mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 07:21:37 +08:00
Compare commits
5 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4c3a57a3a | ||
|
|
eb3f021e2a | ||
|
|
b4cb803954 | ||
|
|
b5928530d3 | ||
|
|
ef120b5422 |
@@ -480,6 +480,12 @@ agent:
|
||||
# Fires once per run when inactivity reaches this threshold (seconds).
|
||||
# Set to 0 to disable the warning.
|
||||
# gateway_timeout_warning: 900
|
||||
|
||||
# Graceful drain timeout for gateway stop/restart (seconds).
|
||||
# The gateway stops accepting new work, waits for in-flight agents to
|
||||
# finish, then interrupts anything still running after this timeout.
|
||||
# 0 = no drain, interrupt immediately.
|
||||
# restart_drain_timeout: 60
|
||||
|
||||
# Enable verbose logging
|
||||
verbose: false
|
||||
|
||||
@@ -673,6 +673,32 @@ class SendResult:
|
||||
retryable: bool = False # True for transient connection errors — base will retry automatically
|
||||
|
||||
|
||||
def merge_pending_message_event(
|
||||
pending_messages: Dict[str, MessageEvent],
|
||||
session_key: str,
|
||||
event: MessageEvent,
|
||||
) -> None:
|
||||
"""Store or merge a pending event for a session.
|
||||
|
||||
Photo bursts/albums often arrive as multiple near-simultaneous PHOTO
|
||||
events. Merge those into the existing queued event so the next turn sees
|
||||
the whole burst, while non-photo follow-ups still replace the pending
|
||||
event normally.
|
||||
"""
|
||||
existing = pending_messages.get(session_key)
|
||||
if (
|
||||
existing
|
||||
and getattr(existing, "message_type", None) == MessageType.PHOTO
|
||||
and event.message_type == MessageType.PHOTO
|
||||
):
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||
return
|
||||
pending_messages[session_key] = event
|
||||
|
||||
|
||||
# Error substrings that indicate a transient *connection* failure worth retrying.
|
||||
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
|
||||
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
|
||||
@@ -727,6 +753,7 @@ class BasePlatformAdapter(ABC):
|
||||
# working on a task after --replace or manual restarts.
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
self._expected_cancelled_tasks: set[asyncio.Task] = set()
|
||||
self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None
|
||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||
self._auto_tts_disabled_chats: set = set()
|
||||
# Chats where typing indicator is paused (e.g. during approval waits).
|
||||
@@ -815,6 +842,10 @@ class BasePlatformAdapter(ABC):
|
||||
an optional response string.
|
||||
"""
|
||||
self._message_handler = handler
|
||||
|
||||
def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None:
|
||||
"""Set an optional handler for messages arriving during active sessions."""
|
||||
self._busy_session_handler = handler
|
||||
|
||||
def set_session_store(self, session_store: Any) -> None:
|
||||
"""
|
||||
@@ -1396,7 +1427,7 @@ class BasePlatformAdapter(ABC):
|
||||
# session lifecycle and its cleanup races with the running task
|
||||
# (see PR #4926).
|
||||
cmd = event.get_command()
|
||||
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background"):
|
||||
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background", "restart"):
|
||||
logger.debug(
|
||||
"[%s] Command '/%s' bypassing active-session guard for %s",
|
||||
self.name, cmd, session_key,
|
||||
@@ -1415,19 +1446,19 @@ class BasePlatformAdapter(ABC):
|
||||
logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True)
|
||||
return
|
||||
|
||||
if self._busy_session_handler is not None:
|
||||
try:
|
||||
if await self._busy_session_handler(event, session_key):
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error("[%s] Busy-session handler failed: %s", self.name, e, exc_info=True)
|
||||
|
||||
# Special case: photo bursts/albums frequently arrive as multiple near-
|
||||
# simultaneous messages. Queue them without interrupting the active run,
|
||||
# then process them immediately after the current task finishes.
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
|
||||
existing = self._pending_messages.get(session_key)
|
||||
if existing and existing.message_type == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = self._merge_caption(existing.text, event.text)
|
||||
else:
|
||||
self._pending_messages[session_key] = event
|
||||
merge_pending_message_event(self._pending_messages, session_key, event)
|
||||
return # Don't interrupt now - will run after current task completes
|
||||
|
||||
# Default behavior for non-photo follow-ups: interrupt the running agent
|
||||
|
||||
20
gateway/restart.py
Normal file
20
gateway/restart.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Shared gateway restart constants and parsing helpers."""
|
||||
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
# EX_TEMPFAIL from sysexits.h — used to ask the service manager to restart
|
||||
# the gateway after a graceful drain/reload path completes.
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE = 75
|
||||
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT = float(
|
||||
DEFAULT_CONFIG["agent"]["restart_drain_timeout"]
|
||||
)
|
||||
|
||||
|
||||
def parse_restart_drain_timeout(raw: object) -> float:
|
||||
"""Parse a configured drain timeout, falling back to the shared default."""
|
||||
try:
|
||||
value = float(raw) if str(raw or "").strip() else DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
except (TypeError, ValueError):
|
||||
return DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
return max(0.0, value)
|
||||
499
gateway/run.py
499
gateway/run.py
@@ -186,6 +186,12 @@ if _config_path.exists():
|
||||
os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"])
|
||||
if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ:
|
||||
os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"])
|
||||
if "restart_drain_timeout" in _agent_cfg and "HERMES_RESTART_DRAIN_TIMEOUT" not in os.environ:
|
||||
os.environ["HERMES_RESTART_DRAIN_TIMEOUT"] = str(_agent_cfg["restart_drain_timeout"])
|
||||
_display_cfg = _cfg.get("display", {})
|
||||
if _display_cfg and isinstance(_display_cfg, dict):
|
||||
if "busy_input_mode" in _display_cfg and "HERMES_GATEWAY_BUSY_INPUT_MODE" not in os.environ:
|
||||
os.environ["HERMES_GATEWAY_BUSY_INPUT_MODE"] = str(_display_cfg["busy_input_mode"])
|
||||
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
|
||||
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
|
||||
_tz_cfg = _cfg.get("timezone", "")
|
||||
@@ -235,7 +241,17 @@ from gateway.session import (
|
||||
build_session_key,
|
||||
)
|
||||
from gateway.delivery import DeliveryRouter
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
merge_pending_message_event,
|
||||
)
|
||||
from gateway.restart import (
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||
parse_restart_drain_timeout,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_whatsapp_identifier(value: str) -> str:
|
||||
@@ -471,6 +487,16 @@ class GatewayRunner:
|
||||
# Class-level defaults so partial construction in tests doesn't
|
||||
# blow up on attribute access.
|
||||
_running_agents_ts: Dict[str, float] = {}
|
||||
_busy_input_mode: str = "interrupt"
|
||||
_restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
_exit_code: Optional[int] = None
|
||||
_draining: bool = False
|
||||
_restart_requested: bool = False
|
||||
_restart_task_started: bool = False
|
||||
_restart_detached: bool = False
|
||||
_restart_via_service: bool = False
|
||||
_stop_task: Optional[asyncio.Task] = None
|
||||
_session_model_overrides: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
def __init__(self, config: Optional[GatewayConfig] = None):
|
||||
self.config = config or load_gateway_config()
|
||||
@@ -483,6 +509,8 @@ class GatewayRunner:
|
||||
self._reasoning_config = self._load_reasoning_config()
|
||||
self._service_tier = self._load_service_tier()
|
||||
self._show_reasoning = self._load_show_reasoning()
|
||||
self._busy_input_mode = self._load_busy_input_mode()
|
||||
self._restart_drain_timeout = self._load_restart_drain_timeout()
|
||||
self._provider_routing = self._load_provider_routing()
|
||||
self._fallback_model = self._load_fallback_model()
|
||||
self._smart_model_routing = self._load_smart_model_routing()
|
||||
@@ -499,6 +527,13 @@ class GatewayRunner:
|
||||
self._exit_cleanly = False
|
||||
self._exit_with_failure = False
|
||||
self._exit_reason: Optional[str] = None
|
||||
self._exit_code: Optional[int] = None
|
||||
self._draining = False
|
||||
self._restart_requested = False
|
||||
self._restart_task_started = False
|
||||
self._restart_detached = False
|
||||
self._restart_via_service = False
|
||||
self._stop_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Track running agents per session for interrupt support
|
||||
# Key: session_key, Value: AIAgent instance
|
||||
@@ -759,6 +794,10 @@ class GatewayRunner:
|
||||
def exit_reason(self) -> Optional[str]:
|
||||
return self._exit_reason
|
||||
|
||||
@property
|
||||
def exit_code(self) -> Optional[int]:
|
||||
return self._exit_code
|
||||
|
||||
def _session_key_for_source(self, source: SessionSource) -> str:
|
||||
"""Resolve the current session key for a source, honoring gateway config when available."""
|
||||
if hasattr(self, "session_store") and self.session_store is not None:
|
||||
@@ -868,6 +907,30 @@ class GatewayRunner:
|
||||
self._exit_cleanly = True
|
||||
self._exit_reason = reason
|
||||
self._shutdown_event.set()
|
||||
|
||||
def _running_agent_count(self) -> int:
|
||||
return len(self._running_agents)
|
||||
|
||||
def _status_action_label(self) -> str:
|
||||
return "restart" if self._restart_requested else "shutdown"
|
||||
|
||||
def _status_action_gerund(self) -> str:
|
||||
return "restarting" if self._restart_requested else "shutting down"
|
||||
|
||||
def _queue_during_drain_enabled(self) -> bool:
|
||||
return self._restart_requested and self._busy_input_mode == "queue"
|
||||
|
||||
def _update_runtime_status(self, gateway_state: Optional[str] = None, exit_reason: Optional[str] = None) -> None:
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(
|
||||
gateway_state=gateway_state,
|
||||
exit_reason=exit_reason,
|
||||
restart_requested=self._restart_requested,
|
||||
active_agents=self._running_agent_count(),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
||||
@@ -994,6 +1057,48 @@ class GatewayRunner:
|
||||
pass
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _load_busy_input_mode() -> str:
|
||||
"""Load gateway drain-time busy-input behavior from config/env."""
|
||||
mode = os.getenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "").strip().lower()
|
||||
if not mode:
|
||||
try:
|
||||
import yaml as _y
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path, encoding="utf-8") as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
mode = str(cfg.get("display", {}).get("busy_input_mode", "") or "").strip().lower()
|
||||
except Exception:
|
||||
pass
|
||||
return "queue" if mode == "queue" else "interrupt"
|
||||
|
||||
@staticmethod
|
||||
def _load_restart_drain_timeout() -> float:
|
||||
"""Load graceful gateway restart/stop drain timeout in seconds."""
|
||||
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
|
||||
if not raw:
|
||||
try:
|
||||
import yaml as _y
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path, encoding="utf-8") as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip()
|
||||
except Exception:
|
||||
pass
|
||||
value = parse_restart_drain_timeout(raw)
|
||||
if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT:
|
||||
try:
|
||||
float(raw)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid restart_drain_timeout '%s', using default %.0fs",
|
||||
raw,
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _load_background_notifications_mode() -> str:
|
||||
"""Load background process notification mode from config or env var.
|
||||
@@ -1078,6 +1183,155 @@ class GatewayRunner:
|
||||
pass
|
||||
return {}
|
||||
|
||||
def _snapshot_running_agents(self) -> Dict[str, Any]:
|
||||
return {
|
||||
session_key: agent
|
||||
for session_key, agent in self._running_agents.items()
|
||||
if agent is not _AGENT_PENDING_SENTINEL
|
||||
}
|
||||
|
||||
def _queue_or_replace_pending_event(self, session_key: str, event: MessageEvent) -> None:
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
if not adapter:
|
||||
return
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, event)
|
||||
|
||||
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
|
||||
if not self._draining:
|
||||
return False
|
||||
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
if not adapter:
|
||||
return True
|
||||
|
||||
thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||
if self._queue_during_drain_enabled():
|
||||
self._queue_or_replace_pending_event(session_key, event)
|
||||
message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
|
||||
else:
|
||||
message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
|
||||
|
||||
await adapter._send_with_retry(
|
||||
chat_id=event.source.chat_id,
|
||||
content=message,
|
||||
reply_to=event.message_id,
|
||||
metadata=thread_meta,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
|
||||
snapshot = self._snapshot_running_agents()
|
||||
last_active_count = self._running_agent_count()
|
||||
last_status_at = 0.0
|
||||
|
||||
def _maybe_update_status(force: bool = False) -> None:
|
||||
nonlocal last_active_count, last_status_at
|
||||
now = asyncio.get_running_loop().time()
|
||||
active_count = self._running_agent_count()
|
||||
if force or active_count != last_active_count or (now - last_status_at) >= 1.0:
|
||||
self._update_runtime_status("draining")
|
||||
last_active_count = active_count
|
||||
last_status_at = now
|
||||
|
||||
if not self._running_agents:
|
||||
_maybe_update_status(force=True)
|
||||
return snapshot, False
|
||||
|
||||
_maybe_update_status(force=True)
|
||||
if timeout <= 0:
|
||||
return snapshot, True
|
||||
|
||||
deadline = asyncio.get_running_loop().time() + timeout
|
||||
while self._running_agents and asyncio.get_running_loop().time() < deadline:
|
||||
_maybe_update_status()
|
||||
await asyncio.sleep(0.1)
|
||||
timed_out = bool(self._running_agents)
|
||||
_maybe_update_status(force=True)
|
||||
return snapshot, timed_out
|
||||
|
||||
def _interrupt_running_agents(self, reason: str) -> None:
|
||||
for session_key, agent in list(self._running_agents.items()):
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
continue
|
||||
try:
|
||||
agent.interrupt(reason)
|
||||
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
||||
except Exception as e:
|
||||
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
||||
|
||||
def _finalize_shutdown_agents(self, active_agents: Dict[str, Any]) -> None:
|
||||
for agent in active_agents.values():
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_invoke_hook(
|
||||
"on_session_finalize",
|
||||
session_id=getattr(agent, "session_id", None),
|
||||
platform="gateway",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(agent, "shutdown_memory_provider"):
|
||||
agent.shutdown_memory_provider()
|
||||
except Exception:
|
||||
pass
|
||||
# Close tool resources (terminal sandboxes, browser daemons,
|
||||
# background processes, httpx clients) to prevent zombie
|
||||
# process accumulation.
|
||||
try:
|
||||
if hasattr(agent, 'close'):
|
||||
agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _launch_detached_restart_command(self) -> None:
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
hermes_cmd = _resolve_hermes_bin()
|
||||
if not hermes_cmd:
|
||||
logger.error("Could not locate hermes binary for detached /restart")
|
||||
return
|
||||
|
||||
current_pid = os.getpid()
|
||||
cmd = " ".join(shlex.quote(part) for part in hermes_cmd)
|
||||
shell_cmd = (
|
||||
f"while kill -0 {current_pid} 2>/dev/null; do sleep 0.2; done; "
|
||||
f"{cmd} gateway restart"
|
||||
)
|
||||
setsid_bin = shutil.which("setsid")
|
||||
if setsid_bin:
|
||||
subprocess.Popen(
|
||||
[setsid_bin, "bash", "-lc", shell_cmd],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
start_new_session=True,
|
||||
)
|
||||
else:
|
||||
subprocess.Popen(
|
||||
["bash", "-lc", shell_cmd],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
def request_restart(self, *, detached: bool = False, via_service: bool = False) -> bool:
|
||||
if self._restart_task_started:
|
||||
return False
|
||||
self._restart_requested = True
|
||||
self._restart_detached = detached
|
||||
self._restart_via_service = via_service
|
||||
self._restart_task_started = True
|
||||
|
||||
async def _run_restart() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
await self.stop(restart=True, detached_restart=detached, service_restart=via_service)
|
||||
|
||||
task = asyncio.create_task(_run_restart())
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return True
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
Start the gateway and all configured platform adapters.
|
||||
@@ -1165,6 +1419,7 @@ class GatewayRunner:
|
||||
adapter.set_message_handler(self._handle_message)
|
||||
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
||||
adapter.set_session_store(self.session_store)
|
||||
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
|
||||
|
||||
# Try to connect
|
||||
logger.info("Connecting to %s...", platform.value)
|
||||
@@ -1240,11 +1495,7 @@ class GatewayRunner:
|
||||
self.delivery_router.adapters = self.adapters
|
||||
|
||||
self._running = True
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(gateway_state="running", exit_reason=None)
|
||||
except Exception:
|
||||
pass
|
||||
self._update_runtime_status("running")
|
||||
|
||||
# Emit gateway:startup hook
|
||||
hook_count = len(self.hooks.loaded_hooks)
|
||||
@@ -1479,6 +1730,7 @@ class GatewayRunner:
|
||||
adapter.set_message_handler(self._handle_message)
|
||||
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
||||
adapter.set_session_store(self.session_store)
|
||||
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
|
||||
|
||||
success = await adapter.connect()
|
||||
if success:
|
||||
@@ -1525,90 +1777,108 @@ class GatewayRunner:
|
||||
return
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
async def stop(
|
||||
self,
|
||||
*,
|
||||
restart: bool = False,
|
||||
detached_restart: bool = False,
|
||||
service_restart: bool = False,
|
||||
) -> None:
|
||||
"""Stop the gateway and disconnect all adapters."""
|
||||
logger.info("Stopping gateway...")
|
||||
self._running = False
|
||||
if restart:
|
||||
self._restart_requested = True
|
||||
self._restart_detached = detached_restart
|
||||
self._restart_via_service = service_restart
|
||||
if self._stop_task is not None:
|
||||
await self._stop_task
|
||||
return
|
||||
|
||||
for session_key, agent in list(self._running_agents.items()):
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
continue
|
||||
async def _stop_impl() -> None:
|
||||
logger.info(
|
||||
"Stopping gateway%s...",
|
||||
" for restart" if self._restart_requested else "",
|
||||
)
|
||||
self._running = False
|
||||
self._draining = True
|
||||
|
||||
timeout = self._restart_drain_timeout
|
||||
active_agents, timed_out = await self._drain_active_agents(timeout)
|
||||
if timed_out:
|
||||
logger.warning(
|
||||
"Gateway drain timed out after %.1fs with %d active agent(s); interrupting remaining work.",
|
||||
timeout,
|
||||
self._running_agent_count(),
|
||||
)
|
||||
self._interrupt_running_agents(
|
||||
"Gateway restarting" if self._restart_requested else "Gateway shutting down"
|
||||
)
|
||||
interrupt_deadline = asyncio.get_running_loop().time() + 5.0
|
||||
while self._running_agents and asyncio.get_running_loop().time() < interrupt_deadline:
|
||||
self._update_runtime_status("draining")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if self._restart_requested and self._restart_detached:
|
||||
try:
|
||||
await self._launch_detached_restart_command()
|
||||
except Exception as e:
|
||||
logger.error("Failed to launch detached gateway restart: %s", e)
|
||||
|
||||
self._finalize_shutdown_agents(active_agents)
|
||||
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
await adapter.cancel_background_tasks()
|
||||
except Exception as e:
|
||||
logger.debug("✗ %s background-task cancel error: %s", platform.value, e)
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
logger.info("✓ %s disconnected", platform.value)
|
||||
except Exception as e:
|
||||
logger.error("✗ %s disconnect error: %s", platform.value, e)
|
||||
|
||||
for _task in list(self._background_tasks):
|
||||
if _task is self._stop_task:
|
||||
continue
|
||||
_task.cancel()
|
||||
self._background_tasks.clear()
|
||||
|
||||
self.adapters.clear()
|
||||
self._running_agents.clear()
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Global cleanup: kill any remaining tool subprocesses not tied
|
||||
# to a specific agent (catch-all for zombie prevention).
|
||||
try:
|
||||
agent.interrupt("Gateway shutting down")
|
||||
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
||||
except Exception as e:
|
||||
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
||||
# Fire plugin on_session_finalize hook before memory shutdown
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_invoke_hook("on_session_finalize",
|
||||
session_id=getattr(agent, 'session_id', None),
|
||||
platform="gateway")
|
||||
from tools.process_registry import process_registry
|
||||
process_registry.kill_all()
|
||||
except Exception:
|
||||
pass
|
||||
# Shut down memory provider at actual session boundary
|
||||
try:
|
||||
if hasattr(agent, 'shutdown_memory_provider'):
|
||||
agent.shutdown_memory_provider()
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
# Close tool resources (terminal sandboxes, browser daemons,
|
||||
# background processes, httpx clients) to prevent zombie
|
||||
# process accumulation.
|
||||
try:
|
||||
if hasattr(agent, 'close'):
|
||||
agent.close()
|
||||
from tools.browser_tool import cleanup_all_browsers
|
||||
cleanup_all_browsers()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
await adapter.cancel_background_tasks()
|
||||
except Exception as e:
|
||||
logger.debug("✗ %s background-task cancel error: %s", platform.value, e)
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
logger.info("✓ %s disconnected", platform.value)
|
||||
except Exception as e:
|
||||
logger.error("✗ %s disconnect error: %s", platform.value, e)
|
||||
from gateway.status import remove_pid_file
|
||||
remove_pid_file()
|
||||
|
||||
# Cancel any pending background tasks
|
||||
for _task in list(self._background_tasks):
|
||||
_task.cancel()
|
||||
self._background_tasks.clear()
|
||||
if self._restart_requested and self._restart_via_service:
|
||||
self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||
self._exit_reason = self._exit_reason or "Gateway restart requested"
|
||||
|
||||
self.adapters.clear()
|
||||
self._running_agents.clear()
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
self._shutdown_event.set()
|
||||
self._draining = False
|
||||
self._update_runtime_status("stopped", self._exit_reason)
|
||||
logger.info("Gateway stopped")
|
||||
|
||||
# Global cleanup: kill any remaining tool subprocesses not tied
|
||||
# to a specific agent (catch-all for zombie prevention).
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
process_registry.kill_all()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from tools.browser_tool import cleanup_all_browsers
|
||||
cleanup_all_browsers()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from gateway.status import remove_pid_file, write_runtime_status
|
||||
remove_pid_file()
|
||||
try:
|
||||
write_runtime_status(gateway_state="stopped", exit_reason=self._exit_reason)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Gateway stopped")
|
||||
self._stop_task = asyncio.create_task(_stop_impl())
|
||||
await self._stop_task
|
||||
|
||||
async def wait_for_shutdown(self) -> None:
|
||||
"""Wait for shutdown signal."""
|
||||
@@ -2014,6 +2284,9 @@ class GatewayRunner:
|
||||
_evt_cmd = event.get_command()
|
||||
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
|
||||
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "restart":
|
||||
return await self._handle_restart_command(event)
|
||||
|
||||
# /stop must hard-kill the session when an agent is running.
|
||||
# A soft interrupt (agent.interrupt()) doesn't help when the agent
|
||||
# is truly hung — the executor thread is blocked and never checks
|
||||
@@ -2094,18 +2367,7 @@ class GatewayRunner:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
# Reuse adapter queue semantics so photo bursts merge cleanly.
|
||||
if _quick_key in adapter._pending_messages:
|
||||
existing = adapter._pending_messages[_quick_key]
|
||||
if getattr(existing, "message_type", None) == MessageType.PHOTO:
|
||||
existing.media_urls.extend(event.media_urls)
|
||||
existing.media_types.extend(event.media_types)
|
||||
if event.text:
|
||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
else:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
merge_pending_message_event(adapter._pending_messages, _quick_key, event)
|
||||
return None
|
||||
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
@@ -2123,6 +2385,14 @@ class GatewayRunner:
|
||||
if adapter:
|
||||
adapter._pending_messages[_quick_key] = event
|
||||
return None
|
||||
if self._draining:
|
||||
if self._queue_during_drain_enabled():
|
||||
self._queue_or_replace_pending_event(_quick_key, event)
|
||||
return (
|
||||
f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
|
||||
if self._queue_during_drain_enabled()
|
||||
else f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
|
||||
)
|
||||
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
|
||||
running_agent.interrupt(event.text)
|
||||
if _quick_key in self._pending_messages:
|
||||
@@ -2164,6 +2434,9 @@ class GatewayRunner:
|
||||
|
||||
if canonical == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
if canonical == "restart":
|
||||
return await self._handle_restart_command(event)
|
||||
|
||||
if canonical == "stop":
|
||||
return await self._handle_stop_command(event)
|
||||
@@ -2262,6 +2535,9 @@ class GatewayRunner:
|
||||
if canonical == "voice":
|
||||
return await self._handle_voice_command(event)
|
||||
|
||||
if self._draining:
|
||||
return f"⏳ Gateway is {self._status_action_gerund()} and is not accepting new work right now."
|
||||
|
||||
# User-defined quick commands (bypass agent loop, no LLM call)
|
||||
if command:
|
||||
if isinstance(self.config, dict):
|
||||
@@ -3556,7 +3832,21 @@ class GatewayRunner:
|
||||
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
|
||||
else:
|
||||
return "No active task to stop."
|
||||
|
||||
|
||||
async def _handle_restart_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /restart command - drain active work, then restart the gateway."""
|
||||
if self._restart_requested or self._draining:
|
||||
count = self._running_agent_count()
|
||||
if count:
|
||||
return f"⏳ Draining {count} active agent(s) before restart..."
|
||||
return "⏳ Gateway restart already in progress..."
|
||||
|
||||
active_agents = self._running_agent_count()
|
||||
self.request_restart(detached=True, via_service=False)
|
||||
if active_agents:
|
||||
return f"⏳ Draining {active_agents} active agent(s) before restart..."
|
||||
return "♻ Restarting gateway..."
|
||||
|
||||
async def _handle_help_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /help command - list available commands."""
|
||||
from hermes_cli.commands import gateway_help_lines
|
||||
@@ -3679,7 +3969,7 @@ class GatewayRunner:
|
||||
# Check for session override
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
override = getattr(self, "_session_model_overrides", {}).get(session_key, {})
|
||||
override = self._session_model_overrides.get(session_key, {})
|
||||
if override:
|
||||
current_model = override.get("model", current_model)
|
||||
current_provider = override.get("provider", current_provider)
|
||||
@@ -3761,8 +4051,6 @@ class GatewayRunner:
|
||||
f"via {result.provider_label or result.target_provider}. "
|
||||
f"Adjust your self-identification accordingly.]"
|
||||
)
|
||||
if not hasattr(_self, "_session_model_overrides"):
|
||||
_self._session_model_overrides = {}
|
||||
_self._session_model_overrides[_session_key] = {
|
||||
"model": result.new_model,
|
||||
"provider": result.target_provider,
|
||||
@@ -3876,8 +4164,6 @@ class GatewayRunner:
|
||||
)
|
||||
|
||||
# Store session override so next agent creation uses the new model
|
||||
if not hasattr(self, "_session_model_overrides"):
|
||||
self._session_model_overrides = {}
|
||||
self._session_model_overrides[session_key] = {
|
||||
"model": result.new_model,
|
||||
"provider": result.target_provider,
|
||||
@@ -7363,6 +7649,8 @@ class GatewayRunner:
|
||||
await asyncio.sleep(0.05)
|
||||
if session_key:
|
||||
self._running_agents[session_key] = agent_holder[0]
|
||||
if self._draining:
|
||||
self._update_runtime_status("draining")
|
||||
|
||||
tracking_task = asyncio.create_task(track_agent())
|
||||
|
||||
@@ -7608,6 +7896,14 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self._draining and pending:
|
||||
logger.info(
|
||||
"Discarding pending follow-up for session %s during gateway %s",
|
||||
session_key[:20] if session_key else "?",
|
||||
self._status_action_label(),
|
||||
)
|
||||
pending = None
|
||||
|
||||
if pending:
|
||||
logger.debug("Processing pending message: '%s...'", pending[:40])
|
||||
|
||||
@@ -7684,6 +7980,8 @@ class GatewayRunner:
|
||||
del self._running_agents[session_key]
|
||||
if session_key:
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
if self._draining:
|
||||
self._update_runtime_status("draining")
|
||||
|
||||
# Wait for cancelled tasks
|
||||
for task in [progress_task, interrupt_monitor, tracking_task, _notify_task]:
|
||||
@@ -7881,13 +8179,21 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
# Set up signal handlers
|
||||
def signal_handler():
|
||||
def shutdown_signal_handler():
|
||||
asyncio.create_task(runner.stop())
|
||||
|
||||
def restart_signal_handler():
|
||||
runner.request_restart(detached=False, via_service=True)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
loop.add_signal_handler(sig, shutdown_signal_handler)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
if hasattr(signal, "SIGUSR1"):
|
||||
try:
|
||||
loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
@@ -7937,6 +8243,9 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if runner.exit_code is not None:
|
||||
raise SystemExit(runner.exit_code)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -158,6 +158,8 @@ def _build_runtime_status_record() -> dict[str, Any]:
|
||||
payload.update({
|
||||
"gateway_state": "starting",
|
||||
"exit_reason": None,
|
||||
"restart_requested": False,
|
||||
"active_agents": 0,
|
||||
"platforms": {},
|
||||
"updated_at": _utc_now_iso(),
|
||||
})
|
||||
@@ -218,6 +220,8 @@ def write_runtime_status(
|
||||
*,
|
||||
gateway_state: Optional[str] = None,
|
||||
exit_reason: Optional[str] = None,
|
||||
restart_requested: Optional[bool] = None,
|
||||
active_agents: Optional[int] = None,
|
||||
platform: Optional[str] = None,
|
||||
platform_state: Optional[str] = None,
|
||||
error_code: Optional[str] = None,
|
||||
@@ -236,6 +240,10 @@ def write_runtime_status(
|
||||
payload["gateway_state"] = gateway_state
|
||||
if exit_reason is not None:
|
||||
payload["exit_reason"] = exit_reason
|
||||
if restart_requested is not None:
|
||||
payload["restart_requested"] = bool(restart_requested)
|
||||
if active_agents is not None:
|
||||
payload["active_agents"] = max(0, int(active_agents))
|
||||
|
||||
if platform is not None:
|
||||
platform_payload = payload["platforms"].get(platform, {})
|
||||
|
||||
@@ -140,6 +140,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
CommandDef("commands", "Browse all commands and skills (paginated)", "Info",
|
||||
gateway_only=True, args_hint="[page]"),
|
||||
CommandDef("help", "Show available commands", "Info"),
|
||||
CommandDef("restart", "Gracefully restart the gateway after draining active runs", "Session",
|
||||
gateway_only=True),
|
||||
CommandDef("usage", "Show token usage and rate limits for the current session", "Info"),
|
||||
CommandDef("insights", "Show usage insights and analytics", "Info",
|
||||
args_hint="[days]"),
|
||||
|
||||
@@ -269,6 +269,11 @@ DEFAULT_CONFIG = {
|
||||
# tools or receiving API responses. Only fires when the agent has
|
||||
# been completely idle for this duration. 0 = unlimited.
|
||||
"gateway_timeout": 1800,
|
||||
# Graceful drain timeout for gateway stop/restart (seconds).
|
||||
# The gateway stops accepting new work, waits for running agents
|
||||
# to finish, then interrupts any remaining runs after the timeout.
|
||||
# 0 = no drain, interrupt immediately.
|
||||
"restart_drain_timeout": 60,
|
||||
"service_tier": "",
|
||||
# Tool-use enforcement: injects system prompt guidance that tells the
|
||||
# model to actually call tools instead of describing intended actions.
|
||||
|
||||
@@ -15,7 +15,19 @@ from pathlib import Path
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from gateway.status import terminate_pid
|
||||
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error
|
||||
from gateway.restart import (
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||
parse_restart_drain_timeout,
|
||||
)
|
||||
from hermes_cli.config import (
|
||||
get_env_value,
|
||||
get_hermes_home,
|
||||
is_managed,
|
||||
managed_error,
|
||||
read_raw_config,
|
||||
save_env_value,
|
||||
)
|
||||
# display_hermes_home is imported lazily at call sites to avoid ImportError
|
||||
# when hermes_constants is cached from a pre-update version during `hermes update`.
|
||||
from hermes_cli.setup import (
|
||||
@@ -92,6 +104,59 @@ def _get_service_pids() -> set:
|
||||
return pids
|
||||
|
||||
|
||||
def _get_parent_pid(pid: int) -> int | None:
|
||||
"""Return the parent PID for ``pid``, or ``None`` when unavailable."""
|
||||
if pid <= 1:
|
||||
return None
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ps", "-o", "ppid=", "-p", str(pid)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
return None
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
raw = result.stdout.strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parent_pid = int(raw.splitlines()[-1].strip())
|
||||
except ValueError:
|
||||
return None
|
||||
return parent_pid if parent_pid > 0 else None
|
||||
|
||||
|
||||
def _is_pid_ancestor_of_current_process(target_pid: int) -> bool:
|
||||
"""Return True when ``target_pid`` is this process or one of its ancestors."""
|
||||
if target_pid <= 0:
|
||||
return False
|
||||
|
||||
pid = os.getpid()
|
||||
seen: set[int] = set()
|
||||
while pid and pid not in seen:
|
||||
if pid == target_pid:
|
||||
return True
|
||||
seen.add(pid)
|
||||
pid = _get_parent_pid(pid) or 0
|
||||
return False
|
||||
|
||||
|
||||
def _request_gateway_self_restart(pid: int) -> bool:
|
||||
"""Ask a running gateway ancestor to restart itself asynchronously."""
|
||||
if not hasattr(signal, "SIGUSR1"):
|
||||
return False
|
||||
if not _is_pid_ancestor_of_current_process(pid):
|
||||
return False
|
||||
try:
|
||||
os.kill(pid, signal.SIGUSR1)
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
|
||||
@@ -665,6 +730,7 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
||||
path_entries.append(resolved_node_dir)
|
||||
|
||||
common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"]
|
||||
restart_timeout = max(60, int(_get_restart_drain_timeout() or 0))
|
||||
|
||||
if system:
|
||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||
@@ -703,9 +769,11 @@ Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=30
|
||||
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=60
|
||||
ExecReload=/bin/kill -USR1 $MAINPID
|
||||
TimeoutStopSec={restart_timeout}
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -733,9 +801,11 @@ Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=30
|
||||
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=60
|
||||
ExecReload=/bin/kill -USR1 $MAINPID
|
||||
TimeoutStopSec={restart_timeout}
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -838,6 +908,20 @@ def _select_systemd_scope(system: bool = False) -> bool:
|
||||
return get_systemd_unit_path(system=True).exists() and not get_systemd_unit_path(system=False).exists()
|
||||
|
||||
|
||||
def _get_restart_drain_timeout() -> float:
|
||||
"""Return the configured gateway restart drain timeout in seconds."""
|
||||
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
|
||||
if not raw:
|
||||
cfg = read_raw_config()
|
||||
agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {}
|
||||
raw = str(
|
||||
agent_cfg.get(
|
||||
"restart_drain_timeout", DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
)
|
||||
return parse_restart_drain_timeout(raw)
|
||||
|
||||
|
||||
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
|
||||
if system:
|
||||
_require_root_for_system_service("install")
|
||||
@@ -923,7 +1007,13 @@ def systemd_restart(system: bool = False):
|
||||
if system:
|
||||
_require_root_for_system_service("restart")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
subprocess.run(_systemctl_cmd(system) + ["restart", get_service_name()], check=True, timeout=90)
|
||||
from gateway.status import get_running_pid
|
||||
|
||||
pid = get_running_pid()
|
||||
if pid is not None and _request_gateway_self_restart(pid):
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service restart requested")
|
||||
return
|
||||
subprocess.run(_systemctl_cmd(system) + ["reload-or-restart", get_service_name()], check=True, timeout=90)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service restarted")
|
||||
|
||||
|
||||
@@ -1211,7 +1301,7 @@ def launchd_stop():
|
||||
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
|
||||
print("✓ Service stopped")
|
||||
|
||||
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
||||
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float | None = 5.0) -> bool:
|
||||
"""Wait for the gateway process (by saved PID) to exit.
|
||||
|
||||
Uses the PID from the gateway.pid file — not launchd labels — so this
|
||||
@@ -1226,21 +1316,21 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
||||
from gateway.status import get_running_pid
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
force_deadline = time.monotonic() + force_after
|
||||
force_deadline = (time.monotonic() + force_after) if force_after is not None else None
|
||||
force_sent = False
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
pid = get_running_pid()
|
||||
if pid is None:
|
||||
return # Process exited cleanly.
|
||||
return True # Process exited cleanly.
|
||||
|
||||
if not force_sent and time.monotonic() >= force_deadline:
|
||||
if force_after is not None and not force_sent and time.monotonic() >= force_deadline:
|
||||
# Grace period expired — force-kill the specific PID.
|
||||
try:
|
||||
terminate_pid(pid, force=True)
|
||||
print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL")
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
return # Already gone or we can't touch it.
|
||||
return True # Already gone or we can't touch it.
|
||||
force_sent = True
|
||||
|
||||
time.sleep(0.3)
|
||||
@@ -1249,15 +1339,30 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
||||
remaining_pid = get_running_pid()
|
||||
if remaining_pid is not None:
|
||||
print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def launchd_restart():
|
||||
label = get_launchd_label()
|
||||
target = f"{_launchd_domain()}/{label}"
|
||||
# Use kickstart -k so launchd performs an atomic kill+restart.
|
||||
# A two-step stop/start from inside the gateway's own process tree
|
||||
# would kill the shell before the start command is reached.
|
||||
drain_timeout = _get_restart_drain_timeout()
|
||||
from gateway.status import get_running_pid
|
||||
|
||||
try:
|
||||
pid = get_running_pid()
|
||||
if pid is not None and _request_gateway_self_restart(pid):
|
||||
print("✓ Service restart requested")
|
||||
return
|
||||
if pid is not None:
|
||||
try:
|
||||
terminate_pid(pid, force=False)
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pid = None
|
||||
if pid is not None:
|
||||
exited = _wait_for_gateway_exit(timeout=drain_timeout, force_after=None)
|
||||
if not exited:
|
||||
print(f"⚠ Gateway drain timed out after {drain_timeout:.0f}s — forcing launchd restart")
|
||||
subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90)
|
||||
print("✓ Service restarted")
|
||||
except subprocess.CalledProcessError as e:
|
||||
@@ -1728,6 +1833,8 @@ def _runtime_health_lines() -> list[str]:
|
||||
lines: list[str] = []
|
||||
gateway_state = state.get("gateway_state")
|
||||
exit_reason = state.get("exit_reason")
|
||||
active_agents = state.get("active_agents")
|
||||
restart_requested = state.get("restart_requested")
|
||||
platforms = state.get("platforms", {}) or {}
|
||||
|
||||
for platform, pdata in platforms.items():
|
||||
@@ -1737,6 +1844,10 @@ def _runtime_health_lines() -> list[str]:
|
||||
|
||||
if gateway_state == "startup_failed" and exit_reason:
|
||||
lines.append(f"⚠ Last startup issue: {exit_reason}")
|
||||
elif gateway_state == "draining":
|
||||
action = "restart" if restart_requested else "shutdown"
|
||||
count = int(active_agents or 0)
|
||||
lines.append(f"⏳ Gateway draining for {action} ({count} active agent(s))")
|
||||
elif gateway_state == "stopped" and exit_reason:
|
||||
lines.append(f"⚠ Last shutdown reason: {exit_reason}")
|
||||
|
||||
|
||||
110
tests/gateway/restart_test_helpers.py
Normal file
110
tests/gateway/restart_test_helpers.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class RestartTestAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
self.sent: list[str] = []
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
self.sent.append(content)
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
)
|
||||
|
||||
|
||||
def make_restart_runner(
|
||||
adapter: BasePlatformAdapter | None = None,
|
||||
) -> tuple[GatewayRunner, BasePlatformAdapter]:
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._exit_code = None
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._background_tasks = set()
|
||||
runner._draining = False
|
||||
runner._restart_requested = False
|
||||
runner._restart_task_started = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_drain_timeout = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
runner._stop_task = None
|
||||
runner._busy_input_mode = "interrupt"
|
||||
runner._update_prompt_pending = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
runner._update_runtime_status = MagicMock()
|
||||
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._handle_active_session_busy_message = (
|
||||
GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
|
||||
)
|
||||
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._status_action_label = GatewayRunner._status_action_label.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__(
|
||||
runner, GatewayRunner
|
||||
)
|
||||
runner.request_restart = GatewayRunner.request_restart.__get__(runner, GatewayRunner)
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.pairing_store = MagicMock()
|
||||
runner.session_store = MagicMock()
|
||||
runner.delivery_router = MagicMock()
|
||||
|
||||
platform_adapter = adapter or RestartTestAdapter()
|
||||
platform_adapter.set_message_handler(AsyncMock(return_value=None))
|
||||
platform_adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
|
||||
runner.adapters = {Platform.TELEGRAM: platform_adapter}
|
||||
return runner, platform_adapter
|
||||
@@ -3,43 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456", chat_type="dm"):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
)
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||
from gateway.session import build_session_key
|
||||
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||
adapter = StubAdapter()
|
||||
_runner, adapter = make_restart_runner()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
@@ -47,7 +19,7 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
||||
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
@@ -65,17 +37,11 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner, adapter = make_restart_runner()
|
||||
runner._pending_messages = {"session": "pending text"}
|
||||
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
||||
runner._background_tasks = set()
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
runner._restart_drain_timeout = 0.0
|
||||
|
||||
adapter = StubAdapter()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def block_forever(_event):
|
||||
@@ -83,7 +49,7 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
||||
return None
|
||||
|
||||
adapter.set_message_handler(block_forever)
|
||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
||||
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@@ -93,7 +59,6 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
||||
session_key = build_session_key(event.source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {session_key: running_agent}
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
@@ -105,3 +70,78 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
||||
assert runner._pending_messages == {}
|
||||
assert runner._pending_approvals == {}
|
||||
assert runner._shutdown_event.is_set() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_drains_running_agents_before_disconnect():
|
||||
runner, adapter = make_restart_runner()
|
||||
disconnect_mock = AsyncMock()
|
||||
adapter.disconnect = disconnect_mock
|
||||
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {"session": running_agent}
|
||||
|
||||
async def finish_agent():
|
||||
await asyncio.sleep(0.05)
|
||||
runner._running_agents.clear()
|
||||
|
||||
asyncio.create_task(finish_agent())
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
running_agent.interrupt.assert_not_called()
|
||||
disconnect_mock.assert_awaited_once()
|
||||
assert runner._shutdown_event.is_set() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_interrupts_after_drain_timeout():
|
||||
runner, adapter = make_restart_runner()
|
||||
runner._restart_drain_timeout = 0.05
|
||||
|
||||
disconnect_mock = AsyncMock()
|
||||
adapter.disconnect = disconnect_mock
|
||||
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents = {"session": running_agent}
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
||||
disconnect_mock.assert_awaited_once()
|
||||
assert runner._shutdown_event.is_set() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_stop_service_restart_sets_named_exit_code():
|
||||
runner, adapter = make_restart_runner()
|
||||
adapter.disconnect = AsyncMock()
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop(restart=True, service_restart=True)
|
||||
|
||||
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_active_agents_throttles_status_updates():
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner._update_runtime_status = MagicMock()
|
||||
|
||||
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
|
||||
|
||||
async def finish_agents():
|
||||
await asyncio.sleep(0.12)
|
||||
runner._running_agents.pop("a")
|
||||
await asyncio.sleep(0.12)
|
||||
runner._running_agents.clear()
|
||||
|
||||
task = asyncio.create_task(finish_agents())
|
||||
await runner._drain_active_agents(1.0)
|
||||
await task
|
||||
|
||||
# Start, one count-change update, and final update. Allow one extra update
|
||||
# if the loop observes the zero-agent state before exiting.
|
||||
assert 3 <= runner._update_runtime_status.call_count <= 4
|
||||
|
||||
160
tests/gateway/test_restart_drain.py
Normal file
160
tests/gateway/test_restart_drain.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import asyncio
|
||||
import shutil
|
||||
import subprocess
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import gateway.run as gateway_run
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
from gateway.session import build_session_key
|
||||
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_command_while_busy_requests_drain_without_interrupt():
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.request_restart = MagicMock(return_value=True)
|
||||
event = MessageEvent(
|
||||
text="/restart",
|
||||
message_type=MessageType.TEXT,
|
||||
source=make_restart_source(),
|
||||
message_id="m1",
|
||||
)
|
||||
session_key = build_session_key(event.source)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[session_key] = running_agent
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == "⏳ Draining 1 active agent(s) before restart..."
|
||||
running_agent.interrupt.assert_not_called()
|
||||
runner.request_restart.assert_called_once_with(detached=True, via_service=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_queue_mode_queues_follow_up_without_interrupt():
|
||||
runner, adapter = make_restart_runner()
|
||||
runner._draining = True
|
||||
runner._restart_requested = True
|
||||
runner._busy_input_mode = "queue"
|
||||
|
||||
event = MessageEvent(
|
||||
text="follow up",
|
||||
message_type=MessageType.TEXT,
|
||||
source=make_restart_source(),
|
||||
message_id="m2",
|
||||
)
|
||||
session_key = build_session_key(event.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(event)
|
||||
|
||||
assert session_key in adapter._pending_messages
|
||||
assert adapter._pending_messages[session_key].text == "follow up"
|
||||
assert not adapter._active_sessions[session_key].is_set()
|
||||
assert any("queued for the next turn" in message for message in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draining_rejects_new_session_messages():
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner._draining = True
|
||||
runner._restart_requested = True
|
||||
|
||||
event = MessageEvent(
|
||||
text="hello",
|
||||
message_type=MessageType.TEXT,
|
||||
source=make_restart_source("fresh"),
|
||||
message_id="m3",
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == "⏳ Gateway is restarting and is not accepting new work right now."
|
||||
|
||||
|
||||
def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
|
||||
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n busy_input_mode: queue\n", encoding="utf-8"
|
||||
)
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
|
||||
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||
|
||||
|
||||
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
|
||||
tmp_path, monkeypatch, caplog
|
||||
):
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||
|
||||
assert (
|
||||
gateway_run.GatewayRunner._load_restart_drain_timeout()
|
||||
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"agent:\n restart_drain_timeout: 12\n", encoding="utf-8"
|
||||
)
|
||||
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0
|
||||
|
||||
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7")
|
||||
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0
|
||||
|
||||
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
|
||||
assert (
|
||||
gateway_run.GatewayRunner._load_restart_drain_timeout()
|
||||
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
assert "Invalid restart_drain_timeout" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_restart_is_idempotent():
|
||||
runner, _adapter = make_restart_runner()
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
assert runner.request_restart(detached=True, via_service=False) is True
|
||||
first_task = next(iter(runner._background_tasks))
|
||||
assert runner.request_restart(detached=True, via_service=False) is False
|
||||
|
||||
await first_task
|
||||
|
||||
runner.stop.assert_awaited_once_with(
|
||||
restart=True, detached_restart=True, service_restart=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_launch_detached_restart_command_uses_setsid(monkeypatch):
|
||||
runner, _adapter = make_restart_runner()
|
||||
popen_calls = []
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"])
|
||||
monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321)
|
||||
monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None)
|
||||
|
||||
def fake_popen(cmd, **kwargs):
|
||||
popen_calls.append((cmd, kwargs))
|
||||
return MagicMock()
|
||||
|
||||
monkeypatch.setattr(subprocess, "Popen", fake_popen)
|
||||
|
||||
await runner._launch_detached_restart_command()
|
||||
|
||||
assert len(popen_calls) == 1
|
||||
cmd, kwargs = popen_calls[0]
|
||||
assert cmd[:2] == ["/usr/bin/setsid", "bash"]
|
||||
assert "gateway restart" in cmd[-1]
|
||||
assert "kill -0 321" in cmd[-1]
|
||||
assert kwargs["start_new_session"] is True
|
||||
assert kwargs["stdout"] is subprocess.DEVNULL
|
||||
assert kwargs["stderr"] is subprocess.DEVNULL
|
||||
@@ -127,6 +127,16 @@ async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook):
|
||||
runner._shutdown_event = MagicMock()
|
||||
runner.adapters = {}
|
||||
runner._exit_reason = "test"
|
||||
runner._exit_code = None
|
||||
runner._draining = False
|
||||
runner._restart_requested = False
|
||||
runner._restart_task_started = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_drain_timeout = 0.0
|
||||
runner._stop_task = None
|
||||
runner._running_agents_ts = {}
|
||||
runner._update_runtime_status = MagicMock()
|
||||
|
||||
agent1 = MagicMock()
|
||||
agent1.session_id = "sess-a"
|
||||
|
||||
@@ -41,6 +41,15 @@ def _make_runner():
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
runner._background_tasks = set()
|
||||
runner._draining = False
|
||||
runner._restart_requested = False
|
||||
runner._restart_task_started = False
|
||||
runner._restart_detached = False
|
||||
runner._restart_via_service = False
|
||||
runner._restart_drain_timeout = 0.0
|
||||
runner._stop_task = None
|
||||
runner._exit_code = None
|
||||
runner._update_runtime_status = MagicMock()
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
|
||||
@@ -5,6 +5,10 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import hermes_cli.gateway as gateway_cli
|
||||
from gateway.restart import (
|
||||
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||
)
|
||||
|
||||
|
||||
class TestSystemdServiceRefresh:
|
||||
@@ -74,7 +78,7 @@ class TestSystemdServiceRefresh:
|
||||
assert unit_path.read_text(encoding="utf-8") == "new unit\n"
|
||||
assert calls[:2] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "restart", gateway_cli.get_service_name()],
|
||||
["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()],
|
||||
]
|
||||
|
||||
|
||||
@@ -84,6 +88,8 @@ class TestGeneratedSystemdUnits:
|
||||
|
||||
assert "ExecStart=" in unit
|
||||
assert "ExecStop=" not in unit
|
||||
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
||||
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
|
||||
assert "TimeoutStopSec=60" in unit
|
||||
|
||||
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
|
||||
@@ -98,6 +104,8 @@ class TestGeneratedSystemdUnits:
|
||||
|
||||
assert "ExecStart=" in unit
|
||||
assert "ExecStop=" not in unit
|
||||
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
||||
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
|
||||
assert "TimeoutStopSec=60" in unit
|
||||
assert "WantedBy=multi-user.target" in unit
|
||||
|
||||
@@ -157,6 +165,31 @@ class TestGatewayStopCleanup:
|
||||
|
||||
|
||||
class TestLaunchdServiceRecovery:
|
||||
def test_get_restart_drain_timeout_prefers_env_then_config_then_default(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||
monkeypatch.setattr(gateway_cli, "read_raw_config", lambda: {})
|
||||
|
||||
assert (
|
||||
gateway_cli._get_restart_drain_timeout()
|
||||
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"read_raw_config",
|
||||
lambda: {"agent": {"restart_drain_timeout": 14}},
|
||||
)
|
||||
assert gateway_cli._get_restart_drain_timeout() == 14.0
|
||||
|
||||
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "9")
|
||||
assert gateway_cli._get_restart_drain_timeout() == 9.0
|
||||
|
||||
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
|
||||
assert (
|
||||
gateway_cli._get_restart_drain_timeout()
|
||||
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||
)
|
||||
|
||||
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text("<plist>old content</plist>", encoding="utf-8")
|
||||
@@ -234,6 +267,55 @@ class TestLaunchdServiceRecovery:
|
||||
["launchctl", "kickstart", target],
|
||||
]
|
||||
|
||||
def test_launchd_restart_drains_running_gateway_before_kickstart(self, monkeypatch):
|
||||
calls = []
|
||||
target = f"{gateway_cli._launchd_domain()}/{gateway_cli.get_launchd_label()}"
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 12.0)
|
||||
monkeypatch.setattr(gateway_cli, "_request_gateway_self_restart", lambda pid: False)
|
||||
monkeypatch.setattr(gateway_cli, "_wait_for_gateway_exit", lambda timeout, force_after=None: True)
|
||||
monkeypatch.setattr(gateway_cli, "terminate_pid", lambda pid, force=False: calls.append(("term", pid, force)))
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.get_running_pid",
|
||||
lambda: 321,
|
||||
)
|
||||
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append(cmd)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
gateway_cli.launchd_restart()
|
||||
|
||||
assert calls == [
|
||||
("term", 321, False),
|
||||
["launchctl", "kickstart", "-k", target],
|
||||
]
|
||||
|
||||
def test_launchd_restart_self_requests_graceful_restart_without_kickstart(self, monkeypatch, capsys):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.get_running_pid",
|
||||
lambda: 321,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_request_gateway_self_restart",
|
||||
lambda pid: calls.append(("self", pid)) or True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess,
|
||||
"run",
|
||||
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("launchctl should not run")),
|
||||
)
|
||||
|
||||
gateway_cli.launchd_restart()
|
||||
|
||||
assert calls == [("self", 321)]
|
||||
assert "restart requested" in capsys.readouterr().out.lower()
|
||||
|
||||
def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch):
|
||||
"""launchd_stop must bootout the service so KeepAlive doesn't respawn it."""
|
||||
label = gateway_cli.get_launchd_label()
|
||||
@@ -337,6 +419,31 @@ class TestGatewayServiceDetection:
|
||||
|
||||
|
||||
class TestGatewaySystemServiceRouting:
|
||||
def test_systemd_restart_self_requests_graceful_restart_without_reload_or_restart(self, monkeypatch, capsys):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
|
||||
monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system)))
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.get_running_pid",
|
||||
lambda: 654,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"_request_gateway_self_restart",
|
||||
lambda pid: calls.append(("self", pid)) or True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess,
|
||||
"run",
|
||||
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("systemctl should not run")),
|
||||
)
|
||||
|
||||
gateway_cli.systemd_restart()
|
||||
|
||||
assert calls == [("refresh", False), ("self", 654)]
|
||||
assert "restart requested" in capsys.readouterr().out.lower()
|
||||
|
||||
def test_gateway_install_passes_system_flags(self, monkeypatch):
|
||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||
|
||||
Reference in New Issue
Block a user