Compare commits

...

5 Commits

Author SHA1 Message Date
Teknium
c4c3a57a3a fix: restore agent.close() cleanup and correct /restart category
- Add agent.close() call to _finalize_shutdown_agents() to prevent
  zombie processes (terminal sandboxes, browser daemons, httpx clients)
- Global cleanup (process_registry, environments, browsers) preserved
  in _stop_impl() during conflict resolution
- Move /restart CommandDef from 'Info' to 'Session' category to match
  /stop and /status
2026-04-10 18:55:28 -07:00
Kenny Xie
eb3f021e2a fix(gateway): address restart review feedback 2026-04-10 18:54:48 -07:00
aquaright1
b4cb803954 fix(gateway): self-request service restarts when invoked in-process 2026-04-10 18:54:48 -07:00
Kenny Xie
b5928530d3 fix(gateway): tolerate partial runner construction 2026-04-10 18:54:48 -07:00
Kenny Xie
ef120b5422 fix(gateway): drain in-flight work before restart 2026-04-10 18:54:48 -07:00
14 changed files with 1089 additions and 161 deletions

View File

@@ -480,6 +480,12 @@ agent:
# Fires once per run when inactivity reaches this threshold (seconds). # Fires once per run when inactivity reaches this threshold (seconds).
# Set to 0 to disable the warning. # Set to 0 to disable the warning.
# gateway_timeout_warning: 900 # gateway_timeout_warning: 900
# Graceful drain timeout for gateway stop/restart (seconds).
# The gateway stops accepting new work, waits for in-flight agents to
# finish, then interrupts anything still running after this timeout.
# 0 = no drain, interrupt immediately.
# restart_drain_timeout: 60
# Enable verbose logging # Enable verbose logging
verbose: false verbose: false

View File

@@ -673,6 +673,32 @@ class SendResult:
retryable: bool = False # True for transient connection errors — base will retry automatically retryable: bool = False # True for transient connection errors — base will retry automatically
def merge_pending_message_event(
pending_messages: Dict[str, MessageEvent],
session_key: str,
event: MessageEvent,
) -> None:
"""Store or merge a pending event for a session.
Photo bursts/albums often arrive as multiple near-simultaneous PHOTO
events. Merge those into the existing queued event so the next turn sees
the whole burst, while non-photo follow-ups still replace the pending
event normally.
"""
existing = pending_messages.get(session_key)
if (
existing
and getattr(existing, "message_type", None) == MessageType.PHOTO
and event.message_type == MessageType.PHOTO
):
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
return
pending_messages[session_key] = event
# Error substrings that indicate a transient *connection* failure worth retrying. # Error substrings that indicate a transient *connection* failure worth retrying.
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally # "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message) # excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
@@ -727,6 +753,7 @@ class BasePlatformAdapter(ABC):
# working on a task after --replace or manual restarts. # working on a task after --replace or manual restarts.
self._background_tasks: set[asyncio.Task] = set() self._background_tasks: set[asyncio.Task] = set()
self._expected_cancelled_tasks: set[asyncio.Task] = set() self._expected_cancelled_tasks: set[asyncio.Task] = set()
self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None
# Chats where auto-TTS on voice input is disabled (set by /voice off) # Chats where auto-TTS on voice input is disabled (set by /voice off)
self._auto_tts_disabled_chats: set = set() self._auto_tts_disabled_chats: set = set()
# Chats where typing indicator is paused (e.g. during approval waits). # Chats where typing indicator is paused (e.g. during approval waits).
@@ -815,6 +842,10 @@ class BasePlatformAdapter(ABC):
an optional response string. an optional response string.
""" """
self._message_handler = handler self._message_handler = handler
def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None:
"""Set an optional handler for messages arriving during active sessions."""
self._busy_session_handler = handler
def set_session_store(self, session_store: Any) -> None: def set_session_store(self, session_store: Any) -> None:
""" """
@@ -1396,7 +1427,7 @@ class BasePlatformAdapter(ABC):
# session lifecycle and its cleanup races with the running task # session lifecycle and its cleanup races with the running task
# (see PR #4926). # (see PR #4926).
cmd = event.get_command() cmd = event.get_command()
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background"): if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background", "restart"):
logger.debug( logger.debug(
"[%s] Command '/%s' bypassing active-session guard for %s", "[%s] Command '/%s' bypassing active-session guard for %s",
self.name, cmd, session_key, self.name, cmd, session_key,
@@ -1415,19 +1446,19 @@ class BasePlatformAdapter(ABC):
logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True) logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True)
return return
if self._busy_session_handler is not None:
try:
if await self._busy_session_handler(event, session_key):
return
except Exception as e:
logger.error("[%s] Busy-session handler failed: %s", self.name, e, exc_info=True)
# Special case: photo bursts/albums frequently arrive as multiple near- # Special case: photo bursts/albums frequently arrive as multiple near-
# simultaneous messages. Queue them without interrupting the active run, # simultaneous messages. Queue them without interrupting the active run,
# then process them immediately after the current task finishes. # then process them immediately after the current task finishes.
if event.message_type == MessageType.PHOTO: if event.message_type == MessageType.PHOTO:
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key) logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
existing = self._pending_messages.get(session_key) merge_pending_message_event(self._pending_messages, session_key, event)
if existing and existing.message_type == MessageType.PHOTO:
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = self._merge_caption(existing.text, event.text)
else:
self._pending_messages[session_key] = event
return # Don't interrupt now - will run after current task completes return # Don't interrupt now - will run after current task completes
# Default behavior for non-photo follow-ups: interrupt the running agent # Default behavior for non-photo follow-ups: interrupt the running agent

20
gateway/restart.py Normal file
View File

@@ -0,0 +1,20 @@
"""Shared gateway restart constants and parsing helpers."""
from hermes_cli.config import DEFAULT_CONFIG
# EX_TEMPFAIL from sysexits.h — used to ask the service manager to restart
# the gateway after a graceful drain/reload path completes.
GATEWAY_SERVICE_RESTART_EXIT_CODE = 75
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT = float(
DEFAULT_CONFIG["agent"]["restart_drain_timeout"]
)
def parse_restart_drain_timeout(raw: object) -> float:
"""Parse a configured drain timeout, falling back to the shared default."""
try:
value = float(raw) if str(raw or "").strip() else DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
except (TypeError, ValueError):
return DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
return max(0.0, value)

View File

@@ -186,6 +186,12 @@ if _config_path.exists():
os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"]) os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"])
if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ: if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ:
os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"]) os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"])
if "restart_drain_timeout" in _agent_cfg and "HERMES_RESTART_DRAIN_TIMEOUT" not in os.environ:
os.environ["HERMES_RESTART_DRAIN_TIMEOUT"] = str(_agent_cfg["restart_drain_timeout"])
_display_cfg = _cfg.get("display", {})
if _display_cfg and isinstance(_display_cfg, dict):
if "busy_input_mode" in _display_cfg and "HERMES_GATEWAY_BUSY_INPUT_MODE" not in os.environ:
os.environ["HERMES_GATEWAY_BUSY_INPUT_MODE"] = str(_display_cfg["busy_input_mode"])
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var. # Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
# HERMES_TIMEZONE from .env takes precedence (already in os.environ). # HERMES_TIMEZONE from .env takes precedence (already in os.environ).
_tz_cfg = _cfg.get("timezone", "") _tz_cfg = _cfg.get("timezone", "")
@@ -235,7 +241,17 @@ from gateway.session import (
build_session_key, build_session_key,
) )
from gateway.delivery import DeliveryRouter from gateway.delivery import DeliveryRouter
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
merge_pending_message_event,
)
from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
parse_restart_drain_timeout,
)
def _normalize_whatsapp_identifier(value: str) -> str: def _normalize_whatsapp_identifier(value: str) -> str:
@@ -471,6 +487,16 @@ class GatewayRunner:
# Class-level defaults so partial construction in tests doesn't # Class-level defaults so partial construction in tests doesn't
# blow up on attribute access. # blow up on attribute access.
_running_agents_ts: Dict[str, float] = {} _running_agents_ts: Dict[str, float] = {}
_busy_input_mode: str = "interrupt"
_restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
_exit_code: Optional[int] = None
_draining: bool = False
_restart_requested: bool = False
_restart_task_started: bool = False
_restart_detached: bool = False
_restart_via_service: bool = False
_stop_task: Optional[asyncio.Task] = None
_session_model_overrides: Dict[str, Dict[str, str]] = {}
def __init__(self, config: Optional[GatewayConfig] = None): def __init__(self, config: Optional[GatewayConfig] = None):
self.config = config or load_gateway_config() self.config = config or load_gateway_config()
@@ -483,6 +509,8 @@ class GatewayRunner:
self._reasoning_config = self._load_reasoning_config() self._reasoning_config = self._load_reasoning_config()
self._service_tier = self._load_service_tier() self._service_tier = self._load_service_tier()
self._show_reasoning = self._load_show_reasoning() self._show_reasoning = self._load_show_reasoning()
self._busy_input_mode = self._load_busy_input_mode()
self._restart_drain_timeout = self._load_restart_drain_timeout()
self._provider_routing = self._load_provider_routing() self._provider_routing = self._load_provider_routing()
self._fallback_model = self._load_fallback_model() self._fallback_model = self._load_fallback_model()
self._smart_model_routing = self._load_smart_model_routing() self._smart_model_routing = self._load_smart_model_routing()
@@ -499,6 +527,13 @@ class GatewayRunner:
self._exit_cleanly = False self._exit_cleanly = False
self._exit_with_failure = False self._exit_with_failure = False
self._exit_reason: Optional[str] = None self._exit_reason: Optional[str] = None
self._exit_code: Optional[int] = None
self._draining = False
self._restart_requested = False
self._restart_task_started = False
self._restart_detached = False
self._restart_via_service = False
self._stop_task: Optional[asyncio.Task] = None
# Track running agents per session for interrupt support # Track running agents per session for interrupt support
# Key: session_key, Value: AIAgent instance # Key: session_key, Value: AIAgent instance
@@ -759,6 +794,10 @@ class GatewayRunner:
def exit_reason(self) -> Optional[str]: def exit_reason(self) -> Optional[str]:
return self._exit_reason return self._exit_reason
@property
def exit_code(self) -> Optional[int]:
return self._exit_code
def _session_key_for_source(self, source: SessionSource) -> str: def _session_key_for_source(self, source: SessionSource) -> str:
"""Resolve the current session key for a source, honoring gateway config when available.""" """Resolve the current session key for a source, honoring gateway config when available."""
if hasattr(self, "session_store") and self.session_store is not None: if hasattr(self, "session_store") and self.session_store is not None:
@@ -868,6 +907,30 @@ class GatewayRunner:
self._exit_cleanly = True self._exit_cleanly = True
self._exit_reason = reason self._exit_reason = reason
self._shutdown_event.set() self._shutdown_event.set()
def _running_agent_count(self) -> int:
return len(self._running_agents)
def _status_action_label(self) -> str:
return "restart" if self._restart_requested else "shutdown"
def _status_action_gerund(self) -> str:
return "restarting" if self._restart_requested else "shutting down"
def _queue_during_drain_enabled(self) -> bool:
return self._restart_requested and self._busy_input_mode == "queue"
def _update_runtime_status(self, gateway_state: Optional[str] = None, exit_reason: Optional[str] = None) -> None:
try:
from gateway.status import write_runtime_status
write_runtime_status(
gateway_state=gateway_state,
exit_reason=exit_reason,
restart_requested=self._restart_requested,
active_agents=self._running_agent_count(),
)
except Exception:
pass
@staticmethod @staticmethod
def _load_prefill_messages() -> List[Dict[str, Any]]: def _load_prefill_messages() -> List[Dict[str, Any]]:
@@ -994,6 +1057,48 @@ class GatewayRunner:
pass pass
return False return False
@staticmethod
def _load_busy_input_mode() -> str:
"""Load gateway drain-time busy-input behavior from config/env."""
mode = os.getenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "").strip().lower()
if not mode:
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
mode = str(cfg.get("display", {}).get("busy_input_mode", "") or "").strip().lower()
except Exception:
pass
return "queue" if mode == "queue" else "interrupt"
@staticmethod
def _load_restart_drain_timeout() -> float:
"""Load graceful gateway restart/stop drain timeout in seconds."""
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
if not raw:
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip()
except Exception:
pass
value = parse_restart_drain_timeout(raw)
if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT:
try:
float(raw)
except (TypeError, ValueError):
logger.warning(
"Invalid restart_drain_timeout '%s', using default %.0fs",
raw,
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
)
return value
@staticmethod @staticmethod
def _load_background_notifications_mode() -> str: def _load_background_notifications_mode() -> str:
"""Load background process notification mode from config or env var. """Load background process notification mode from config or env var.
@@ -1078,6 +1183,155 @@ class GatewayRunner:
pass pass
return {} return {}
def _snapshot_running_agents(self) -> Dict[str, Any]:
return {
session_key: agent
for session_key, agent in self._running_agents.items()
if agent is not _AGENT_PENDING_SENTINEL
}
def _queue_or_replace_pending_event(self, session_key: str, event: MessageEvent) -> None:
adapter = self.adapters.get(event.source.platform)
if not adapter:
return
merge_pending_message_event(adapter._pending_messages, session_key, event)
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
if not self._draining:
return False
adapter = self.adapters.get(event.source.platform)
if not adapter:
return True
thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None
if self._queue_during_drain_enabled():
self._queue_or_replace_pending_event(session_key, event)
message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
else:
message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
await adapter._send_with_retry(
chat_id=event.source.chat_id,
content=message,
reply_to=event.message_id,
metadata=thread_meta,
)
return True
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
snapshot = self._snapshot_running_agents()
last_active_count = self._running_agent_count()
last_status_at = 0.0
def _maybe_update_status(force: bool = False) -> None:
nonlocal last_active_count, last_status_at
now = asyncio.get_running_loop().time()
active_count = self._running_agent_count()
if force or active_count != last_active_count or (now - last_status_at) >= 1.0:
self._update_runtime_status("draining")
last_active_count = active_count
last_status_at = now
if not self._running_agents:
_maybe_update_status(force=True)
return snapshot, False
_maybe_update_status(force=True)
if timeout <= 0:
return snapshot, True
deadline = asyncio.get_running_loop().time() + timeout
while self._running_agents and asyncio.get_running_loop().time() < deadline:
_maybe_update_status()
await asyncio.sleep(0.1)
timed_out = bool(self._running_agents)
_maybe_update_status(force=True)
return snapshot, timed_out
def _interrupt_running_agents(self, reason: str) -> None:
for session_key, agent in list(self._running_agents.items()):
if agent is _AGENT_PENDING_SENTINEL:
continue
try:
agent.interrupt(reason)
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
except Exception as e:
logger.debug("Failed interrupting agent during shutdown: %s", e)
def _finalize_shutdown_agents(self, active_agents: Dict[str, Any]) -> None:
for agent in active_agents.values():
try:
from hermes_cli.plugins import invoke_hook as _invoke_hook
_invoke_hook(
"on_session_finalize",
session_id=getattr(agent, "session_id", None),
platform="gateway",
)
except Exception:
pass
try:
if hasattr(agent, "shutdown_memory_provider"):
agent.shutdown_memory_provider()
except Exception:
pass
# Close tool resources (terminal sandboxes, browser daemons,
# background processes, httpx clients) to prevent zombie
# process accumulation.
try:
if hasattr(agent, 'close'):
agent.close()
except Exception:
pass
async def _launch_detached_restart_command(self) -> None:
import shutil
import subprocess
hermes_cmd = _resolve_hermes_bin()
if not hermes_cmd:
logger.error("Could not locate hermes binary for detached /restart")
return
current_pid = os.getpid()
cmd = " ".join(shlex.quote(part) for part in hermes_cmd)
shell_cmd = (
f"while kill -0 {current_pid} 2>/dev/null; do sleep 0.2; done; "
f"{cmd} gateway restart"
)
setsid_bin = shutil.which("setsid")
if setsid_bin:
subprocess.Popen(
[setsid_bin, "bash", "-lc", shell_cmd],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
else:
subprocess.Popen(
["bash", "-lc", shell_cmd],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
def request_restart(self, *, detached: bool = False, via_service: bool = False) -> bool:
if self._restart_task_started:
return False
self._restart_requested = True
self._restart_detached = detached
self._restart_via_service = via_service
self._restart_task_started = True
async def _run_restart() -> None:
await asyncio.sleep(0.05)
await self.stop(restart=True, detached_restart=detached, service_restart=via_service)
task = asyncio.create_task(_run_restart())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return True
async def start(self) -> bool: async def start(self) -> bool:
""" """
Start the gateway and all configured platform adapters. Start the gateway and all configured platform adapters.
@@ -1165,6 +1419,7 @@ class GatewayRunner:
adapter.set_message_handler(self._handle_message) adapter.set_message_handler(self._handle_message)
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error) adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
adapter.set_session_store(self.session_store) adapter.set_session_store(self.session_store)
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
# Try to connect # Try to connect
logger.info("Connecting to %s...", platform.value) logger.info("Connecting to %s...", platform.value)
@@ -1240,11 +1495,7 @@ class GatewayRunner:
self.delivery_router.adapters = self.adapters self.delivery_router.adapters = self.adapters
self._running = True self._running = True
try: self._update_runtime_status("running")
from gateway.status import write_runtime_status
write_runtime_status(gateway_state="running", exit_reason=None)
except Exception:
pass
# Emit gateway:startup hook # Emit gateway:startup hook
hook_count = len(self.hooks.loaded_hooks) hook_count = len(self.hooks.loaded_hooks)
@@ -1479,6 +1730,7 @@ class GatewayRunner:
adapter.set_message_handler(self._handle_message) adapter.set_message_handler(self._handle_message)
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error) adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
adapter.set_session_store(self.session_store) adapter.set_session_store(self.session_store)
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
success = await adapter.connect() success = await adapter.connect()
if success: if success:
@@ -1525,90 +1777,108 @@ class GatewayRunner:
return return
await asyncio.sleep(1) await asyncio.sleep(1)
async def stop(self) -> None: async def stop(
self,
*,
restart: bool = False,
detached_restart: bool = False,
service_restart: bool = False,
) -> None:
"""Stop the gateway and disconnect all adapters.""" """Stop the gateway and disconnect all adapters."""
logger.info("Stopping gateway...") if restart:
self._running = False self._restart_requested = True
self._restart_detached = detached_restart
self._restart_via_service = service_restart
if self._stop_task is not None:
await self._stop_task
return
for session_key, agent in list(self._running_agents.items()): async def _stop_impl() -> None:
if agent is _AGENT_PENDING_SENTINEL: logger.info(
continue "Stopping gateway%s...",
" for restart" if self._restart_requested else "",
)
self._running = False
self._draining = True
timeout = self._restart_drain_timeout
active_agents, timed_out = await self._drain_active_agents(timeout)
if timed_out:
logger.warning(
"Gateway drain timed out after %.1fs with %d active agent(s); interrupting remaining work.",
timeout,
self._running_agent_count(),
)
self._interrupt_running_agents(
"Gateway restarting" if self._restart_requested else "Gateway shutting down"
)
interrupt_deadline = asyncio.get_running_loop().time() + 5.0
while self._running_agents and asyncio.get_running_loop().time() < interrupt_deadline:
self._update_runtime_status("draining")
await asyncio.sleep(0.1)
if self._restart_requested and self._restart_detached:
try:
await self._launch_detached_restart_command()
except Exception as e:
logger.error("Failed to launch detached gateway restart: %s", e)
self._finalize_shutdown_agents(active_agents)
for platform, adapter in list(self.adapters.items()):
try:
await adapter.cancel_background_tasks()
except Exception as e:
logger.debug("%s background-task cancel error: %s", platform.value, e)
try:
await adapter.disconnect()
logger.info("%s disconnected", platform.value)
except Exception as e:
logger.error("%s disconnect error: %s", platform.value, e)
for _task in list(self._background_tasks):
if _task is self._stop_task:
continue
_task.cancel()
self._background_tasks.clear()
self.adapters.clear()
self._running_agents.clear()
self._pending_messages.clear()
self._pending_approvals.clear()
self._shutdown_event.set()
# Global cleanup: kill any remaining tool subprocesses not tied
# to a specific agent (catch-all for zombie prevention).
try: try:
agent.interrupt("Gateway shutting down") from tools.process_registry import process_registry
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20]) process_registry.kill_all()
except Exception as e:
logger.debug("Failed interrupting agent during shutdown: %s", e)
# Fire plugin on_session_finalize hook before memory shutdown
try:
from hermes_cli.plugins import invoke_hook as _invoke_hook
_invoke_hook("on_session_finalize",
session_id=getattr(agent, 'session_id', None),
platform="gateway")
except Exception: except Exception:
pass pass
# Shut down memory provider at actual session boundary
try: try:
if hasattr(agent, 'shutdown_memory_provider'): from tools.terminal_tool import cleanup_all_environments
agent.shutdown_memory_provider() cleanup_all_environments()
except Exception: except Exception:
pass pass
# Close tool resources (terminal sandboxes, browser daemons,
# background processes, httpx clients) to prevent zombie
# process accumulation.
try: try:
if hasattr(agent, 'close'): from tools.browser_tool import cleanup_all_browsers
agent.close() cleanup_all_browsers()
except Exception: except Exception:
pass pass
for platform, adapter in list(self.adapters.items()): from gateway.status import remove_pid_file
try: remove_pid_file()
await adapter.cancel_background_tasks()
except Exception as e:
logger.debug("%s background-task cancel error: %s", platform.value, e)
try:
await adapter.disconnect()
logger.info("%s disconnected", platform.value)
except Exception as e:
logger.error("%s disconnect error: %s", platform.value, e)
# Cancel any pending background tasks if self._restart_requested and self._restart_via_service:
for _task in list(self._background_tasks): self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE
_task.cancel() self._exit_reason = self._exit_reason or "Gateway restart requested"
self._background_tasks.clear()
self.adapters.clear() self._draining = False
self._running_agents.clear() self._update_runtime_status("stopped", self._exit_reason)
self._pending_messages.clear() logger.info("Gateway stopped")
self._pending_approvals.clear()
self._shutdown_event.set()
# Global cleanup: kill any remaining tool subprocesses not tied self._stop_task = asyncio.create_task(_stop_impl())
# to a specific agent (catch-all for zombie prevention). await self._stop_task
try:
from tools.process_registry import process_registry
process_registry.kill_all()
except Exception:
pass
try:
from tools.terminal_tool import cleanup_all_environments
cleanup_all_environments()
except Exception:
pass
try:
from tools.browser_tool import cleanup_all_browsers
cleanup_all_browsers()
except Exception:
pass
from gateway.status import remove_pid_file, write_runtime_status
remove_pid_file()
try:
write_runtime_status(gateway_state="stopped", exit_reason=self._exit_reason)
except Exception:
pass
logger.info("Gateway stopped")
async def wait_for_shutdown(self) -> None: async def wait_for_shutdown(self) -> None:
"""Wait for shutdown signal.""" """Wait for shutdown signal."""
@@ -2014,6 +2284,9 @@ class GatewayRunner:
_evt_cmd = event.get_command() _evt_cmd = event.get_command()
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None _cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
if _cmd_def_inner and _cmd_def_inner.name == "restart":
return await self._handle_restart_command(event)
# /stop must hard-kill the session when an agent is running. # /stop must hard-kill the session when an agent is running.
# A soft interrupt (agent.interrupt()) doesn't help when the agent # A soft interrupt (agent.interrupt()) doesn't help when the agent
# is truly hung — the executor thread is blocked and never checks # is truly hung — the executor thread is blocked and never checks
@@ -2094,18 +2367,7 @@ class GatewayRunner:
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20]) logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
adapter = self.adapters.get(source.platform) adapter = self.adapters.get(source.platform)
if adapter: if adapter:
# Reuse adapter queue semantics so photo bursts merge cleanly. merge_pending_message_event(adapter._pending_messages, _quick_key, event)
if _quick_key in adapter._pending_messages:
existing = adapter._pending_messages[_quick_key]
if getattr(existing, "message_type", None) == MessageType.PHOTO:
existing.media_urls.extend(event.media_urls)
existing.media_types.extend(event.media_types)
if event.text:
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
else:
adapter._pending_messages[_quick_key] = event
else:
adapter._pending_messages[_quick_key] = event
return None return None
running_agent = self._running_agents.get(_quick_key) running_agent = self._running_agents.get(_quick_key)
@@ -2123,6 +2385,14 @@ class GatewayRunner:
if adapter: if adapter:
adapter._pending_messages[_quick_key] = event adapter._pending_messages[_quick_key] = event
return None return None
if self._draining:
if self._queue_during_drain_enabled():
self._queue_or_replace_pending_event(_quick_key, event)
return (
f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
if self._queue_during_drain_enabled()
else f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
)
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20]) logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
running_agent.interrupt(event.text) running_agent.interrupt(event.text)
if _quick_key in self._pending_messages: if _quick_key in self._pending_messages:
@@ -2164,6 +2434,9 @@ class GatewayRunner:
if canonical == "status": if canonical == "status":
return await self._handle_status_command(event) return await self._handle_status_command(event)
if canonical == "restart":
return await self._handle_restart_command(event)
if canonical == "stop": if canonical == "stop":
return await self._handle_stop_command(event) return await self._handle_stop_command(event)
@@ -2262,6 +2535,9 @@ class GatewayRunner:
if canonical == "voice": if canonical == "voice":
return await self._handle_voice_command(event) return await self._handle_voice_command(event)
if self._draining:
return f"⏳ Gateway is {self._status_action_gerund()} and is not accepting new work right now."
# User-defined quick commands (bypass agent loop, no LLM call) # User-defined quick commands (bypass agent loop, no LLM call)
if command: if command:
if isinstance(self.config, dict): if isinstance(self.config, dict):
@@ -3556,7 +3832,21 @@ class GatewayRunner:
return "⚡ Force-stopped. The session is unlocked — you can send a new message." return "⚡ Force-stopped. The session is unlocked — you can send a new message."
else: else:
return "No active task to stop." return "No active task to stop."
async def _handle_restart_command(self, event: MessageEvent) -> str:
"""Handle /restart command - drain active work, then restart the gateway."""
if self._restart_requested or self._draining:
count = self._running_agent_count()
if count:
return f"⏳ Draining {count} active agent(s) before restart..."
return "⏳ Gateway restart already in progress..."
active_agents = self._running_agent_count()
self.request_restart(detached=True, via_service=False)
if active_agents:
return f"⏳ Draining {active_agents} active agent(s) before restart..."
return "♻ Restarting gateway..."
async def _handle_help_command(self, event: MessageEvent) -> str: async def _handle_help_command(self, event: MessageEvent) -> str:
"""Handle /help command - list available commands.""" """Handle /help command - list available commands."""
from hermes_cli.commands import gateway_help_lines from hermes_cli.commands import gateway_help_lines
@@ -3679,7 +3969,7 @@ class GatewayRunner:
# Check for session override # Check for session override
source = event.source source = event.source
session_key = self._session_key_for_source(source) session_key = self._session_key_for_source(source)
override = getattr(self, "_session_model_overrides", {}).get(session_key, {}) override = self._session_model_overrides.get(session_key, {})
if override: if override:
current_model = override.get("model", current_model) current_model = override.get("model", current_model)
current_provider = override.get("provider", current_provider) current_provider = override.get("provider", current_provider)
@@ -3761,8 +4051,6 @@ class GatewayRunner:
f"via {result.provider_label or result.target_provider}. " f"via {result.provider_label or result.target_provider}. "
f"Adjust your self-identification accordingly.]" f"Adjust your self-identification accordingly.]"
) )
if not hasattr(_self, "_session_model_overrides"):
_self._session_model_overrides = {}
_self._session_model_overrides[_session_key] = { _self._session_model_overrides[_session_key] = {
"model": result.new_model, "model": result.new_model,
"provider": result.target_provider, "provider": result.target_provider,
@@ -3876,8 +4164,6 @@ class GatewayRunner:
) )
# Store session override so next agent creation uses the new model # Store session override so next agent creation uses the new model
if not hasattr(self, "_session_model_overrides"):
self._session_model_overrides = {}
self._session_model_overrides[session_key] = { self._session_model_overrides[session_key] = {
"model": result.new_model, "model": result.new_model,
"provider": result.target_provider, "provider": result.target_provider,
@@ -7363,6 +7649,8 @@ class GatewayRunner:
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
if session_key: if session_key:
self._running_agents[session_key] = agent_holder[0] self._running_agents[session_key] = agent_holder[0]
if self._draining:
self._update_runtime_status("draining")
tracking_task = asyncio.create_task(track_agent()) tracking_task = asyncio.create_task(track_agent())
@@ -7608,6 +7896,14 @@ class GatewayRunner:
except Exception: except Exception:
pass pass
if self._draining and pending:
logger.info(
"Discarding pending follow-up for session %s during gateway %s",
session_key[:20] if session_key else "?",
self._status_action_label(),
)
pending = None
if pending: if pending:
logger.debug("Processing pending message: '%s...'", pending[:40]) logger.debug("Processing pending message: '%s...'", pending[:40])
@@ -7684,6 +7980,8 @@ class GatewayRunner:
del self._running_agents[session_key] del self._running_agents[session_key]
if session_key: if session_key:
self._running_agents_ts.pop(session_key, None) self._running_agents_ts.pop(session_key, None)
if self._draining:
self._update_runtime_status("draining")
# Wait for cancelled tasks # Wait for cancelled tasks
for task in [progress_task, interrupt_monitor, tracking_task, _notify_task]: for task in [progress_task, interrupt_monitor, tracking_task, _notify_task]:
@@ -7881,13 +8179,21 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
runner = GatewayRunner(config) runner = GatewayRunner(config)
# Set up signal handlers # Set up signal handlers
def signal_handler(): def shutdown_signal_handler():
asyncio.create_task(runner.stop()) asyncio.create_task(runner.stop())
def restart_signal_handler():
runner.request_restart(detached=False, via_service=True)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM): for sig in (signal.SIGINT, signal.SIGTERM):
try: try:
loop.add_signal_handler(sig, signal_handler) loop.add_signal_handler(sig, shutdown_signal_handler)
except NotImplementedError:
pass
if hasattr(signal, "SIGUSR1"):
try:
loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler)
except NotImplementedError: except NotImplementedError:
pass pass
@@ -7937,6 +8243,9 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
except Exception: except Exception:
pass pass
if runner.exit_code is not None:
raise SystemExit(runner.exit_code)
return True return True

View File

@@ -158,6 +158,8 @@ def _build_runtime_status_record() -> dict[str, Any]:
payload.update({ payload.update({
"gateway_state": "starting", "gateway_state": "starting",
"exit_reason": None, "exit_reason": None,
"restart_requested": False,
"active_agents": 0,
"platforms": {}, "platforms": {},
"updated_at": _utc_now_iso(), "updated_at": _utc_now_iso(),
}) })
@@ -218,6 +220,8 @@ def write_runtime_status(
*, *,
gateway_state: Optional[str] = None, gateway_state: Optional[str] = None,
exit_reason: Optional[str] = None, exit_reason: Optional[str] = None,
restart_requested: Optional[bool] = None,
active_agents: Optional[int] = None,
platform: Optional[str] = None, platform: Optional[str] = None,
platform_state: Optional[str] = None, platform_state: Optional[str] = None,
error_code: Optional[str] = None, error_code: Optional[str] = None,
@@ -236,6 +240,10 @@ def write_runtime_status(
payload["gateway_state"] = gateway_state payload["gateway_state"] = gateway_state
if exit_reason is not None: if exit_reason is not None:
payload["exit_reason"] = exit_reason payload["exit_reason"] = exit_reason
if restart_requested is not None:
payload["restart_requested"] = bool(restart_requested)
if active_agents is not None:
payload["active_agents"] = max(0, int(active_agents))
if platform is not None: if platform is not None:
platform_payload = payload["platforms"].get(platform, {}) platform_payload = payload["platforms"].get(platform, {})

View File

@@ -140,6 +140,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
CommandDef("commands", "Browse all commands and skills (paginated)", "Info", CommandDef("commands", "Browse all commands and skills (paginated)", "Info",
gateway_only=True, args_hint="[page]"), gateway_only=True, args_hint="[page]"),
CommandDef("help", "Show available commands", "Info"), CommandDef("help", "Show available commands", "Info"),
CommandDef("restart", "Gracefully restart the gateway after draining active runs", "Session",
gateway_only=True),
CommandDef("usage", "Show token usage and rate limits for the current session", "Info"), CommandDef("usage", "Show token usage and rate limits for the current session", "Info"),
CommandDef("insights", "Show usage insights and analytics", "Info", CommandDef("insights", "Show usage insights and analytics", "Info",
args_hint="[days]"), args_hint="[days]"),

View File

@@ -269,6 +269,11 @@ DEFAULT_CONFIG = {
# tools or receiving API responses. Only fires when the agent has # tools or receiving API responses. Only fires when the agent has
# been completely idle for this duration. 0 = unlimited. # been completely idle for this duration. 0 = unlimited.
"gateway_timeout": 1800, "gateway_timeout": 1800,
# Graceful drain timeout for gateway stop/restart (seconds).
# The gateway stops accepting new work, waits for running agents
# to finish, then interrupts any remaining runs after the timeout.
# 0 = no drain, interrupt immediately.
"restart_drain_timeout": 60,
"service_tier": "", "service_tier": "",
# Tool-use enforcement: injects system prompt guidance that tells the # Tool-use enforcement: injects system prompt guidance that tells the
# model to actually call tools instead of describing intended actions. # model to actually call tools instead of describing intended actions.

View File

@@ -15,7 +15,19 @@ from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent.resolve() PROJECT_ROOT = Path(__file__).parent.parent.resolve()
from gateway.status import terminate_pid from gateway.status import terminate_pid
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
parse_restart_drain_timeout,
)
from hermes_cli.config import (
get_env_value,
get_hermes_home,
is_managed,
managed_error,
read_raw_config,
save_env_value,
)
# display_hermes_home is imported lazily at call sites to avoid ImportError # display_hermes_home is imported lazily at call sites to avoid ImportError
# when hermes_constants is cached from a pre-update version during `hermes update`. # when hermes_constants is cached from a pre-update version during `hermes update`.
from hermes_cli.setup import ( from hermes_cli.setup import (
@@ -92,6 +104,59 @@ def _get_service_pids() -> set:
return pids return pids
def _get_parent_pid(pid: int) -> int | None:
"""Return the parent PID for ``pid``, or ``None`` when unavailable."""
if pid <= 1:
return None
try:
result = subprocess.run(
["ps", "-o", "ppid=", "-p", str(pid)],
capture_output=True,
text=True,
timeout=5,
)
except (FileNotFoundError, subprocess.TimeoutExpired):
return None
if result.returncode != 0:
return None
raw = result.stdout.strip()
if not raw:
return None
try:
parent_pid = int(raw.splitlines()[-1].strip())
except ValueError:
return None
return parent_pid if parent_pid > 0 else None
def _is_pid_ancestor_of_current_process(target_pid: int) -> bool:
"""Return True when ``target_pid`` is this process or one of its ancestors."""
if target_pid <= 0:
return False
pid = os.getpid()
seen: set[int] = set()
while pid and pid not in seen:
if pid == target_pid:
return True
seen.add(pid)
pid = _get_parent_pid(pid) or 0
return False
def _request_gateway_self_restart(pid: int) -> bool:
"""Ask a running gateway ancestor to restart itself asynchronously."""
if not hasattr(signal, "SIGUSR1"):
return False
if not _is_pid_ancestor_of_current_process(pid):
return False
try:
os.kill(pid, signal.SIGUSR1)
except (ProcessLookupError, PermissionError, OSError):
return False
return True
def find_gateway_pids(exclude_pids: set | None = None) -> list: def find_gateway_pids(exclude_pids: set | None = None) -> list:
"""Find PIDs of running gateway processes. """Find PIDs of running gateway processes.
@@ -665,6 +730,7 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
path_entries.append(resolved_node_dir) path_entries.append(resolved_node_dir)
common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"] common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"]
restart_timeout = max(60, int(_get_restart_drain_timeout() or 0))
if system: if system:
username, group_name, home_dir = _system_service_identity(run_as_user) username, group_name, home_dir = _system_service_identity(run_as_user)
@@ -703,9 +769,11 @@ Environment="VIRTUAL_ENV={venv_dir}"
Environment="HERMES_HOME={hermes_home}" Environment="HERMES_HOME={hermes_home}"
Restart=on-failure Restart=on-failure
RestartSec=30 RestartSec=30
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
KillMode=mixed KillMode=mixed
KillSignal=SIGTERM KillSignal=SIGTERM
TimeoutStopSec=60 ExecReload=/bin/kill -USR1 $MAINPID
TimeoutStopSec={restart_timeout}
StandardOutput=journal StandardOutput=journal
StandardError=journal StandardError=journal
@@ -733,9 +801,11 @@ Environment="VIRTUAL_ENV={venv_dir}"
Environment="HERMES_HOME={hermes_home}" Environment="HERMES_HOME={hermes_home}"
Restart=on-failure Restart=on-failure
RestartSec=30 RestartSec=30
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
KillMode=mixed KillMode=mixed
KillSignal=SIGTERM KillSignal=SIGTERM
TimeoutStopSec=60 ExecReload=/bin/kill -USR1 $MAINPID
TimeoutStopSec={restart_timeout}
StandardOutput=journal StandardOutput=journal
StandardError=journal StandardError=journal
@@ -838,6 +908,20 @@ def _select_systemd_scope(system: bool = False) -> bool:
return get_systemd_unit_path(system=True).exists() and not get_systemd_unit_path(system=False).exists() return get_systemd_unit_path(system=True).exists() and not get_systemd_unit_path(system=False).exists()
def _get_restart_drain_timeout() -> float:
"""Return the configured gateway restart drain timeout in seconds."""
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
if not raw:
cfg = read_raw_config()
agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {}
raw = str(
agent_cfg.get(
"restart_drain_timeout", DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
)
return parse_restart_drain_timeout(raw)
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None): def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
if system: if system:
_require_root_for_system_service("install") _require_root_for_system_service("install")
@@ -923,7 +1007,13 @@ def systemd_restart(system: bool = False):
if system: if system:
_require_root_for_system_service("restart") _require_root_for_system_service("restart")
refresh_systemd_unit_if_needed(system=system) refresh_systemd_unit_if_needed(system=system)
subprocess.run(_systemctl_cmd(system) + ["restart", get_service_name()], check=True, timeout=90) from gateway.status import get_running_pid
pid = get_running_pid()
if pid is not None and _request_gateway_self_restart(pid):
print(f"{_service_scope_label(system).capitalize()} service restart requested")
return
subprocess.run(_systemctl_cmd(system) + ["reload-or-restart", get_service_name()], check=True, timeout=90)
print(f"{_service_scope_label(system).capitalize()} service restarted") print(f"{_service_scope_label(system).capitalize()} service restarted")
@@ -1211,7 +1301,7 @@ def launchd_stop():
_wait_for_gateway_exit(timeout=10.0, force_after=5.0) _wait_for_gateway_exit(timeout=10.0, force_after=5.0)
print("✓ Service stopped") print("✓ Service stopped")
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0): def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float | None = 5.0) -> bool:
"""Wait for the gateway process (by saved PID) to exit. """Wait for the gateway process (by saved PID) to exit.
Uses the PID from the gateway.pid file — not launchd labels — so this Uses the PID from the gateway.pid file — not launchd labels — so this
@@ -1226,21 +1316,21 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
from gateway.status import get_running_pid from gateway.status import get_running_pid
deadline = time.monotonic() + timeout deadline = time.monotonic() + timeout
force_deadline = time.monotonic() + force_after force_deadline = (time.monotonic() + force_after) if force_after is not None else None
force_sent = False force_sent = False
while time.monotonic() < deadline: while time.monotonic() < deadline:
pid = get_running_pid() pid = get_running_pid()
if pid is None: if pid is None:
return # Process exited cleanly. return True # Process exited cleanly.
if not force_sent and time.monotonic() >= force_deadline: if force_after is not None and not force_sent and time.monotonic() >= force_deadline:
# Grace period expired — force-kill the specific PID. # Grace period expired — force-kill the specific PID.
try: try:
terminate_pid(pid, force=True) terminate_pid(pid, force=True)
print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL") print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL")
except (ProcessLookupError, PermissionError, OSError): except (ProcessLookupError, PermissionError, OSError):
return # Already gone or we can't touch it. return True # Already gone or we can't touch it.
force_sent = True force_sent = True
time.sleep(0.3) time.sleep(0.3)
@@ -1249,15 +1339,30 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
remaining_pid = get_running_pid() remaining_pid = get_running_pid()
if remaining_pid is not None: if remaining_pid is not None:
print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail") print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail")
return False
return True
def launchd_restart(): def launchd_restart():
label = get_launchd_label() label = get_launchd_label()
target = f"{_launchd_domain()}/{label}" target = f"{_launchd_domain()}/{label}"
# Use kickstart -k so launchd performs an atomic kill+restart. drain_timeout = _get_restart_drain_timeout()
# A two-step stop/start from inside the gateway's own process tree from gateway.status import get_running_pid
# would kill the shell before the start command is reached.
try: try:
pid = get_running_pid()
if pid is not None and _request_gateway_self_restart(pid):
print("✓ Service restart requested")
return
if pid is not None:
try:
terminate_pid(pid, force=False)
except (ProcessLookupError, PermissionError, OSError):
pid = None
if pid is not None:
exited = _wait_for_gateway_exit(timeout=drain_timeout, force_after=None)
if not exited:
print(f"⚠ Gateway drain timed out after {drain_timeout:.0f}s — forcing launchd restart")
subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90) subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90)
print("✓ Service restarted") print("✓ Service restarted")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
@@ -1728,6 +1833,8 @@ def _runtime_health_lines() -> list[str]:
lines: list[str] = [] lines: list[str] = []
gateway_state = state.get("gateway_state") gateway_state = state.get("gateway_state")
exit_reason = state.get("exit_reason") exit_reason = state.get("exit_reason")
active_agents = state.get("active_agents")
restart_requested = state.get("restart_requested")
platforms = state.get("platforms", {}) or {} platforms = state.get("platforms", {}) or {}
for platform, pdata in platforms.items(): for platform, pdata in platforms.items():
@@ -1737,6 +1844,10 @@ def _runtime_health_lines() -> list[str]:
if gateway_state == "startup_failed" and exit_reason: if gateway_state == "startup_failed" and exit_reason:
lines.append(f"⚠ Last startup issue: {exit_reason}") lines.append(f"⚠ Last startup issue: {exit_reason}")
elif gateway_state == "draining":
action = "restart" if restart_requested else "shutdown"
count = int(active_agents or 0)
lines.append(f"⏳ Gateway draining for {action} ({count} active agent(s))")
elif gateway_state == "stopped" and exit_reason: elif gateway_state == "stopped" and exit_reason:
lines.append(f"⚠ Last shutdown reason: {exit_reason}") lines.append(f"⚠ Last shutdown reason: {exit_reason}")

View File

@@ -0,0 +1,110 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
from gateway.run import GatewayRunner
from gateway.session import SessionSource
class RestartTestAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
self.sent: list[str] = []
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
self.sent.append(content)
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
)
def make_restart_runner(
adapter: BasePlatformAdapter | None = None,
) -> tuple[GatewayRunner, BasePlatformAdapter]:
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._exit_code = None
runner._running_agents = {}
runner._running_agents_ts = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._pending_model_notes = {}
runner._background_tasks = set()
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
runner._stop_task = None
runner._busy_input_mode = "interrupt"
runner._update_prompt_pending = {}
runner._voice_mode = {}
runner._session_model_overrides = {}
runner._shutdown_all_gateway_honcho = lambda: None
runner._update_runtime_status = MagicMock()
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(
runner, GatewayRunner
)
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(
runner, GatewayRunner
)
runner._handle_active_session_busy_message = (
GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
)
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(
runner, GatewayRunner
)
runner._status_action_label = GatewayRunner._status_action_label.__get__(
runner, GatewayRunner
)
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(
runner, GatewayRunner
)
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(
runner, GatewayRunner
)
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(
runner, GatewayRunner
)
runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__(
runner, GatewayRunner
)
runner.request_restart = GatewayRunner.request_restart.__get__(runner, GatewayRunner)
runner._is_user_authorized = lambda _source: True
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
runner.pairing_store = MagicMock()
runner.session_store = MagicMock()
runner.delivery_router = MagicMock()
platform_adapter = adapter or RestartTestAdapter()
platform_adapter.set_message_handler(AsyncMock(return_value=None))
platform_adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
runner.adapters = {Platform.TELEGRAM: platform_adapter}
return runner, platform_adapter

View File

@@ -3,43 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig from gateway.platforms.base import MessageEvent
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
from gateway.run import GatewayRunner from gateway.session import build_session_key
from gateway.session import SessionSource, build_session_key from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
class StubAdapter(BasePlatformAdapter):
def __init__(self):
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
async def connect(self):
return True
async def disconnect(self):
return None
async def send(self, chat_id, content, reply_to=None, metadata=None):
return SendResult(success=True, message_id="1")
async def send_typing(self, chat_id, metadata=None):
return None
async def get_chat_info(self, chat_id):
return {"id": chat_id}
def _source(chat_id="123456", chat_type="dm"):
return SessionSource(
platform=Platform.TELEGRAM,
chat_id=chat_id,
chat_type=chat_type,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cancel_background_tasks_cancels_inflight_message_processing(): async def test_cancel_background_tasks_cancels_inflight_message_processing():
adapter = StubAdapter() _runner, adapter = make_restart_runner()
release = asyncio.Event() release = asyncio.Event()
async def block_forever(_event): async def block_forever(_event):
@@ -47,7 +19,7 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
return None return None
adapter.set_message_handler(block_forever) adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1") event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
await adapter.handle_message(event) await adapter.handle_message(event)
await asyncio.sleep(0) await asyncio.sleep(0)
@@ -65,17 +37,11 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(): async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
runner = object.__new__(GatewayRunner) runner, adapter = make_restart_runner()
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
runner._running = True
runner._shutdown_event = asyncio.Event()
runner._exit_reason = None
runner._pending_messages = {"session": "pending text"} runner._pending_messages = {"session": "pending text"}
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}} runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
runner._background_tasks = set() runner._restart_drain_timeout = 0.0
runner._shutdown_all_gateway_honcho = lambda: None
adapter = StubAdapter()
release = asyncio.Event() release = asyncio.Event()
async def block_forever(_event): async def block_forever(_event):
@@ -83,7 +49,7 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
return None return None
adapter.set_message_handler(block_forever) adapter.set_message_handler(block_forever)
event = MessageEvent(text="work", source=_source(), message_id="1") event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
await adapter.handle_message(event) await adapter.handle_message(event)
await asyncio.sleep(0) await asyncio.sleep(0)
@@ -93,7 +59,6 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
session_key = build_session_key(event.source) session_key = build_session_key(event.source)
running_agent = MagicMock() running_agent = MagicMock()
runner._running_agents = {session_key: running_agent} runner._running_agents = {session_key: running_agent}
runner.adapters = {Platform.TELEGRAM: adapter}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"): with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop() await runner.stop()
@@ -105,3 +70,78 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
assert runner._pending_messages == {} assert runner._pending_messages == {}
assert runner._pending_approvals == {} assert runner._pending_approvals == {}
assert runner._shutdown_event.is_set() is True assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_drains_running_agents_before_disconnect():
runner, adapter = make_restart_runner()
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
running_agent = MagicMock()
runner._running_agents = {"session": running_agent}
async def finish_agent():
await asyncio.sleep(0.05)
runner._running_agents.clear()
asyncio.create_task(finish_agent())
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_not_called()
disconnect_mock.assert_awaited_once()
assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_interrupts_after_drain_timeout():
runner, adapter = make_restart_runner()
runner._restart_drain_timeout = 0.05
disconnect_mock = AsyncMock()
adapter.disconnect = disconnect_mock
running_agent = MagicMock()
runner._running_agents = {"session": running_agent}
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
disconnect_mock.assert_awaited_once()
assert runner._shutdown_event.is_set() is True
@pytest.mark.asyncio
async def test_gateway_stop_service_restart_sets_named_exit_code():
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop(restart=True, service_restart=True)
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
@pytest.mark.asyncio
async def test_drain_active_agents_throttles_status_updates():
runner, _adapter = make_restart_runner()
runner._update_runtime_status = MagicMock()
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
async def finish_agents():
await asyncio.sleep(0.12)
runner._running_agents.pop("a")
await asyncio.sleep(0.12)
runner._running_agents.clear()
task = asyncio.create_task(finish_agents())
await runner._drain_active_agents(1.0)
await task
# Start, one count-change update, and final update. Allow one extra update
# if the loop observes the zero-agent state before exiting.
assert 3 <= runner._update_runtime_status.call_count <= 4

View File

@@ -0,0 +1,160 @@
import asyncio
import shutil
import subprocess
from unittest.mock import AsyncMock, MagicMock
import pytest
import gateway.run as gateway_run
from gateway.platforms.base import MessageEvent, MessageType
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
from gateway.session import build_session_key
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
@pytest.mark.asyncio
async def test_restart_command_while_busy_requests_drain_without_interrupt():
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m1",
)
session_key = build_session_key(event.source)
running_agent = MagicMock()
runner._running_agents[session_key] = running_agent
result = await runner._handle_message(event)
assert result == "⏳ Draining 1 active agent(s) before restart..."
running_agent.interrupt.assert_not_called()
runner.request_restart.assert_called_once_with(detached=True, via_service=False)
@pytest.mark.asyncio
async def test_drain_queue_mode_queues_follow_up_without_interrupt():
runner, adapter = make_restart_runner()
runner._draining = True
runner._restart_requested = True
runner._busy_input_mode = "queue"
event = MessageEvent(
text="follow up",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m2",
)
session_key = build_session_key(event.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(event)
assert session_key in adapter._pending_messages
assert adapter._pending_messages[session_key].text == "follow up"
assert not adapter._active_sessions[session_key].is_set()
assert any("queued for the next turn" in message for message in adapter.sent)
@pytest.mark.asyncio
async def test_draining_rejects_new_session_messages():
runner, _adapter = make_restart_runner()
runner._draining = True
runner._restart_requested = True
event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=make_restart_source("fresh"),
message_id="m3",
)
result = await runner._handle_message(event)
assert result == "⏳ Gateway is restarting and is not accepting new work right now."
def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
(tmp_path / "config.yaml").write_text(
"display:\n busy_input_mode: queue\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
tmp_path, monkeypatch, caplog
):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
assert (
gateway_run.GatewayRunner._load_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
(tmp_path / "config.yaml").write_text(
"agent:\n restart_drain_timeout: 12\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7")
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
assert (
gateway_run.GatewayRunner._load_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
assert "Invalid restart_drain_timeout" in caplog.text
@pytest.mark.asyncio
async def test_request_restart_is_idempotent():
runner, _adapter = make_restart_runner()
runner.stop = AsyncMock()
assert runner.request_restart(detached=True, via_service=False) is True
first_task = next(iter(runner._background_tasks))
assert runner.request_restart(detached=True, via_service=False) is False
await first_task
runner.stop.assert_awaited_once_with(
restart=True, detached_restart=True, service_restart=False
)
@pytest.mark.asyncio
async def test_launch_detached_restart_command_uses_setsid(monkeypatch):
runner, _adapter = make_restart_runner()
popen_calls = []
monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"])
monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321)
monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None)
def fake_popen(cmd, **kwargs):
popen_calls.append((cmd, kwargs))
return MagicMock()
monkeypatch.setattr(subprocess, "Popen", fake_popen)
await runner._launch_detached_restart_command()
assert len(popen_calls) == 1
cmd, kwargs = popen_calls[0]
assert cmd[:2] == ["/usr/bin/setsid", "bash"]
assert "gateway restart" in cmd[-1]
assert "kill -0 321" in cmd[-1]
assert kwargs["start_new_session"] is True
assert kwargs["stdout"] is subprocess.DEVNULL
assert kwargs["stderr"] is subprocess.DEVNULL

View File

@@ -127,6 +127,16 @@ async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook):
runner._shutdown_event = MagicMock() runner._shutdown_event = MagicMock()
runner.adapters = {} runner.adapters = {}
runner._exit_reason = "test" runner._exit_reason = "test"
runner._exit_code = None
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = 0.0
runner._stop_task = None
runner._running_agents_ts = {}
runner._update_runtime_status = MagicMock()
agent1 = MagicMock() agent1 = MagicMock()
agent1.session_id = "sess-a" agent1.session_id = "sess-a"

View File

@@ -41,6 +41,15 @@ def _make_runner():
runner._pending_approvals = {} runner._pending_approvals = {}
runner._voice_mode = {} runner._voice_mode = {}
runner._background_tasks = set() runner._background_tasks = set()
runner._draining = False
runner._restart_requested = False
runner._restart_task_started = False
runner._restart_detached = False
runner._restart_via_service = False
runner._restart_drain_timeout = 0.0
runner._stop_task = None
runner._exit_code = None
runner._update_runtime_status = MagicMock()
runner._is_user_authorized = lambda _source: True runner._is_user_authorized = lambda _source: True
runner.hooks = MagicMock() runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock() runner.hooks.emit = AsyncMock()

View File

@@ -5,6 +5,10 @@ from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import hermes_cli.gateway as gateway_cli import hermes_cli.gateway as gateway_cli
from gateway.restart import (
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
GATEWAY_SERVICE_RESTART_EXIT_CODE,
)
class TestSystemdServiceRefresh: class TestSystemdServiceRefresh:
@@ -74,7 +78,7 @@ class TestSystemdServiceRefresh:
assert unit_path.read_text(encoding="utf-8") == "new unit\n" assert unit_path.read_text(encoding="utf-8") == "new unit\n"
assert calls[:2] == [ assert calls[:2] == [
["systemctl", "--user", "daemon-reload"], ["systemctl", "--user", "daemon-reload"],
["systemctl", "--user", "restart", gateway_cli.get_service_name()], ["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()],
] ]
@@ -84,6 +88,8 @@ class TestGeneratedSystemdUnits:
assert "ExecStart=" in unit assert "ExecStart=" in unit
assert "ExecStop=" not in unit assert "ExecStop=" not in unit
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
assert "TimeoutStopSec=60" in unit assert "TimeoutStopSec=60" in unit
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch): def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
@@ -98,6 +104,8 @@ class TestGeneratedSystemdUnits:
assert "ExecStart=" in unit assert "ExecStart=" in unit
assert "ExecStop=" not in unit assert "ExecStop=" not in unit
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
assert "TimeoutStopSec=60" in unit assert "TimeoutStopSec=60" in unit
assert "WantedBy=multi-user.target" in unit assert "WantedBy=multi-user.target" in unit
@@ -157,6 +165,31 @@ class TestGatewayStopCleanup:
class TestLaunchdServiceRecovery: class TestLaunchdServiceRecovery:
def test_get_restart_drain_timeout_prefers_env_then_config_then_default(self, monkeypatch):
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
monkeypatch.setattr(gateway_cli, "read_raw_config", lambda: {})
assert (
gateway_cli._get_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
monkeypatch.setattr(
gateway_cli,
"read_raw_config",
lambda: {"agent": {"restart_drain_timeout": 14}},
)
assert gateway_cli._get_restart_drain_timeout() == 14.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "9")
assert gateway_cli._get_restart_drain_timeout() == 9.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
assert (
gateway_cli._get_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch): def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
plist_path = tmp_path / "ai.hermes.gateway.plist" plist_path = tmp_path / "ai.hermes.gateway.plist"
plist_path.write_text("<plist>old content</plist>", encoding="utf-8") plist_path.write_text("<plist>old content</plist>", encoding="utf-8")
@@ -234,6 +267,55 @@ class TestLaunchdServiceRecovery:
["launchctl", "kickstart", target], ["launchctl", "kickstart", target],
] ]
def test_launchd_restart_drains_running_gateway_before_kickstart(self, monkeypatch):
calls = []
target = f"{gateway_cli._launchd_domain()}/{gateway_cli.get_launchd_label()}"
monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 12.0)
monkeypatch.setattr(gateway_cli, "_request_gateway_self_restart", lambda pid: False)
monkeypatch.setattr(gateway_cli, "_wait_for_gateway_exit", lambda timeout, force_after=None: True)
monkeypatch.setattr(gateway_cli, "terminate_pid", lambda pid, force=False: calls.append(("term", pid, force)))
monkeypatch.setattr(
"gateway.status.get_running_pid",
lambda: 321,
)
def fake_run(cmd, check=False, **kwargs):
calls.append(cmd)
return SimpleNamespace(returncode=0, stdout="", stderr="")
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
gateway_cli.launchd_restart()
assert calls == [
("term", 321, False),
["launchctl", "kickstart", "-k", target],
]
def test_launchd_restart_self_requests_graceful_restart_without_kickstart(self, monkeypatch, capsys):
calls = []
monkeypatch.setattr(
"gateway.status.get_running_pid",
lambda: 321,
)
monkeypatch.setattr(
gateway_cli,
"_request_gateway_self_restart",
lambda pid: calls.append(("self", pid)) or True,
)
monkeypatch.setattr(
gateway_cli.subprocess,
"run",
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("launchctl should not run")),
)
gateway_cli.launchd_restart()
assert calls == [("self", 321)]
assert "restart requested" in capsys.readouterr().out.lower()
def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch): def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch):
"""launchd_stop must bootout the service so KeepAlive doesn't respawn it.""" """launchd_stop must bootout the service so KeepAlive doesn't respawn it."""
label = gateway_cli.get_launchd_label() label = gateway_cli.get_launchd_label()
@@ -337,6 +419,31 @@ class TestGatewayServiceDetection:
class TestGatewaySystemServiceRouting: class TestGatewaySystemServiceRouting:
def test_systemd_restart_self_requests_graceful_restart_without_reload_or_restart(self, monkeypatch, capsys):
calls = []
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system)))
monkeypatch.setattr(
"gateway.status.get_running_pid",
lambda: 654,
)
monkeypatch.setattr(
gateway_cli,
"_request_gateway_self_restart",
lambda pid: calls.append(("self", pid)) or True,
)
monkeypatch.setattr(
gateway_cli.subprocess,
"run",
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("systemctl should not run")),
)
gateway_cli.systemd_restart()
assert calls == [("refresh", False), ("self", 654)]
assert "restart requested" in capsys.readouterr().out.lower()
def test_gateway_install_passes_system_flags(self, monkeypatch): def test_gateway_install_passes_system_flags(self, monkeypatch):
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True) monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False) monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)