mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-30 07:51:45 +08:00
Compare commits
5 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4c3a57a3a | ||
|
|
eb3f021e2a | ||
|
|
b4cb803954 | ||
|
|
b5928530d3 | ||
|
|
ef120b5422 |
@@ -480,6 +480,12 @@ agent:
|
|||||||
# Fires once per run when inactivity reaches this threshold (seconds).
|
# Fires once per run when inactivity reaches this threshold (seconds).
|
||||||
# Set to 0 to disable the warning.
|
# Set to 0 to disable the warning.
|
||||||
# gateway_timeout_warning: 900
|
# gateway_timeout_warning: 900
|
||||||
|
|
||||||
|
# Graceful drain timeout for gateway stop/restart (seconds).
|
||||||
|
# The gateway stops accepting new work, waits for in-flight agents to
|
||||||
|
# finish, then interrupts anything still running after this timeout.
|
||||||
|
# 0 = no drain, interrupt immediately.
|
||||||
|
# restart_drain_timeout: 60
|
||||||
|
|
||||||
# Enable verbose logging
|
# Enable verbose logging
|
||||||
verbose: false
|
verbose: false
|
||||||
|
|||||||
@@ -673,6 +673,32 @@ class SendResult:
|
|||||||
retryable: bool = False # True for transient connection errors — base will retry automatically
|
retryable: bool = False # True for transient connection errors — base will retry automatically
|
||||||
|
|
||||||
|
|
||||||
|
def merge_pending_message_event(
|
||||||
|
pending_messages: Dict[str, MessageEvent],
|
||||||
|
session_key: str,
|
||||||
|
event: MessageEvent,
|
||||||
|
) -> None:
|
||||||
|
"""Store or merge a pending event for a session.
|
||||||
|
|
||||||
|
Photo bursts/albums often arrive as multiple near-simultaneous PHOTO
|
||||||
|
events. Merge those into the existing queued event so the next turn sees
|
||||||
|
the whole burst, while non-photo follow-ups still replace the pending
|
||||||
|
event normally.
|
||||||
|
"""
|
||||||
|
existing = pending_messages.get(session_key)
|
||||||
|
if (
|
||||||
|
existing
|
||||||
|
and getattr(existing, "message_type", None) == MessageType.PHOTO
|
||||||
|
and event.message_type == MessageType.PHOTO
|
||||||
|
):
|
||||||
|
existing.media_urls.extend(event.media_urls)
|
||||||
|
existing.media_types.extend(event.media_types)
|
||||||
|
if event.text:
|
||||||
|
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
||||||
|
return
|
||||||
|
pending_messages[session_key] = event
|
||||||
|
|
||||||
|
|
||||||
# Error substrings that indicate a transient *connection* failure worth retrying.
|
# Error substrings that indicate a transient *connection* failure worth retrying.
|
||||||
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
|
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
|
||||||
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
|
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
|
||||||
@@ -727,6 +753,7 @@ class BasePlatformAdapter(ABC):
|
|||||||
# working on a task after --replace or manual restarts.
|
# working on a task after --replace or manual restarts.
|
||||||
self._background_tasks: set[asyncio.Task] = set()
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
self._expected_cancelled_tasks: set[asyncio.Task] = set()
|
self._expected_cancelled_tasks: set[asyncio.Task] = set()
|
||||||
|
self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None
|
||||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||||
self._auto_tts_disabled_chats: set = set()
|
self._auto_tts_disabled_chats: set = set()
|
||||||
# Chats where typing indicator is paused (e.g. during approval waits).
|
# Chats where typing indicator is paused (e.g. during approval waits).
|
||||||
@@ -815,6 +842,10 @@ class BasePlatformAdapter(ABC):
|
|||||||
an optional response string.
|
an optional response string.
|
||||||
"""
|
"""
|
||||||
self._message_handler = handler
|
self._message_handler = handler
|
||||||
|
|
||||||
|
def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None:
|
||||||
|
"""Set an optional handler for messages arriving during active sessions."""
|
||||||
|
self._busy_session_handler = handler
|
||||||
|
|
||||||
def set_session_store(self, session_store: Any) -> None:
|
def set_session_store(self, session_store: Any) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -1396,7 +1427,7 @@ class BasePlatformAdapter(ABC):
|
|||||||
# session lifecycle and its cleanup races with the running task
|
# session lifecycle and its cleanup races with the running task
|
||||||
# (see PR #4926).
|
# (see PR #4926).
|
||||||
cmd = event.get_command()
|
cmd = event.get_command()
|
||||||
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background"):
|
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background", "restart"):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"[%s] Command '/%s' bypassing active-session guard for %s",
|
"[%s] Command '/%s' bypassing active-session guard for %s",
|
||||||
self.name, cmd, session_key,
|
self.name, cmd, session_key,
|
||||||
@@ -1415,19 +1446,19 @@ class BasePlatformAdapter(ABC):
|
|||||||
logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True)
|
logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self._busy_session_handler is not None:
|
||||||
|
try:
|
||||||
|
if await self._busy_session_handler(event, session_key):
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("[%s] Busy-session handler failed: %s", self.name, e, exc_info=True)
|
||||||
|
|
||||||
# Special case: photo bursts/albums frequently arrive as multiple near-
|
# Special case: photo bursts/albums frequently arrive as multiple near-
|
||||||
# simultaneous messages. Queue them without interrupting the active run,
|
# simultaneous messages. Queue them without interrupting the active run,
|
||||||
# then process them immediately after the current task finishes.
|
# then process them immediately after the current task finishes.
|
||||||
if event.message_type == MessageType.PHOTO:
|
if event.message_type == MessageType.PHOTO:
|
||||||
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
|
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
|
||||||
existing = self._pending_messages.get(session_key)
|
merge_pending_message_event(self._pending_messages, session_key, event)
|
||||||
if existing and existing.message_type == MessageType.PHOTO:
|
|
||||||
existing.media_urls.extend(event.media_urls)
|
|
||||||
existing.media_types.extend(event.media_types)
|
|
||||||
if event.text:
|
|
||||||
existing.text = self._merge_caption(existing.text, event.text)
|
|
||||||
else:
|
|
||||||
self._pending_messages[session_key] = event
|
|
||||||
return # Don't interrupt now - will run after current task completes
|
return # Don't interrupt now - will run after current task completes
|
||||||
|
|
||||||
# Default behavior for non-photo follow-ups: interrupt the running agent
|
# Default behavior for non-photo follow-ups: interrupt the running agent
|
||||||
|
|||||||
20
gateway/restart.py
Normal file
20
gateway/restart.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""Shared gateway restart constants and parsing helpers."""
|
||||||
|
|
||||||
|
from hermes_cli.config import DEFAULT_CONFIG
|
||||||
|
|
||||||
|
# EX_TEMPFAIL from sysexits.h — used to ask the service manager to restart
|
||||||
|
# the gateway after a graceful drain/reload path completes.
|
||||||
|
GATEWAY_SERVICE_RESTART_EXIT_CODE = 75
|
||||||
|
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT = float(
|
||||||
|
DEFAULT_CONFIG["agent"]["restart_drain_timeout"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_restart_drain_timeout(raw: object) -> float:
|
||||||
|
"""Parse a configured drain timeout, falling back to the shared default."""
|
||||||
|
try:
|
||||||
|
value = float(raw) if str(raw or "").strip() else DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
return max(0.0, value)
|
||||||
499
gateway/run.py
499
gateway/run.py
@@ -186,6 +186,12 @@ if _config_path.exists():
|
|||||||
os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"])
|
os.environ["HERMES_AGENT_TIMEOUT"] = str(_agent_cfg["gateway_timeout"])
|
||||||
if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ:
|
if "gateway_timeout_warning" in _agent_cfg and "HERMES_AGENT_TIMEOUT_WARNING" not in os.environ:
|
||||||
os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"])
|
os.environ["HERMES_AGENT_TIMEOUT_WARNING"] = str(_agent_cfg["gateway_timeout_warning"])
|
||||||
|
if "restart_drain_timeout" in _agent_cfg and "HERMES_RESTART_DRAIN_TIMEOUT" not in os.environ:
|
||||||
|
os.environ["HERMES_RESTART_DRAIN_TIMEOUT"] = str(_agent_cfg["restart_drain_timeout"])
|
||||||
|
_display_cfg = _cfg.get("display", {})
|
||||||
|
if _display_cfg and isinstance(_display_cfg, dict):
|
||||||
|
if "busy_input_mode" in _display_cfg and "HERMES_GATEWAY_BUSY_INPUT_MODE" not in os.environ:
|
||||||
|
os.environ["HERMES_GATEWAY_BUSY_INPUT_MODE"] = str(_display_cfg["busy_input_mode"])
|
||||||
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
|
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
|
||||||
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
|
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
|
||||||
_tz_cfg = _cfg.get("timezone", "")
|
_tz_cfg = _cfg.get("timezone", "")
|
||||||
@@ -235,7 +241,17 @@ from gateway.session import (
|
|||||||
build_session_key,
|
build_session_key,
|
||||||
)
|
)
|
||||||
from gateway.delivery import DeliveryRouter
|
from gateway.delivery import DeliveryRouter
|
||||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
|
from gateway.platforms.base import (
|
||||||
|
BasePlatformAdapter,
|
||||||
|
MessageEvent,
|
||||||
|
MessageType,
|
||||||
|
merge_pending_message_event,
|
||||||
|
)
|
||||||
|
from gateway.restart import (
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||||
|
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||||
|
parse_restart_drain_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _normalize_whatsapp_identifier(value: str) -> str:
|
def _normalize_whatsapp_identifier(value: str) -> str:
|
||||||
@@ -471,6 +487,16 @@ class GatewayRunner:
|
|||||||
# Class-level defaults so partial construction in tests doesn't
|
# Class-level defaults so partial construction in tests doesn't
|
||||||
# blow up on attribute access.
|
# blow up on attribute access.
|
||||||
_running_agents_ts: Dict[str, float] = {}
|
_running_agents_ts: Dict[str, float] = {}
|
||||||
|
_busy_input_mode: str = "interrupt"
|
||||||
|
_restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
_exit_code: Optional[int] = None
|
||||||
|
_draining: bool = False
|
||||||
|
_restart_requested: bool = False
|
||||||
|
_restart_task_started: bool = False
|
||||||
|
_restart_detached: bool = False
|
||||||
|
_restart_via_service: bool = False
|
||||||
|
_stop_task: Optional[asyncio.Task] = None
|
||||||
|
_session_model_overrides: Dict[str, Dict[str, str]] = {}
|
||||||
|
|
||||||
def __init__(self, config: Optional[GatewayConfig] = None):
|
def __init__(self, config: Optional[GatewayConfig] = None):
|
||||||
self.config = config or load_gateway_config()
|
self.config = config or load_gateway_config()
|
||||||
@@ -483,6 +509,8 @@ class GatewayRunner:
|
|||||||
self._reasoning_config = self._load_reasoning_config()
|
self._reasoning_config = self._load_reasoning_config()
|
||||||
self._service_tier = self._load_service_tier()
|
self._service_tier = self._load_service_tier()
|
||||||
self._show_reasoning = self._load_show_reasoning()
|
self._show_reasoning = self._load_show_reasoning()
|
||||||
|
self._busy_input_mode = self._load_busy_input_mode()
|
||||||
|
self._restart_drain_timeout = self._load_restart_drain_timeout()
|
||||||
self._provider_routing = self._load_provider_routing()
|
self._provider_routing = self._load_provider_routing()
|
||||||
self._fallback_model = self._load_fallback_model()
|
self._fallback_model = self._load_fallback_model()
|
||||||
self._smart_model_routing = self._load_smart_model_routing()
|
self._smart_model_routing = self._load_smart_model_routing()
|
||||||
@@ -499,6 +527,13 @@ class GatewayRunner:
|
|||||||
self._exit_cleanly = False
|
self._exit_cleanly = False
|
||||||
self._exit_with_failure = False
|
self._exit_with_failure = False
|
||||||
self._exit_reason: Optional[str] = None
|
self._exit_reason: Optional[str] = None
|
||||||
|
self._exit_code: Optional[int] = None
|
||||||
|
self._draining = False
|
||||||
|
self._restart_requested = False
|
||||||
|
self._restart_task_started = False
|
||||||
|
self._restart_detached = False
|
||||||
|
self._restart_via_service = False
|
||||||
|
self._stop_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
# Track running agents per session for interrupt support
|
# Track running agents per session for interrupt support
|
||||||
# Key: session_key, Value: AIAgent instance
|
# Key: session_key, Value: AIAgent instance
|
||||||
@@ -759,6 +794,10 @@ class GatewayRunner:
|
|||||||
def exit_reason(self) -> Optional[str]:
|
def exit_reason(self) -> Optional[str]:
|
||||||
return self._exit_reason
|
return self._exit_reason
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exit_code(self) -> Optional[int]:
|
||||||
|
return self._exit_code
|
||||||
|
|
||||||
def _session_key_for_source(self, source: SessionSource) -> str:
|
def _session_key_for_source(self, source: SessionSource) -> str:
|
||||||
"""Resolve the current session key for a source, honoring gateway config when available."""
|
"""Resolve the current session key for a source, honoring gateway config when available."""
|
||||||
if hasattr(self, "session_store") and self.session_store is not None:
|
if hasattr(self, "session_store") and self.session_store is not None:
|
||||||
@@ -868,6 +907,30 @@ class GatewayRunner:
|
|||||||
self._exit_cleanly = True
|
self._exit_cleanly = True
|
||||||
self._exit_reason = reason
|
self._exit_reason = reason
|
||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
def _running_agent_count(self) -> int:
|
||||||
|
return len(self._running_agents)
|
||||||
|
|
||||||
|
def _status_action_label(self) -> str:
|
||||||
|
return "restart" if self._restart_requested else "shutdown"
|
||||||
|
|
||||||
|
def _status_action_gerund(self) -> str:
|
||||||
|
return "restarting" if self._restart_requested else "shutting down"
|
||||||
|
|
||||||
|
def _queue_during_drain_enabled(self) -> bool:
|
||||||
|
return self._restart_requested and self._busy_input_mode == "queue"
|
||||||
|
|
||||||
|
def _update_runtime_status(self, gateway_state: Optional[str] = None, exit_reason: Optional[str] = None) -> None:
|
||||||
|
try:
|
||||||
|
from gateway.status import write_runtime_status
|
||||||
|
write_runtime_status(
|
||||||
|
gateway_state=gateway_state,
|
||||||
|
exit_reason=exit_reason,
|
||||||
|
restart_requested=self._restart_requested,
|
||||||
|
active_agents=self._running_agent_count(),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
||||||
@@ -994,6 +1057,48 @@ class GatewayRunner:
|
|||||||
pass
|
pass
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_busy_input_mode() -> str:
|
||||||
|
"""Load gateway drain-time busy-input behavior from config/env."""
|
||||||
|
mode = os.getenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "").strip().lower()
|
||||||
|
if not mode:
|
||||||
|
try:
|
||||||
|
import yaml as _y
|
||||||
|
cfg_path = _hermes_home / "config.yaml"
|
||||||
|
if cfg_path.exists():
|
||||||
|
with open(cfg_path, encoding="utf-8") as _f:
|
||||||
|
cfg = _y.safe_load(_f) or {}
|
||||||
|
mode = str(cfg.get("display", {}).get("busy_input_mode", "") or "").strip().lower()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return "queue" if mode == "queue" else "interrupt"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_restart_drain_timeout() -> float:
|
||||||
|
"""Load graceful gateway restart/stop drain timeout in seconds."""
|
||||||
|
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
|
||||||
|
if not raw:
|
||||||
|
try:
|
||||||
|
import yaml as _y
|
||||||
|
cfg_path = _hermes_home / "config.yaml"
|
||||||
|
if cfg_path.exists():
|
||||||
|
with open(cfg_path, encoding="utf-8") as _f:
|
||||||
|
cfg = _y.safe_load(_f) or {}
|
||||||
|
raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
value = parse_restart_drain_timeout(raw)
|
||||||
|
if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT:
|
||||||
|
try:
|
||||||
|
float(raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.warning(
|
||||||
|
"Invalid restart_drain_timeout '%s', using default %.0fs",
|
||||||
|
raw,
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_background_notifications_mode() -> str:
|
def _load_background_notifications_mode() -> str:
|
||||||
"""Load background process notification mode from config or env var.
|
"""Load background process notification mode from config or env var.
|
||||||
@@ -1078,6 +1183,155 @@ class GatewayRunner:
|
|||||||
pass
|
pass
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def _snapshot_running_agents(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
session_key: agent
|
||||||
|
for session_key, agent in self._running_agents.items()
|
||||||
|
if agent is not _AGENT_PENDING_SENTINEL
|
||||||
|
}
|
||||||
|
|
||||||
|
def _queue_or_replace_pending_event(self, session_key: str, event: MessageEvent) -> None:
|
||||||
|
adapter = self.adapters.get(event.source.platform)
|
||||||
|
if not adapter:
|
||||||
|
return
|
||||||
|
merge_pending_message_event(adapter._pending_messages, session_key, event)
|
||||||
|
|
||||||
|
async def _handle_active_session_busy_message(self, event: MessageEvent, session_key: str) -> bool:
|
||||||
|
if not self._draining:
|
||||||
|
return False
|
||||||
|
|
||||||
|
adapter = self.adapters.get(event.source.platform)
|
||||||
|
if not adapter:
|
||||||
|
return True
|
||||||
|
|
||||||
|
thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||||
|
if self._queue_during_drain_enabled():
|
||||||
|
self._queue_or_replace_pending_event(session_key, event)
|
||||||
|
message = f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
|
||||||
|
else:
|
||||||
|
message = f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
|
||||||
|
|
||||||
|
await adapter._send_with_retry(
|
||||||
|
chat_id=event.source.chat_id,
|
||||||
|
content=message,
|
||||||
|
reply_to=event.message_id,
|
||||||
|
metadata=thread_meta,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _drain_active_agents(self, timeout: float) -> tuple[Dict[str, Any], bool]:
|
||||||
|
snapshot = self._snapshot_running_agents()
|
||||||
|
last_active_count = self._running_agent_count()
|
||||||
|
last_status_at = 0.0
|
||||||
|
|
||||||
|
def _maybe_update_status(force: bool = False) -> None:
|
||||||
|
nonlocal last_active_count, last_status_at
|
||||||
|
now = asyncio.get_running_loop().time()
|
||||||
|
active_count = self._running_agent_count()
|
||||||
|
if force or active_count != last_active_count or (now - last_status_at) >= 1.0:
|
||||||
|
self._update_runtime_status("draining")
|
||||||
|
last_active_count = active_count
|
||||||
|
last_status_at = now
|
||||||
|
|
||||||
|
if not self._running_agents:
|
||||||
|
_maybe_update_status(force=True)
|
||||||
|
return snapshot, False
|
||||||
|
|
||||||
|
_maybe_update_status(force=True)
|
||||||
|
if timeout <= 0:
|
||||||
|
return snapshot, True
|
||||||
|
|
||||||
|
deadline = asyncio.get_running_loop().time() + timeout
|
||||||
|
while self._running_agents and asyncio.get_running_loop().time() < deadline:
|
||||||
|
_maybe_update_status()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
timed_out = bool(self._running_agents)
|
||||||
|
_maybe_update_status(force=True)
|
||||||
|
return snapshot, timed_out
|
||||||
|
|
||||||
|
def _interrupt_running_agents(self, reason: str) -> None:
|
||||||
|
for session_key, agent in list(self._running_agents.items()):
|
||||||
|
if agent is _AGENT_PENDING_SENTINEL:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
agent.interrupt(reason)
|
||||||
|
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
||||||
|
|
||||||
|
def _finalize_shutdown_agents(self, active_agents: Dict[str, Any]) -> None:
|
||||||
|
for agent in active_agents.values():
|
||||||
|
try:
|
||||||
|
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||||
|
_invoke_hook(
|
||||||
|
"on_session_finalize",
|
||||||
|
session_id=getattr(agent, "session_id", None),
|
||||||
|
platform="gateway",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if hasattr(agent, "shutdown_memory_provider"):
|
||||||
|
agent.shutdown_memory_provider()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Close tool resources (terminal sandboxes, browser daemons,
|
||||||
|
# background processes, httpx clients) to prevent zombie
|
||||||
|
# process accumulation.
|
||||||
|
try:
|
||||||
|
if hasattr(agent, 'close'):
|
||||||
|
agent.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _launch_detached_restart_command(self) -> None:
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
hermes_cmd = _resolve_hermes_bin()
|
||||||
|
if not hermes_cmd:
|
||||||
|
logger.error("Could not locate hermes binary for detached /restart")
|
||||||
|
return
|
||||||
|
|
||||||
|
current_pid = os.getpid()
|
||||||
|
cmd = " ".join(shlex.quote(part) for part in hermes_cmd)
|
||||||
|
shell_cmd = (
|
||||||
|
f"while kill -0 {current_pid} 2>/dev/null; do sleep 0.2; done; "
|
||||||
|
f"{cmd} gateway restart"
|
||||||
|
)
|
||||||
|
setsid_bin = shutil.which("setsid")
|
||||||
|
if setsid_bin:
|
||||||
|
subprocess.Popen(
|
||||||
|
[setsid_bin, "bash", "-lc", shell_cmd],
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
start_new_session=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
subprocess.Popen(
|
||||||
|
["bash", "-lc", shell_cmd],
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
start_new_session=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def request_restart(self, *, detached: bool = False, via_service: bool = False) -> bool:
|
||||||
|
if self._restart_task_started:
|
||||||
|
return False
|
||||||
|
self._restart_requested = True
|
||||||
|
self._restart_detached = detached
|
||||||
|
self._restart_via_service = via_service
|
||||||
|
self._restart_task_started = True
|
||||||
|
|
||||||
|
async def _run_restart() -> None:
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
await self.stop(restart=True, detached_restart=detached, service_restart=via_service)
|
||||||
|
|
||||||
|
task = asyncio.create_task(_run_restart())
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
return True
|
||||||
|
|
||||||
async def start(self) -> bool:
|
async def start(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Start the gateway and all configured platform adapters.
|
Start the gateway and all configured platform adapters.
|
||||||
@@ -1165,6 +1419,7 @@ class GatewayRunner:
|
|||||||
adapter.set_message_handler(self._handle_message)
|
adapter.set_message_handler(self._handle_message)
|
||||||
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
||||||
adapter.set_session_store(self.session_store)
|
adapter.set_session_store(self.session_store)
|
||||||
|
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
|
||||||
|
|
||||||
# Try to connect
|
# Try to connect
|
||||||
logger.info("Connecting to %s...", platform.value)
|
logger.info("Connecting to %s...", platform.value)
|
||||||
@@ -1240,11 +1495,7 @@ class GatewayRunner:
|
|||||||
self.delivery_router.adapters = self.adapters
|
self.delivery_router.adapters = self.adapters
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
try:
|
self._update_runtime_status("running")
|
||||||
from gateway.status import write_runtime_status
|
|
||||||
write_runtime_status(gateway_state="running", exit_reason=None)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Emit gateway:startup hook
|
# Emit gateway:startup hook
|
||||||
hook_count = len(self.hooks.loaded_hooks)
|
hook_count = len(self.hooks.loaded_hooks)
|
||||||
@@ -1479,6 +1730,7 @@ class GatewayRunner:
|
|||||||
adapter.set_message_handler(self._handle_message)
|
adapter.set_message_handler(self._handle_message)
|
||||||
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
||||||
adapter.set_session_store(self.session_store)
|
adapter.set_session_store(self.session_store)
|
||||||
|
adapter.set_busy_session_handler(self._handle_active_session_busy_message)
|
||||||
|
|
||||||
success = await adapter.connect()
|
success = await adapter.connect()
|
||||||
if success:
|
if success:
|
||||||
@@ -1525,90 +1777,108 @@ class GatewayRunner:
|
|||||||
return
|
return
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
restart: bool = False,
|
||||||
|
detached_restart: bool = False,
|
||||||
|
service_restart: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Stop the gateway and disconnect all adapters."""
|
"""Stop the gateway and disconnect all adapters."""
|
||||||
logger.info("Stopping gateway...")
|
if restart:
|
||||||
self._running = False
|
self._restart_requested = True
|
||||||
|
self._restart_detached = detached_restart
|
||||||
|
self._restart_via_service = service_restart
|
||||||
|
if self._stop_task is not None:
|
||||||
|
await self._stop_task
|
||||||
|
return
|
||||||
|
|
||||||
for session_key, agent in list(self._running_agents.items()):
|
async def _stop_impl() -> None:
|
||||||
if agent is _AGENT_PENDING_SENTINEL:
|
logger.info(
|
||||||
continue
|
"Stopping gateway%s...",
|
||||||
|
" for restart" if self._restart_requested else "",
|
||||||
|
)
|
||||||
|
self._running = False
|
||||||
|
self._draining = True
|
||||||
|
|
||||||
|
timeout = self._restart_drain_timeout
|
||||||
|
active_agents, timed_out = await self._drain_active_agents(timeout)
|
||||||
|
if timed_out:
|
||||||
|
logger.warning(
|
||||||
|
"Gateway drain timed out after %.1fs with %d active agent(s); interrupting remaining work.",
|
||||||
|
timeout,
|
||||||
|
self._running_agent_count(),
|
||||||
|
)
|
||||||
|
self._interrupt_running_agents(
|
||||||
|
"Gateway restarting" if self._restart_requested else "Gateway shutting down"
|
||||||
|
)
|
||||||
|
interrupt_deadline = asyncio.get_running_loop().time() + 5.0
|
||||||
|
while self._running_agents and asyncio.get_running_loop().time() < interrupt_deadline:
|
||||||
|
self._update_runtime_status("draining")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
if self._restart_requested and self._restart_detached:
|
||||||
|
try:
|
||||||
|
await self._launch_detached_restart_command()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to launch detached gateway restart: %s", e)
|
||||||
|
|
||||||
|
self._finalize_shutdown_agents(active_agents)
|
||||||
|
|
||||||
|
for platform, adapter in list(self.adapters.items()):
|
||||||
|
try:
|
||||||
|
await adapter.cancel_background_tasks()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("✗ %s background-task cancel error: %s", platform.value, e)
|
||||||
|
try:
|
||||||
|
await adapter.disconnect()
|
||||||
|
logger.info("✓ %s disconnected", platform.value)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("✗ %s disconnect error: %s", platform.value, e)
|
||||||
|
|
||||||
|
for _task in list(self._background_tasks):
|
||||||
|
if _task is self._stop_task:
|
||||||
|
continue
|
||||||
|
_task.cancel()
|
||||||
|
self._background_tasks.clear()
|
||||||
|
|
||||||
|
self.adapters.clear()
|
||||||
|
self._running_agents.clear()
|
||||||
|
self._pending_messages.clear()
|
||||||
|
self._pending_approvals.clear()
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
|
# Global cleanup: kill any remaining tool subprocesses not tied
|
||||||
|
# to a specific agent (catch-all for zombie prevention).
|
||||||
try:
|
try:
|
||||||
agent.interrupt("Gateway shutting down")
|
from tools.process_registry import process_registry
|
||||||
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
process_registry.kill_all()
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
|
||||||
# Fire plugin on_session_finalize hook before memory shutdown
|
|
||||||
try:
|
|
||||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
|
||||||
_invoke_hook("on_session_finalize",
|
|
||||||
session_id=getattr(agent, 'session_id', None),
|
|
||||||
platform="gateway")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# Shut down memory provider at actual session boundary
|
|
||||||
try:
|
try:
|
||||||
if hasattr(agent, 'shutdown_memory_provider'):
|
from tools.terminal_tool import cleanup_all_environments
|
||||||
agent.shutdown_memory_provider()
|
cleanup_all_environments()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# Close tool resources (terminal sandboxes, browser daemons,
|
|
||||||
# background processes, httpx clients) to prevent zombie
|
|
||||||
# process accumulation.
|
|
||||||
try:
|
try:
|
||||||
if hasattr(agent, 'close'):
|
from tools.browser_tool import cleanup_all_browsers
|
||||||
agent.close()
|
cleanup_all_browsers()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for platform, adapter in list(self.adapters.items()):
|
from gateway.status import remove_pid_file
|
||||||
try:
|
remove_pid_file()
|
||||||
await adapter.cancel_background_tasks()
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("✗ %s background-task cancel error: %s", platform.value, e)
|
|
||||||
try:
|
|
||||||
await adapter.disconnect()
|
|
||||||
logger.info("✓ %s disconnected", platform.value)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("✗ %s disconnect error: %s", platform.value, e)
|
|
||||||
|
|
||||||
# Cancel any pending background tasks
|
if self._restart_requested and self._restart_via_service:
|
||||||
for _task in list(self._background_tasks):
|
self._exit_code = GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||||
_task.cancel()
|
self._exit_reason = self._exit_reason or "Gateway restart requested"
|
||||||
self._background_tasks.clear()
|
|
||||||
|
|
||||||
self.adapters.clear()
|
self._draining = False
|
||||||
self._running_agents.clear()
|
self._update_runtime_status("stopped", self._exit_reason)
|
||||||
self._pending_messages.clear()
|
logger.info("Gateway stopped")
|
||||||
self._pending_approvals.clear()
|
|
||||||
self._shutdown_event.set()
|
|
||||||
|
|
||||||
# Global cleanup: kill any remaining tool subprocesses not tied
|
self._stop_task = asyncio.create_task(_stop_impl())
|
||||||
# to a specific agent (catch-all for zombie prevention).
|
await self._stop_task
|
||||||
try:
|
|
||||||
from tools.process_registry import process_registry
|
|
||||||
process_registry.kill_all()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
from tools.terminal_tool import cleanup_all_environments
|
|
||||||
cleanup_all_environments()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
from tools.browser_tool import cleanup_all_browsers
|
|
||||||
cleanup_all_browsers()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from gateway.status import remove_pid_file, write_runtime_status
|
|
||||||
remove_pid_file()
|
|
||||||
try:
|
|
||||||
write_runtime_status(gateway_state="stopped", exit_reason=self._exit_reason)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info("Gateway stopped")
|
|
||||||
|
|
||||||
async def wait_for_shutdown(self) -> None:
|
async def wait_for_shutdown(self) -> None:
|
||||||
"""Wait for shutdown signal."""
|
"""Wait for shutdown signal."""
|
||||||
@@ -2014,6 +2284,9 @@ class GatewayRunner:
|
|||||||
_evt_cmd = event.get_command()
|
_evt_cmd = event.get_command()
|
||||||
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
|
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
|
||||||
|
|
||||||
|
if _cmd_def_inner and _cmd_def_inner.name == "restart":
|
||||||
|
return await self._handle_restart_command(event)
|
||||||
|
|
||||||
# /stop must hard-kill the session when an agent is running.
|
# /stop must hard-kill the session when an agent is running.
|
||||||
# A soft interrupt (agent.interrupt()) doesn't help when the agent
|
# A soft interrupt (agent.interrupt()) doesn't help when the agent
|
||||||
# is truly hung — the executor thread is blocked and never checks
|
# is truly hung — the executor thread is blocked and never checks
|
||||||
@@ -2094,18 +2367,7 @@ class GatewayRunner:
|
|||||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||||
adapter = self.adapters.get(source.platform)
|
adapter = self.adapters.get(source.platform)
|
||||||
if adapter:
|
if adapter:
|
||||||
# Reuse adapter queue semantics so photo bursts merge cleanly.
|
merge_pending_message_event(adapter._pending_messages, _quick_key, event)
|
||||||
if _quick_key in adapter._pending_messages:
|
|
||||||
existing = adapter._pending_messages[_quick_key]
|
|
||||||
if getattr(existing, "message_type", None) == MessageType.PHOTO:
|
|
||||||
existing.media_urls.extend(event.media_urls)
|
|
||||||
existing.media_types.extend(event.media_types)
|
|
||||||
if event.text:
|
|
||||||
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
|
||||||
else:
|
|
||||||
adapter._pending_messages[_quick_key] = event
|
|
||||||
else:
|
|
||||||
adapter._pending_messages[_quick_key] = event
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
running_agent = self._running_agents.get(_quick_key)
|
running_agent = self._running_agents.get(_quick_key)
|
||||||
@@ -2123,6 +2385,14 @@ class GatewayRunner:
|
|||||||
if adapter:
|
if adapter:
|
||||||
adapter._pending_messages[_quick_key] = event
|
adapter._pending_messages[_quick_key] = event
|
||||||
return None
|
return None
|
||||||
|
if self._draining:
|
||||||
|
if self._queue_during_drain_enabled():
|
||||||
|
self._queue_or_replace_pending_event(_quick_key, event)
|
||||||
|
return (
|
||||||
|
f"⏳ Gateway {self._status_action_gerund()} — queued for the next turn after it comes back."
|
||||||
|
if self._queue_during_drain_enabled()
|
||||||
|
else f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
|
||||||
|
)
|
||||||
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
|
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
|
||||||
running_agent.interrupt(event.text)
|
running_agent.interrupt(event.text)
|
||||||
if _quick_key in self._pending_messages:
|
if _quick_key in self._pending_messages:
|
||||||
@@ -2164,6 +2434,9 @@ class GatewayRunner:
|
|||||||
|
|
||||||
if canonical == "status":
|
if canonical == "status":
|
||||||
return await self._handle_status_command(event)
|
return await self._handle_status_command(event)
|
||||||
|
|
||||||
|
if canonical == "restart":
|
||||||
|
return await self._handle_restart_command(event)
|
||||||
|
|
||||||
if canonical == "stop":
|
if canonical == "stop":
|
||||||
return await self._handle_stop_command(event)
|
return await self._handle_stop_command(event)
|
||||||
@@ -2262,6 +2535,9 @@ class GatewayRunner:
|
|||||||
if canonical == "voice":
|
if canonical == "voice":
|
||||||
return await self._handle_voice_command(event)
|
return await self._handle_voice_command(event)
|
||||||
|
|
||||||
|
if self._draining:
|
||||||
|
return f"⏳ Gateway is {self._status_action_gerund()} and is not accepting new work right now."
|
||||||
|
|
||||||
# User-defined quick commands (bypass agent loop, no LLM call)
|
# User-defined quick commands (bypass agent loop, no LLM call)
|
||||||
if command:
|
if command:
|
||||||
if isinstance(self.config, dict):
|
if isinstance(self.config, dict):
|
||||||
@@ -3556,7 +3832,21 @@ class GatewayRunner:
|
|||||||
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
|
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
|
||||||
else:
|
else:
|
||||||
return "No active task to stop."
|
return "No active task to stop."
|
||||||
|
|
||||||
|
async def _handle_restart_command(self, event: MessageEvent) -> str:
|
||||||
|
"""Handle /restart command - drain active work, then restart the gateway."""
|
||||||
|
if self._restart_requested or self._draining:
|
||||||
|
count = self._running_agent_count()
|
||||||
|
if count:
|
||||||
|
return f"⏳ Draining {count} active agent(s) before restart..."
|
||||||
|
return "⏳ Gateway restart already in progress..."
|
||||||
|
|
||||||
|
active_agents = self._running_agent_count()
|
||||||
|
self.request_restart(detached=True, via_service=False)
|
||||||
|
if active_agents:
|
||||||
|
return f"⏳ Draining {active_agents} active agent(s) before restart..."
|
||||||
|
return "♻ Restarting gateway..."
|
||||||
|
|
||||||
async def _handle_help_command(self, event: MessageEvent) -> str:
|
async def _handle_help_command(self, event: MessageEvent) -> str:
|
||||||
"""Handle /help command - list available commands."""
|
"""Handle /help command - list available commands."""
|
||||||
from hermes_cli.commands import gateway_help_lines
|
from hermes_cli.commands import gateway_help_lines
|
||||||
@@ -3679,7 +3969,7 @@ class GatewayRunner:
|
|||||||
# Check for session override
|
# Check for session override
|
||||||
source = event.source
|
source = event.source
|
||||||
session_key = self._session_key_for_source(source)
|
session_key = self._session_key_for_source(source)
|
||||||
override = getattr(self, "_session_model_overrides", {}).get(session_key, {})
|
override = self._session_model_overrides.get(session_key, {})
|
||||||
if override:
|
if override:
|
||||||
current_model = override.get("model", current_model)
|
current_model = override.get("model", current_model)
|
||||||
current_provider = override.get("provider", current_provider)
|
current_provider = override.get("provider", current_provider)
|
||||||
@@ -3761,8 +4051,6 @@ class GatewayRunner:
|
|||||||
f"via {result.provider_label or result.target_provider}. "
|
f"via {result.provider_label or result.target_provider}. "
|
||||||
f"Adjust your self-identification accordingly.]"
|
f"Adjust your self-identification accordingly.]"
|
||||||
)
|
)
|
||||||
if not hasattr(_self, "_session_model_overrides"):
|
|
||||||
_self._session_model_overrides = {}
|
|
||||||
_self._session_model_overrides[_session_key] = {
|
_self._session_model_overrides[_session_key] = {
|
||||||
"model": result.new_model,
|
"model": result.new_model,
|
||||||
"provider": result.target_provider,
|
"provider": result.target_provider,
|
||||||
@@ -3876,8 +4164,6 @@ class GatewayRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store session override so next agent creation uses the new model
|
# Store session override so next agent creation uses the new model
|
||||||
if not hasattr(self, "_session_model_overrides"):
|
|
||||||
self._session_model_overrides = {}
|
|
||||||
self._session_model_overrides[session_key] = {
|
self._session_model_overrides[session_key] = {
|
||||||
"model": result.new_model,
|
"model": result.new_model,
|
||||||
"provider": result.target_provider,
|
"provider": result.target_provider,
|
||||||
@@ -7363,6 +7649,8 @@ class GatewayRunner:
|
|||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
if session_key:
|
if session_key:
|
||||||
self._running_agents[session_key] = agent_holder[0]
|
self._running_agents[session_key] = agent_holder[0]
|
||||||
|
if self._draining:
|
||||||
|
self._update_runtime_status("draining")
|
||||||
|
|
||||||
tracking_task = asyncio.create_task(track_agent())
|
tracking_task = asyncio.create_task(track_agent())
|
||||||
|
|
||||||
@@ -7608,6 +7896,14 @@ class GatewayRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if self._draining and pending:
|
||||||
|
logger.info(
|
||||||
|
"Discarding pending follow-up for session %s during gateway %s",
|
||||||
|
session_key[:20] if session_key else "?",
|
||||||
|
self._status_action_label(),
|
||||||
|
)
|
||||||
|
pending = None
|
||||||
|
|
||||||
if pending:
|
if pending:
|
||||||
logger.debug("Processing pending message: '%s...'", pending[:40])
|
logger.debug("Processing pending message: '%s...'", pending[:40])
|
||||||
|
|
||||||
@@ -7684,6 +7980,8 @@ class GatewayRunner:
|
|||||||
del self._running_agents[session_key]
|
del self._running_agents[session_key]
|
||||||
if session_key:
|
if session_key:
|
||||||
self._running_agents_ts.pop(session_key, None)
|
self._running_agents_ts.pop(session_key, None)
|
||||||
|
if self._draining:
|
||||||
|
self._update_runtime_status("draining")
|
||||||
|
|
||||||
# Wait for cancelled tasks
|
# Wait for cancelled tasks
|
||||||
for task in [progress_task, interrupt_monitor, tracking_task, _notify_task]:
|
for task in [progress_task, interrupt_monitor, tracking_task, _notify_task]:
|
||||||
@@ -7881,13 +8179,21 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
|||||||
runner = GatewayRunner(config)
|
runner = GatewayRunner(config)
|
||||||
|
|
||||||
# Set up signal handlers
|
# Set up signal handlers
|
||||||
def signal_handler():
|
def shutdown_signal_handler():
|
||||||
asyncio.create_task(runner.stop())
|
asyncio.create_task(runner.stop())
|
||||||
|
|
||||||
|
def restart_signal_handler():
|
||||||
|
runner.request_restart(detached=False, via_service=True)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||||
try:
|
try:
|
||||||
loop.add_signal_handler(sig, signal_handler)
|
loop.add_signal_handler(sig, shutdown_signal_handler)
|
||||||
|
except NotImplementedError:
|
||||||
|
pass
|
||||||
|
if hasattr(signal, "SIGUSR1"):
|
||||||
|
try:
|
||||||
|
loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -7937,6 +8243,9 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if runner.exit_code is not None:
|
||||||
|
raise SystemExit(runner.exit_code)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,8 @@ def _build_runtime_status_record() -> dict[str, Any]:
|
|||||||
payload.update({
|
payload.update({
|
||||||
"gateway_state": "starting",
|
"gateway_state": "starting",
|
||||||
"exit_reason": None,
|
"exit_reason": None,
|
||||||
|
"restart_requested": False,
|
||||||
|
"active_agents": 0,
|
||||||
"platforms": {},
|
"platforms": {},
|
||||||
"updated_at": _utc_now_iso(),
|
"updated_at": _utc_now_iso(),
|
||||||
})
|
})
|
||||||
@@ -218,6 +220,8 @@ def write_runtime_status(
|
|||||||
*,
|
*,
|
||||||
gateway_state: Optional[str] = None,
|
gateway_state: Optional[str] = None,
|
||||||
exit_reason: Optional[str] = None,
|
exit_reason: Optional[str] = None,
|
||||||
|
restart_requested: Optional[bool] = None,
|
||||||
|
active_agents: Optional[int] = None,
|
||||||
platform: Optional[str] = None,
|
platform: Optional[str] = None,
|
||||||
platform_state: Optional[str] = None,
|
platform_state: Optional[str] = None,
|
||||||
error_code: Optional[str] = None,
|
error_code: Optional[str] = None,
|
||||||
@@ -236,6 +240,10 @@ def write_runtime_status(
|
|||||||
payload["gateway_state"] = gateway_state
|
payload["gateway_state"] = gateway_state
|
||||||
if exit_reason is not None:
|
if exit_reason is not None:
|
||||||
payload["exit_reason"] = exit_reason
|
payload["exit_reason"] = exit_reason
|
||||||
|
if restart_requested is not None:
|
||||||
|
payload["restart_requested"] = bool(restart_requested)
|
||||||
|
if active_agents is not None:
|
||||||
|
payload["active_agents"] = max(0, int(active_agents))
|
||||||
|
|
||||||
if platform is not None:
|
if platform is not None:
|
||||||
platform_payload = payload["platforms"].get(platform, {})
|
platform_payload = payload["platforms"].get(platform, {})
|
||||||
|
|||||||
@@ -140,6 +140,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
|||||||
CommandDef("commands", "Browse all commands and skills (paginated)", "Info",
|
CommandDef("commands", "Browse all commands and skills (paginated)", "Info",
|
||||||
gateway_only=True, args_hint="[page]"),
|
gateway_only=True, args_hint="[page]"),
|
||||||
CommandDef("help", "Show available commands", "Info"),
|
CommandDef("help", "Show available commands", "Info"),
|
||||||
|
CommandDef("restart", "Gracefully restart the gateway after draining active runs", "Session",
|
||||||
|
gateway_only=True),
|
||||||
CommandDef("usage", "Show token usage and rate limits for the current session", "Info"),
|
CommandDef("usage", "Show token usage and rate limits for the current session", "Info"),
|
||||||
CommandDef("insights", "Show usage insights and analytics", "Info",
|
CommandDef("insights", "Show usage insights and analytics", "Info",
|
||||||
args_hint="[days]"),
|
args_hint="[days]"),
|
||||||
|
|||||||
@@ -269,6 +269,11 @@ DEFAULT_CONFIG = {
|
|||||||
# tools or receiving API responses. Only fires when the agent has
|
# tools or receiving API responses. Only fires when the agent has
|
||||||
# been completely idle for this duration. 0 = unlimited.
|
# been completely idle for this duration. 0 = unlimited.
|
||||||
"gateway_timeout": 1800,
|
"gateway_timeout": 1800,
|
||||||
|
# Graceful drain timeout for gateway stop/restart (seconds).
|
||||||
|
# The gateway stops accepting new work, waits for running agents
|
||||||
|
# to finish, then interrupts any remaining runs after the timeout.
|
||||||
|
# 0 = no drain, interrupt immediately.
|
||||||
|
"restart_drain_timeout": 60,
|
||||||
"service_tier": "",
|
"service_tier": "",
|
||||||
# Tool-use enforcement: injects system prompt guidance that tells the
|
# Tool-use enforcement: injects system prompt guidance that tells the
|
||||||
# model to actually call tools instead of describing intended actions.
|
# model to actually call tools instead of describing intended actions.
|
||||||
|
|||||||
@@ -15,7 +15,19 @@ from pathlib import Path
|
|||||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||||
|
|
||||||
from gateway.status import terminate_pid
|
from gateway.status import terminate_pid
|
||||||
from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error
|
from gateway.restart import (
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||||
|
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||||
|
parse_restart_drain_timeout,
|
||||||
|
)
|
||||||
|
from hermes_cli.config import (
|
||||||
|
get_env_value,
|
||||||
|
get_hermes_home,
|
||||||
|
is_managed,
|
||||||
|
managed_error,
|
||||||
|
read_raw_config,
|
||||||
|
save_env_value,
|
||||||
|
)
|
||||||
# display_hermes_home is imported lazily at call sites to avoid ImportError
|
# display_hermes_home is imported lazily at call sites to avoid ImportError
|
||||||
# when hermes_constants is cached from a pre-update version during `hermes update`.
|
# when hermes_constants is cached from a pre-update version during `hermes update`.
|
||||||
from hermes_cli.setup import (
|
from hermes_cli.setup import (
|
||||||
@@ -92,6 +104,59 @@ def _get_service_pids() -> set:
|
|||||||
return pids
|
return pids
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parent_pid(pid: int) -> int | None:
|
||||||
|
"""Return the parent PID for ``pid``, or ``None`` when unavailable."""
|
||||||
|
if pid <= 1:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["ps", "-o", "ppid=", "-p", str(pid)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||||
|
return None
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
raw = result.stdout.strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parent_pid = int(raw.splitlines()[-1].strip())
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return parent_pid if parent_pid > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pid_ancestor_of_current_process(target_pid: int) -> bool:
|
||||||
|
"""Return True when ``target_pid`` is this process or one of its ancestors."""
|
||||||
|
if target_pid <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
pid = os.getpid()
|
||||||
|
seen: set[int] = set()
|
||||||
|
while pid and pid not in seen:
|
||||||
|
if pid == target_pid:
|
||||||
|
return True
|
||||||
|
seen.add(pid)
|
||||||
|
pid = _get_parent_pid(pid) or 0
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _request_gateway_self_restart(pid: int) -> bool:
|
||||||
|
"""Ask a running gateway ancestor to restart itself asynchronously."""
|
||||||
|
if not hasattr(signal, "SIGUSR1"):
|
||||||
|
return False
|
||||||
|
if not _is_pid_ancestor_of_current_process(pid):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
os.kill(pid, signal.SIGUSR1)
|
||||||
|
except (ProcessLookupError, PermissionError, OSError):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
||||||
"""Find PIDs of running gateway processes.
|
"""Find PIDs of running gateway processes.
|
||||||
|
|
||||||
@@ -665,6 +730,7 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
|||||||
path_entries.append(resolved_node_dir)
|
path_entries.append(resolved_node_dir)
|
||||||
|
|
||||||
common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"]
|
common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"]
|
||||||
|
restart_timeout = max(60, int(_get_restart_drain_timeout() or 0))
|
||||||
|
|
||||||
if system:
|
if system:
|
||||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||||
@@ -703,9 +769,11 @@ Environment="VIRTUAL_ENV={venv_dir}"
|
|||||||
Environment="HERMES_HOME={hermes_home}"
|
Environment="HERMES_HOME={hermes_home}"
|
||||||
Restart=on-failure
|
Restart=on-failure
|
||||||
RestartSec=30
|
RestartSec=30
|
||||||
|
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
|
||||||
KillMode=mixed
|
KillMode=mixed
|
||||||
KillSignal=SIGTERM
|
KillSignal=SIGTERM
|
||||||
TimeoutStopSec=60
|
ExecReload=/bin/kill -USR1 $MAINPID
|
||||||
|
TimeoutStopSec={restart_timeout}
|
||||||
StandardOutput=journal
|
StandardOutput=journal
|
||||||
StandardError=journal
|
StandardError=journal
|
||||||
|
|
||||||
@@ -733,9 +801,11 @@ Environment="VIRTUAL_ENV={venv_dir}"
|
|||||||
Environment="HERMES_HOME={hermes_home}"
|
Environment="HERMES_HOME={hermes_home}"
|
||||||
Restart=on-failure
|
Restart=on-failure
|
||||||
RestartSec=30
|
RestartSec=30
|
||||||
|
RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}
|
||||||
KillMode=mixed
|
KillMode=mixed
|
||||||
KillSignal=SIGTERM
|
KillSignal=SIGTERM
|
||||||
TimeoutStopSec=60
|
ExecReload=/bin/kill -USR1 $MAINPID
|
||||||
|
TimeoutStopSec={restart_timeout}
|
||||||
StandardOutput=journal
|
StandardOutput=journal
|
||||||
StandardError=journal
|
StandardError=journal
|
||||||
|
|
||||||
@@ -838,6 +908,20 @@ def _select_systemd_scope(system: bool = False) -> bool:
|
|||||||
return get_systemd_unit_path(system=True).exists() and not get_systemd_unit_path(system=False).exists()
|
return get_systemd_unit_path(system=True).exists() and not get_systemd_unit_path(system=False).exists()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_restart_drain_timeout() -> float:
|
||||||
|
"""Return the configured gateway restart drain timeout in seconds."""
|
||||||
|
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
|
||||||
|
if not raw:
|
||||||
|
cfg = read_raw_config()
|
||||||
|
agent_cfg = cfg.get("agent", {}) if isinstance(cfg, dict) else {}
|
||||||
|
raw = str(
|
||||||
|
agent_cfg.get(
|
||||||
|
"restart_drain_timeout", DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return parse_restart_drain_timeout(raw)
|
||||||
|
|
||||||
|
|
||||||
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
|
def systemd_install(force: bool = False, system: bool = False, run_as_user: str | None = None):
|
||||||
if system:
|
if system:
|
||||||
_require_root_for_system_service("install")
|
_require_root_for_system_service("install")
|
||||||
@@ -923,7 +1007,13 @@ def systemd_restart(system: bool = False):
|
|||||||
if system:
|
if system:
|
||||||
_require_root_for_system_service("restart")
|
_require_root_for_system_service("restart")
|
||||||
refresh_systemd_unit_if_needed(system=system)
|
refresh_systemd_unit_if_needed(system=system)
|
||||||
subprocess.run(_systemctl_cmd(system) + ["restart", get_service_name()], check=True, timeout=90)
|
from gateway.status import get_running_pid
|
||||||
|
|
||||||
|
pid = get_running_pid()
|
||||||
|
if pid is not None and _request_gateway_self_restart(pid):
|
||||||
|
print(f"✓ {_service_scope_label(system).capitalize()} service restart requested")
|
||||||
|
return
|
||||||
|
subprocess.run(_systemctl_cmd(system) + ["reload-or-restart", get_service_name()], check=True, timeout=90)
|
||||||
print(f"✓ {_service_scope_label(system).capitalize()} service restarted")
|
print(f"✓ {_service_scope_label(system).capitalize()} service restarted")
|
||||||
|
|
||||||
|
|
||||||
@@ -1211,7 +1301,7 @@ def launchd_stop():
|
|||||||
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
|
_wait_for_gateway_exit(timeout=10.0, force_after=5.0)
|
||||||
print("✓ Service stopped")
|
print("✓ Service stopped")
|
||||||
|
|
||||||
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float | None = 5.0) -> bool:
|
||||||
"""Wait for the gateway process (by saved PID) to exit.
|
"""Wait for the gateway process (by saved PID) to exit.
|
||||||
|
|
||||||
Uses the PID from the gateway.pid file — not launchd labels — so this
|
Uses the PID from the gateway.pid file — not launchd labels — so this
|
||||||
@@ -1226,21 +1316,21 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
|||||||
from gateway.status import get_running_pid
|
from gateway.status import get_running_pid
|
||||||
|
|
||||||
deadline = time.monotonic() + timeout
|
deadline = time.monotonic() + timeout
|
||||||
force_deadline = time.monotonic() + force_after
|
force_deadline = (time.monotonic() + force_after) if force_after is not None else None
|
||||||
force_sent = False
|
force_sent = False
|
||||||
|
|
||||||
while time.monotonic() < deadline:
|
while time.monotonic() < deadline:
|
||||||
pid = get_running_pid()
|
pid = get_running_pid()
|
||||||
if pid is None:
|
if pid is None:
|
||||||
return # Process exited cleanly.
|
return True # Process exited cleanly.
|
||||||
|
|
||||||
if not force_sent and time.monotonic() >= force_deadline:
|
if force_after is not None and not force_sent and time.monotonic() >= force_deadline:
|
||||||
# Grace period expired — force-kill the specific PID.
|
# Grace period expired — force-kill the specific PID.
|
||||||
try:
|
try:
|
||||||
terminate_pid(pid, force=True)
|
terminate_pid(pid, force=True)
|
||||||
print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL")
|
print(f"⚠ Gateway PID {pid} did not exit gracefully; sent SIGKILL")
|
||||||
except (ProcessLookupError, PermissionError, OSError):
|
except (ProcessLookupError, PermissionError, OSError):
|
||||||
return # Already gone or we can't touch it.
|
return True # Already gone or we can't touch it.
|
||||||
force_sent = True
|
force_sent = True
|
||||||
|
|
||||||
time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
@@ -1249,15 +1339,30 @@ def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0):
|
|||||||
remaining_pid = get_running_pid()
|
remaining_pid = get_running_pid()
|
||||||
if remaining_pid is not None:
|
if remaining_pid is not None:
|
||||||
print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail")
|
print(f"⚠ Gateway PID {remaining_pid} still running after {timeout}s — restart may fail")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def launchd_restart():
|
def launchd_restart():
|
||||||
label = get_launchd_label()
|
label = get_launchd_label()
|
||||||
target = f"{_launchd_domain()}/{label}"
|
target = f"{_launchd_domain()}/{label}"
|
||||||
# Use kickstart -k so launchd performs an atomic kill+restart.
|
drain_timeout = _get_restart_drain_timeout()
|
||||||
# A two-step stop/start from inside the gateway's own process tree
|
from gateway.status import get_running_pid
|
||||||
# would kill the shell before the start command is reached.
|
|
||||||
try:
|
try:
|
||||||
|
pid = get_running_pid()
|
||||||
|
if pid is not None and _request_gateway_self_restart(pid):
|
||||||
|
print("✓ Service restart requested")
|
||||||
|
return
|
||||||
|
if pid is not None:
|
||||||
|
try:
|
||||||
|
terminate_pid(pid, force=False)
|
||||||
|
except (ProcessLookupError, PermissionError, OSError):
|
||||||
|
pid = None
|
||||||
|
if pid is not None:
|
||||||
|
exited = _wait_for_gateway_exit(timeout=drain_timeout, force_after=None)
|
||||||
|
if not exited:
|
||||||
|
print(f"⚠ Gateway drain timed out after {drain_timeout:.0f}s — forcing launchd restart")
|
||||||
subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90)
|
subprocess.run(["launchctl", "kickstart", "-k", target], check=True, timeout=90)
|
||||||
print("✓ Service restarted")
|
print("✓ Service restarted")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
@@ -1728,6 +1833,8 @@ def _runtime_health_lines() -> list[str]:
|
|||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
gateway_state = state.get("gateway_state")
|
gateway_state = state.get("gateway_state")
|
||||||
exit_reason = state.get("exit_reason")
|
exit_reason = state.get("exit_reason")
|
||||||
|
active_agents = state.get("active_agents")
|
||||||
|
restart_requested = state.get("restart_requested")
|
||||||
platforms = state.get("platforms", {}) or {}
|
platforms = state.get("platforms", {}) or {}
|
||||||
|
|
||||||
for platform, pdata in platforms.items():
|
for platform, pdata in platforms.items():
|
||||||
@@ -1737,6 +1844,10 @@ def _runtime_health_lines() -> list[str]:
|
|||||||
|
|
||||||
if gateway_state == "startup_failed" and exit_reason:
|
if gateway_state == "startup_failed" and exit_reason:
|
||||||
lines.append(f"⚠ Last startup issue: {exit_reason}")
|
lines.append(f"⚠ Last startup issue: {exit_reason}")
|
||||||
|
elif gateway_state == "draining":
|
||||||
|
action = "restart" if restart_requested else "shutdown"
|
||||||
|
count = int(active_agents or 0)
|
||||||
|
lines.append(f"⏳ Gateway draining for {action} ({count} active agent(s))")
|
||||||
elif gateway_state == "stopped" and exit_reason:
|
elif gateway_state == "stopped" and exit_reason:
|
||||||
lines.append(f"⚠ Last shutdown reason: {exit_reason}")
|
lines.append(f"⚠ Last shutdown reason: {exit_reason}")
|
||||||
|
|
||||||
|
|||||||
110
tests/gateway/restart_test_helpers.py
Normal file
110
tests/gateway/restart_test_helpers.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||||
|
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||||
|
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
|
||||||
|
class RestartTestAdapter(BasePlatformAdapter):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||||
|
self.sent: list[str] = []
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||||
|
self.sent.append(content)
|
||||||
|
return SendResult(success=True, message_id="1")
|
||||||
|
|
||||||
|
async def send_typing(self, chat_id, metadata=None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_chat_info(self, chat_id):
|
||||||
|
return {"id": chat_id}
|
||||||
|
|
||||||
|
|
||||||
|
def make_restart_source(chat_id: str = "123456", chat_type: str = "dm") -> SessionSource:
|
||||||
|
return SessionSource(
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_id=chat_id,
|
||||||
|
chat_type=chat_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_restart_runner(
|
||||||
|
adapter: BasePlatformAdapter | None = None,
|
||||||
|
) -> tuple[GatewayRunner, BasePlatformAdapter]:
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner.config = GatewayConfig(
|
||||||
|
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||||
|
)
|
||||||
|
runner._running = True
|
||||||
|
runner._shutdown_event = asyncio.Event()
|
||||||
|
runner._exit_reason = None
|
||||||
|
runner._exit_code = None
|
||||||
|
runner._running_agents = {}
|
||||||
|
runner._running_agents_ts = {}
|
||||||
|
runner._pending_messages = {}
|
||||||
|
runner._pending_approvals = {}
|
||||||
|
runner._pending_model_notes = {}
|
||||||
|
runner._background_tasks = set()
|
||||||
|
runner._draining = False
|
||||||
|
runner._restart_requested = False
|
||||||
|
runner._restart_task_started = False
|
||||||
|
runner._restart_detached = False
|
||||||
|
runner._restart_via_service = False
|
||||||
|
runner._restart_drain_timeout = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
runner._stop_task = None
|
||||||
|
runner._busy_input_mode = "interrupt"
|
||||||
|
runner._update_prompt_pending = {}
|
||||||
|
runner._voice_mode = {}
|
||||||
|
runner._session_model_overrides = {}
|
||||||
|
runner._shutdown_all_gateway_honcho = lambda: None
|
||||||
|
runner._update_runtime_status = MagicMock()
|
||||||
|
runner._queue_or_replace_pending_event = GatewayRunner._queue_or_replace_pending_event.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._session_key_for_source = GatewayRunner._session_key_for_source.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._handle_active_session_busy_message = (
|
||||||
|
GatewayRunner._handle_active_session_busy_message.__get__(runner, GatewayRunner)
|
||||||
|
)
|
||||||
|
runner._handle_restart_command = GatewayRunner._handle_restart_command.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._status_action_label = GatewayRunner._status_action_label.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._status_action_gerund = GatewayRunner._status_action_gerund.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._queue_during_drain_enabled = GatewayRunner._queue_during_drain_enabled.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._running_agent_count = GatewayRunner._running_agent_count.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner._launch_detached_restart_command = GatewayRunner._launch_detached_restart_command.__get__(
|
||||||
|
runner, GatewayRunner
|
||||||
|
)
|
||||||
|
runner.request_restart = GatewayRunner.request_restart.__get__(runner, GatewayRunner)
|
||||||
|
runner._is_user_authorized = lambda _source: True
|
||||||
|
runner.hooks = MagicMock()
|
||||||
|
runner.hooks.emit = AsyncMock()
|
||||||
|
runner.pairing_store = MagicMock()
|
||||||
|
runner.session_store = MagicMock()
|
||||||
|
runner.delivery_router = MagicMock()
|
||||||
|
|
||||||
|
platform_adapter = adapter or RestartTestAdapter()
|
||||||
|
platform_adapter.set_message_handler(AsyncMock(return_value=None))
|
||||||
|
platform_adapter.set_busy_session_handler(runner._handle_active_session_busy_message)
|
||||||
|
runner.adapters = {Platform.TELEGRAM: platform_adapter}
|
||||||
|
return runner, platform_adapter
|
||||||
@@ -3,43 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
from gateway.platforms.base import MessageEvent
|
||||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
from gateway.restart import GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||||
from gateway.run import GatewayRunner
|
from gateway.session import build_session_key
|
||||||
from gateway.session import SessionSource, build_session_key
|
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||||
|
|
||||||
|
|
||||||
class StubAdapter(BasePlatformAdapter):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
|
||||||
|
|
||||||
async def connect(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def disconnect(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
|
||||||
return SendResult(success=True, message_id="1")
|
|
||||||
|
|
||||||
async def send_typing(self, chat_id, metadata=None):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_chat_info(self, chat_id):
|
|
||||||
return {"id": chat_id}
|
|
||||||
|
|
||||||
|
|
||||||
def _source(chat_id="123456", chat_type="dm"):
|
|
||||||
return SessionSource(
|
|
||||||
platform=Platform.TELEGRAM,
|
|
||||||
chat_id=chat_id,
|
|
||||||
chat_type=chat_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
||||||
adapter = StubAdapter()
|
_runner, adapter = make_restart_runner()
|
||||||
release = asyncio.Event()
|
release = asyncio.Event()
|
||||||
|
|
||||||
async def block_forever(_event):
|
async def block_forever(_event):
|
||||||
@@ -47,7 +19,7 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
adapter.set_message_handler(block_forever)
|
adapter.set_message_handler(block_forever)
|
||||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
||||||
|
|
||||||
await adapter.handle_message(event)
|
await adapter.handle_message(event)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
@@ -65,17 +37,11 @@ async def test_cancel_background_tasks_cancels_inflight_message_processing():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks():
|
||||||
runner = object.__new__(GatewayRunner)
|
runner, adapter = make_restart_runner()
|
||||||
runner.config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")})
|
|
||||||
runner._running = True
|
|
||||||
runner._shutdown_event = asyncio.Event()
|
|
||||||
runner._exit_reason = None
|
|
||||||
runner._pending_messages = {"session": "pending text"}
|
runner._pending_messages = {"session": "pending text"}
|
||||||
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
runner._pending_approvals = {"session": {"command": "rm -rf /tmp/x"}}
|
||||||
runner._background_tasks = set()
|
runner._restart_drain_timeout = 0.0
|
||||||
runner._shutdown_all_gateway_honcho = lambda: None
|
|
||||||
|
|
||||||
adapter = StubAdapter()
|
|
||||||
release = asyncio.Event()
|
release = asyncio.Event()
|
||||||
|
|
||||||
async def block_forever(_event):
|
async def block_forever(_event):
|
||||||
@@ -83,7 +49,7 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
adapter.set_message_handler(block_forever)
|
adapter.set_message_handler(block_forever)
|
||||||
event = MessageEvent(text="work", source=_source(), message_id="1")
|
event = MessageEvent(text="work", source=make_restart_source(), message_id="1")
|
||||||
await adapter.handle_message(event)
|
await adapter.handle_message(event)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
@@ -93,7 +59,6 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
|||||||
session_key = build_session_key(event.source)
|
session_key = build_session_key(event.source)
|
||||||
running_agent = MagicMock()
|
running_agent = MagicMock()
|
||||||
runner._running_agents = {session_key: running_agent}
|
runner._running_agents = {session_key: running_agent}
|
||||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
|
||||||
|
|
||||||
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||||
await runner.stop()
|
await runner.stop()
|
||||||
@@ -105,3 +70,78 @@ async def test_gateway_stop_interrupts_running_agents_and_cancels_adapter_tasks(
|
|||||||
assert runner._pending_messages == {}
|
assert runner._pending_messages == {}
|
||||||
assert runner._pending_approvals == {}
|
assert runner._pending_approvals == {}
|
||||||
assert runner._shutdown_event.is_set() is True
|
assert runner._shutdown_event.is_set() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gateway_stop_drains_running_agents_before_disconnect():
|
||||||
|
runner, adapter = make_restart_runner()
|
||||||
|
disconnect_mock = AsyncMock()
|
||||||
|
adapter.disconnect = disconnect_mock
|
||||||
|
|
||||||
|
running_agent = MagicMock()
|
||||||
|
runner._running_agents = {"session": running_agent}
|
||||||
|
|
||||||
|
async def finish_agent():
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
runner._running_agents.clear()
|
||||||
|
|
||||||
|
asyncio.create_task(finish_agent())
|
||||||
|
|
||||||
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||||
|
await runner.stop()
|
||||||
|
|
||||||
|
running_agent.interrupt.assert_not_called()
|
||||||
|
disconnect_mock.assert_awaited_once()
|
||||||
|
assert runner._shutdown_event.is_set() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gateway_stop_interrupts_after_drain_timeout():
|
||||||
|
runner, adapter = make_restart_runner()
|
||||||
|
runner._restart_drain_timeout = 0.05
|
||||||
|
|
||||||
|
disconnect_mock = AsyncMock()
|
||||||
|
adapter.disconnect = disconnect_mock
|
||||||
|
|
||||||
|
running_agent = MagicMock()
|
||||||
|
runner._running_agents = {"session": running_agent}
|
||||||
|
|
||||||
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||||
|
await runner.stop()
|
||||||
|
|
||||||
|
running_agent.interrupt.assert_called_once_with("Gateway shutting down")
|
||||||
|
disconnect_mock.assert_awaited_once()
|
||||||
|
assert runner._shutdown_event.is_set() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gateway_stop_service_restart_sets_named_exit_code():
|
||||||
|
runner, adapter = make_restart_runner()
|
||||||
|
adapter.disconnect = AsyncMock()
|
||||||
|
|
||||||
|
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
|
||||||
|
await runner.stop(restart=True, service_restart=True)
|
||||||
|
|
||||||
|
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drain_active_agents_throttles_status_updates():
|
||||||
|
runner, _adapter = make_restart_runner()
|
||||||
|
runner._update_runtime_status = MagicMock()
|
||||||
|
|
||||||
|
runner._running_agents = {"a": MagicMock(), "b": MagicMock()}
|
||||||
|
|
||||||
|
async def finish_agents():
|
||||||
|
await asyncio.sleep(0.12)
|
||||||
|
runner._running_agents.pop("a")
|
||||||
|
await asyncio.sleep(0.12)
|
||||||
|
runner._running_agents.clear()
|
||||||
|
|
||||||
|
task = asyncio.create_task(finish_agents())
|
||||||
|
await runner._drain_active_agents(1.0)
|
||||||
|
await task
|
||||||
|
|
||||||
|
# Start, one count-change update, and final update. Allow one extra update
|
||||||
|
# if the loop observes the zero-agent state before exiting.
|
||||||
|
assert 3 <= runner._update_runtime_status.call_count <= 4
|
||||||
|
|||||||
160
tests/gateway/test_restart_drain.py
Normal file
160
tests/gateway/test_restart_drain.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import asyncio
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gateway.run as gateway_run
|
||||||
|
from gateway.platforms.base import MessageEvent, MessageType
|
||||||
|
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
from gateway.session import build_session_key
|
||||||
|
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_restart_command_while_busy_requests_drain_without_interrupt():
|
||||||
|
runner, _adapter = make_restart_runner()
|
||||||
|
runner.request_restart = MagicMock(return_value=True)
|
||||||
|
event = MessageEvent(
|
||||||
|
text="/restart",
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=make_restart_source(),
|
||||||
|
message_id="m1",
|
||||||
|
)
|
||||||
|
session_key = build_session_key(event.source)
|
||||||
|
running_agent = MagicMock()
|
||||||
|
runner._running_agents[session_key] = running_agent
|
||||||
|
|
||||||
|
result = await runner._handle_message(event)
|
||||||
|
|
||||||
|
assert result == "⏳ Draining 1 active agent(s) before restart..."
|
||||||
|
running_agent.interrupt.assert_not_called()
|
||||||
|
runner.request_restart.assert_called_once_with(detached=True, via_service=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drain_queue_mode_queues_follow_up_without_interrupt():
|
||||||
|
runner, adapter = make_restart_runner()
|
||||||
|
runner._draining = True
|
||||||
|
runner._restart_requested = True
|
||||||
|
runner._busy_input_mode = "queue"
|
||||||
|
|
||||||
|
event = MessageEvent(
|
||||||
|
text="follow up",
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=make_restart_source(),
|
||||||
|
message_id="m2",
|
||||||
|
)
|
||||||
|
session_key = build_session_key(event.source)
|
||||||
|
adapter._active_sessions[session_key] = asyncio.Event()
|
||||||
|
|
||||||
|
await adapter.handle_message(event)
|
||||||
|
|
||||||
|
assert session_key in adapter._pending_messages
|
||||||
|
assert adapter._pending_messages[session_key].text == "follow up"
|
||||||
|
assert not adapter._active_sessions[session_key].is_set()
|
||||||
|
assert any("queued for the next turn" in message for message in adapter.sent)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_draining_rejects_new_session_messages():
|
||||||
|
runner, _adapter = make_restart_runner()
|
||||||
|
runner._draining = True
|
||||||
|
runner._restart_requested = True
|
||||||
|
|
||||||
|
event = MessageEvent(
|
||||||
|
text="hello",
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=make_restart_source("fresh"),
|
||||||
|
message_id="m3",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await runner._handle_message(event)
|
||||||
|
|
||||||
|
assert result == "⏳ Gateway is restarting and is not accepting new work right now."
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||||
|
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
|
||||||
|
|
||||||
|
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||||
|
|
||||||
|
(tmp_path / "config.yaml").write_text(
|
||||||
|
"display:\n busy_input_mode: queue\n", encoding="utf-8"
|
||||||
|
)
|
||||||
|
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
|
||||||
|
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
|
||||||
|
tmp_path, monkeypatch, caplog
|
||||||
|
):
|
||||||
|
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||||
|
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
gateway_run.GatewayRunner._load_restart_drain_timeout()
|
||||||
|
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
(tmp_path / "config.yaml").write_text(
|
||||||
|
"agent:\n restart_drain_timeout: 12\n", encoding="utf-8"
|
||||||
|
)
|
||||||
|
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7")
|
||||||
|
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
|
||||||
|
assert (
|
||||||
|
gateway_run.GatewayRunner._load_restart_drain_timeout()
|
||||||
|
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
assert "Invalid restart_drain_timeout" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_restart_is_idempotent():
|
||||||
|
runner, _adapter = make_restart_runner()
|
||||||
|
runner.stop = AsyncMock()
|
||||||
|
|
||||||
|
assert runner.request_restart(detached=True, via_service=False) is True
|
||||||
|
first_task = next(iter(runner._background_tasks))
|
||||||
|
assert runner.request_restart(detached=True, via_service=False) is False
|
||||||
|
|
||||||
|
await first_task
|
||||||
|
|
||||||
|
runner.stop.assert_awaited_once_with(
|
||||||
|
restart=True, detached_restart=True, service_restart=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_launch_detached_restart_command_uses_setsid(monkeypatch):
|
||||||
|
runner, _adapter = make_restart_runner()
|
||||||
|
popen_calls = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"])
|
||||||
|
monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321)
|
||||||
|
monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None)
|
||||||
|
|
||||||
|
def fake_popen(cmd, **kwargs):
|
||||||
|
popen_calls.append((cmd, kwargs))
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
monkeypatch.setattr(subprocess, "Popen", fake_popen)
|
||||||
|
|
||||||
|
await runner._launch_detached_restart_command()
|
||||||
|
|
||||||
|
assert len(popen_calls) == 1
|
||||||
|
cmd, kwargs = popen_calls[0]
|
||||||
|
assert cmd[:2] == ["/usr/bin/setsid", "bash"]
|
||||||
|
assert "gateway restart" in cmd[-1]
|
||||||
|
assert "kill -0 321" in cmd[-1]
|
||||||
|
assert kwargs["start_new_session"] is True
|
||||||
|
assert kwargs["stdout"] is subprocess.DEVNULL
|
||||||
|
assert kwargs["stderr"] is subprocess.DEVNULL
|
||||||
@@ -127,6 +127,16 @@ async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook):
|
|||||||
runner._shutdown_event = MagicMock()
|
runner._shutdown_event = MagicMock()
|
||||||
runner.adapters = {}
|
runner.adapters = {}
|
||||||
runner._exit_reason = "test"
|
runner._exit_reason = "test"
|
||||||
|
runner._exit_code = None
|
||||||
|
runner._draining = False
|
||||||
|
runner._restart_requested = False
|
||||||
|
runner._restart_task_started = False
|
||||||
|
runner._restart_detached = False
|
||||||
|
runner._restart_via_service = False
|
||||||
|
runner._restart_drain_timeout = 0.0
|
||||||
|
runner._stop_task = None
|
||||||
|
runner._running_agents_ts = {}
|
||||||
|
runner._update_runtime_status = MagicMock()
|
||||||
|
|
||||||
agent1 = MagicMock()
|
agent1 = MagicMock()
|
||||||
agent1.session_id = "sess-a"
|
agent1.session_id = "sess-a"
|
||||||
|
|||||||
@@ -41,6 +41,15 @@ def _make_runner():
|
|||||||
runner._pending_approvals = {}
|
runner._pending_approvals = {}
|
||||||
runner._voice_mode = {}
|
runner._voice_mode = {}
|
||||||
runner._background_tasks = set()
|
runner._background_tasks = set()
|
||||||
|
runner._draining = False
|
||||||
|
runner._restart_requested = False
|
||||||
|
runner._restart_task_started = False
|
||||||
|
runner._restart_detached = False
|
||||||
|
runner._restart_via_service = False
|
||||||
|
runner._restart_drain_timeout = 0.0
|
||||||
|
runner._stop_task = None
|
||||||
|
runner._exit_code = None
|
||||||
|
runner._update_runtime_status = MagicMock()
|
||||||
runner._is_user_authorized = lambda _source: True
|
runner._is_user_authorized = lambda _source: True
|
||||||
runner.hooks = MagicMock()
|
runner.hooks = MagicMock()
|
||||||
runner.hooks.emit = AsyncMock()
|
runner.hooks.emit = AsyncMock()
|
||||||
|
|||||||
@@ -5,6 +5,10 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import hermes_cli.gateway as gateway_cli
|
import hermes_cli.gateway as gateway_cli
|
||||||
|
from gateway.restart import (
|
||||||
|
DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT,
|
||||||
|
GATEWAY_SERVICE_RESTART_EXIT_CODE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSystemdServiceRefresh:
|
class TestSystemdServiceRefresh:
|
||||||
@@ -74,7 +78,7 @@ class TestSystemdServiceRefresh:
|
|||||||
assert unit_path.read_text(encoding="utf-8") == "new unit\n"
|
assert unit_path.read_text(encoding="utf-8") == "new unit\n"
|
||||||
assert calls[:2] == [
|
assert calls[:2] == [
|
||||||
["systemctl", "--user", "daemon-reload"],
|
["systemctl", "--user", "daemon-reload"],
|
||||||
["systemctl", "--user", "restart", gateway_cli.get_service_name()],
|
["systemctl", "--user", "reload-or-restart", gateway_cli.get_service_name()],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -84,6 +88,8 @@ class TestGeneratedSystemdUnits:
|
|||||||
|
|
||||||
assert "ExecStart=" in unit
|
assert "ExecStart=" in unit
|
||||||
assert "ExecStop=" not in unit
|
assert "ExecStop=" not in unit
|
||||||
|
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
||||||
|
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
|
||||||
assert "TimeoutStopSec=60" in unit
|
assert "TimeoutStopSec=60" in unit
|
||||||
|
|
||||||
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
|
def test_user_unit_includes_resolved_node_directory_in_path(self, monkeypatch):
|
||||||
@@ -98,6 +104,8 @@ class TestGeneratedSystemdUnits:
|
|||||||
|
|
||||||
assert "ExecStart=" in unit
|
assert "ExecStart=" in unit
|
||||||
assert "ExecStop=" not in unit
|
assert "ExecStop=" not in unit
|
||||||
|
assert "ExecReload=/bin/kill -USR1 $MAINPID" in unit
|
||||||
|
assert f"RestartForceExitStatus={GATEWAY_SERVICE_RESTART_EXIT_CODE}" in unit
|
||||||
assert "TimeoutStopSec=60" in unit
|
assert "TimeoutStopSec=60" in unit
|
||||||
assert "WantedBy=multi-user.target" in unit
|
assert "WantedBy=multi-user.target" in unit
|
||||||
|
|
||||||
@@ -157,6 +165,31 @@ class TestGatewayStopCleanup:
|
|||||||
|
|
||||||
|
|
||||||
class TestLaunchdServiceRecovery:
|
class TestLaunchdServiceRecovery:
|
||||||
|
def test_get_restart_drain_timeout_prefers_env_then_config_then_default(self, monkeypatch):
|
||||||
|
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||||
|
monkeypatch.setattr(gateway_cli, "read_raw_config", lambda: {})
|
||||||
|
|
||||||
|
assert (
|
||||||
|
gateway_cli._get_restart_drain_timeout()
|
||||||
|
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
gateway_cli,
|
||||||
|
"read_raw_config",
|
||||||
|
lambda: {"agent": {"restart_drain_timeout": 14}},
|
||||||
|
)
|
||||||
|
assert gateway_cli._get_restart_drain_timeout() == 14.0
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "9")
|
||||||
|
assert gateway_cli._get_restart_drain_timeout() == 9.0
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
|
||||||
|
assert (
|
||||||
|
gateway_cli._get_restart_drain_timeout()
|
||||||
|
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
|
def test_launchd_install_repairs_outdated_plist_without_force(self, tmp_path, monkeypatch):
|
||||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||||
plist_path.write_text("<plist>old content</plist>", encoding="utf-8")
|
plist_path.write_text("<plist>old content</plist>", encoding="utf-8")
|
||||||
@@ -234,6 +267,55 @@ class TestLaunchdServiceRecovery:
|
|||||||
["launchctl", "kickstart", target],
|
["launchctl", "kickstart", target],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_launchd_restart_drains_running_gateway_before_kickstart(self, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
target = f"{gateway_cli._launchd_domain()}/{gateway_cli.get_launchd_label()}"
|
||||||
|
|
||||||
|
monkeypatch.setattr(gateway_cli, "_get_restart_drain_timeout", lambda: 12.0)
|
||||||
|
monkeypatch.setattr(gateway_cli, "_request_gateway_self_restart", lambda pid: False)
|
||||||
|
monkeypatch.setattr(gateway_cli, "_wait_for_gateway_exit", lambda timeout, force_after=None: True)
|
||||||
|
monkeypatch.setattr(gateway_cli, "terminate_pid", lambda pid, force=False: calls.append(("term", pid, force)))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"gateway.status.get_running_pid",
|
||||||
|
lambda: 321,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_run(cmd, check=False, **kwargs):
|
||||||
|
calls.append(cmd)
|
||||||
|
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||||
|
|
||||||
|
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||||
|
|
||||||
|
gateway_cli.launchd_restart()
|
||||||
|
|
||||||
|
assert calls == [
|
||||||
|
("term", 321, False),
|
||||||
|
["launchctl", "kickstart", "-k", target],
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_launchd_restart_self_requests_graceful_restart_without_kickstart(self, monkeypatch, capsys):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"gateway.status.get_running_pid",
|
||||||
|
lambda: 321,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
gateway_cli,
|
||||||
|
"_request_gateway_self_restart",
|
||||||
|
lambda pid: calls.append(("self", pid)) or True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
gateway_cli.subprocess,
|
||||||
|
"run",
|
||||||
|
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("launchctl should not run")),
|
||||||
|
)
|
||||||
|
|
||||||
|
gateway_cli.launchd_restart()
|
||||||
|
|
||||||
|
assert calls == [("self", 321)]
|
||||||
|
assert "restart requested" in capsys.readouterr().out.lower()
|
||||||
|
|
||||||
def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch):
|
def test_launchd_stop_uses_bootout_not_kill(self, monkeypatch):
|
||||||
"""launchd_stop must bootout the service so KeepAlive doesn't respawn it."""
|
"""launchd_stop must bootout the service so KeepAlive doesn't respawn it."""
|
||||||
label = gateway_cli.get_launchd_label()
|
label = gateway_cli.get_launchd_label()
|
||||||
@@ -337,6 +419,31 @@ class TestGatewayServiceDetection:
|
|||||||
|
|
||||||
|
|
||||||
class TestGatewaySystemServiceRouting:
|
class TestGatewaySystemServiceRouting:
|
||||||
|
def test_systemd_restart_self_requests_graceful_restart_without_reload_or_restart(self, monkeypatch, capsys):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(gateway_cli, "_select_systemd_scope", lambda system=False: False)
|
||||||
|
monkeypatch.setattr(gateway_cli, "refresh_systemd_unit_if_needed", lambda system=False: calls.append(("refresh", system)))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"gateway.status.get_running_pid",
|
||||||
|
lambda: 654,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
gateway_cli,
|
||||||
|
"_request_gateway_self_restart",
|
||||||
|
lambda pid: calls.append(("self", pid)) or True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
gateway_cli.subprocess,
|
||||||
|
"run",
|
||||||
|
lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("systemctl should not run")),
|
||||||
|
)
|
||||||
|
|
||||||
|
gateway_cli.systemd_restart()
|
||||||
|
|
||||||
|
assert calls == [("refresh", False), ("self", 654)]
|
||||||
|
assert "restart requested" in capsys.readouterr().out.lower()
|
||||||
|
|
||||||
def test_gateway_install_passes_system_flags(self, monkeypatch):
|
def test_gateway_install_passes_system_flags(self, monkeypatch):
|
||||||
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
monkeypatch.setattr(gateway_cli, "supports_systemd_services", lambda: True)
|
||||||
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
monkeypatch.setattr(gateway_cli, "is_termux", lambda: False)
|
||||||
|
|||||||
Reference in New Issue
Block a user