mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-07 03:07:21 +08:00
Compare commits
1 Commits
fix/client
...
fix/ci-tes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1647dadba |
@@ -1835,15 +1835,9 @@ def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
# Every auxiliary LLM consumer should use these instead of manually
|
||||
# constructing clients and calling .chat.completions.create().
|
||||
|
||||
# Client cache: (provider, async_mode, base_url, api_key, api_mode, runtime_key) -> (client, default_model, loop)
|
||||
# NOTE: loop identity is NOT part of the key. On async cache hits we check
|
||||
# whether the cached loop is the *current* loop; if not, the stale entry is
|
||||
# replaced in-place. This bounds cache growth to one entry per unique
|
||||
# provider config rather than one per (config × event-loop), which previously
|
||||
# caused unbounded fd accumulation in long-running gateway processes (#10200).
|
||||
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
|
||||
_client_cache: Dict[tuple, tuple] = {}
|
||||
_client_cache_lock = threading.Lock()
|
||||
_CLIENT_CACHE_MAX_SIZE = 64 # safety belt — evict oldest when exceeded
|
||||
|
||||
|
||||
def neuter_async_httpx_del() -> None:
|
||||
@@ -1976,49 +1970,39 @@ def _get_cached_client(
|
||||
Async clients (AsyncOpenAI) use httpx.AsyncClient internally, which
|
||||
binds to the event loop that was current when the client was created.
|
||||
Using such a client on a *different* loop causes deadlocks or
|
||||
RuntimeError. To prevent cross-loop issues, the cache validates on
|
||||
every async hit that the cached loop is the *current, open* loop.
|
||||
If the loop changed (e.g. a new gateway worker-thread loop), the stale
|
||||
entry is replaced in-place rather than creating an additional entry.
|
||||
|
||||
This keeps cache size bounded to one entry per unique provider config,
|
||||
preventing the fd-exhaustion that previously occurred in long-running
|
||||
gateways where recycled worker threads created unbounded entries (#10200).
|
||||
RuntimeError. To prevent cross-loop issues (especially in gateway
|
||||
mode where _run_async() may spawn fresh loops in worker threads), the
|
||||
cache key for async clients includes the current event loop's identity
|
||||
so each loop gets its own client instance.
|
||||
"""
|
||||
# Resolve the current event loop for async clients so we can validate
|
||||
# cached entries. Loop identity is NOT in the cache key — instead we
|
||||
# check at hit time whether the cached loop is still current and open.
|
||||
# This prevents unbounded cache growth from recycled worker-thread loops
|
||||
# while still guaranteeing we never reuse a client on the wrong loop
|
||||
# (which causes deadlocks, see #2681).
|
||||
# Include loop identity for async clients to prevent cross-loop reuse.
|
||||
# httpx.AsyncClient (inside AsyncOpenAI) is bound to the loop where it
|
||||
# was created — reusing it on a different loop causes deadlocks (#2681).
|
||||
loop_id = 0
|
||||
current_loop = None
|
||||
if async_mode:
|
||||
try:
|
||||
import asyncio as _aio
|
||||
current_loop = _aio.get_event_loop()
|
||||
loop_id = id(current_loop)
|
||||
except RuntimeError:
|
||||
pass
|
||||
runtime = _normalize_main_runtime(main_runtime)
|
||||
runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else ()
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key)
|
||||
cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", loop_id, runtime_key)
|
||||
with _client_cache_lock:
|
||||
if cache_key in _client_cache:
|
||||
cached_client, cached_default, cached_loop = _client_cache[cache_key]
|
||||
if async_mode:
|
||||
# Validate: the cached client must be bound to the CURRENT,
|
||||
# OPEN loop. If the loop changed or was closed, the httpx
|
||||
# transport inside is dead — force-close and replace.
|
||||
loop_ok = (
|
||||
cached_loop is not None
|
||||
and cached_loop is current_loop
|
||||
and not cached_loop.is_closed()
|
||||
)
|
||||
if loop_ok:
|
||||
# A cached async client whose loop has been closed will raise
|
||||
# "Event loop is closed" when httpx tries to clean up its
|
||||
# transport. Discard the stale client and create a fresh one.
|
||||
if cached_loop is not None and cached_loop.is_closed():
|
||||
_force_close_async_httpx(cached_client)
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
effective = _compat_model(cached_client, model, cached_default)
|
||||
return cached_client, effective
|
||||
# Stale — evict and fall through to create a new client.
|
||||
_force_close_async_httpx(cached_client)
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
effective = _compat_model(cached_client, model, cached_default)
|
||||
return cached_client, effective
|
||||
@@ -2038,12 +2022,6 @@ def _get_cached_client(
|
||||
bound_loop = current_loop
|
||||
with _client_cache_lock:
|
||||
if cache_key not in _client_cache:
|
||||
# Safety belt: if the cache has grown beyond the max, evict
|
||||
# the oldest entries (FIFO — dict preserves insertion order).
|
||||
while len(_client_cache) >= _CLIENT_CACHE_MAX_SIZE:
|
||||
evict_key, evict_entry = next(iter(_client_cache.items()))
|
||||
_force_close_async_httpx(evict_entry[0])
|
||||
del _client_cache[evict_key]
|
||||
_client_cache[cache_key] = (client, default_model, bound_loop)
|
||||
else:
|
||||
client, default_model, _ = _client_cache[cache_key]
|
||||
|
||||
2
cli.py
2
cli.py
@@ -4100,8 +4100,6 @@ class HermesCLI:
|
||||
self.agent.flush_memories(self.conversation_history)
|
||||
except (Exception, KeyboardInterrupt):
|
||||
pass
|
||||
# Trigger memory extraction on the old session before session_id rotates.
|
||||
self.agent.commit_memory_session(self.conversation_history)
|
||||
self._notify_session_boundary("on_session_finalize")
|
||||
elif self.agent:
|
||||
# First session or empty history — still finalize the old session
|
||||
|
||||
@@ -10,7 +10,6 @@ runs at a time if multiple processes overlap.
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -771,11 +770,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
_cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None
|
||||
_POLL_INTERVAL = 5.0
|
||||
_cron_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
# Preserve scheduler-scoped ContextVar state (for example skill-declared
|
||||
# env passthrough registrations) when the cron run hops into the worker
|
||||
# thread used for inactivity timeout monitoring.
|
||||
_cron_context = contextvars.copy_context()
|
||||
_cron_future = _cron_pool.submit(_cron_context.run, agent.run_conversation, prompt)
|
||||
_cron_future = _cron_pool.submit(agent.run_conversation, prompt)
|
||||
_inactivity_timeout = False
|
||||
try:
|
||||
if _cron_inactivity_limit is None:
|
||||
|
||||
149
gateway/run.py
149
gateway/run.py
@@ -482,27 +482,6 @@ def _resolve_hermes_bin() -> Optional[list[str]]:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_session_key(session_key: str) -> "dict | None":
|
||||
"""Parse a session key into its component parts.
|
||||
|
||||
Session keys follow the format
|
||||
``agent:main:{platform}:{chat_type}:{chat_id}[:{thread_id}[:{user_id}]]``.
|
||||
Returns a dict with ``platform``, ``chat_type``, ``chat_id``, and
|
||||
optionally ``thread_id`` keys, or None if the key doesn't match.
|
||||
"""
|
||||
parts = session_key.split(":")
|
||||
if len(parts) >= 5 and parts[0] == "agent" and parts[1] == "main":
|
||||
result = {
|
||||
"platform": parts[2],
|
||||
"chat_type": parts[3],
|
||||
"chat_id": parts[4],
|
||||
}
|
||||
if len(parts) > 5:
|
||||
result["thread_id"] = parts[5]
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _format_gateway_process_notification(evt: dict) -> "str | None":
|
||||
"""Format a watch pattern event from completion_queue into a [SYSTEM:] message."""
|
||||
evt_type = evt.get("type", "completion")
|
||||
@@ -1510,11 +1489,12 @@ class GatewayRunner:
|
||||
notified: set = set()
|
||||
for session_key in active:
|
||||
# Parse platform + chat_id from the session key.
|
||||
_parsed = _parse_session_key(session_key)
|
||||
if not _parsed:
|
||||
# Format: agent:main:{platform}:{chat_type}:{chat_id}[:{extra}...]
|
||||
parts = session_key.split(":")
|
||||
if len(parts) < 5:
|
||||
continue
|
||||
platform_str = _parsed["platform"]
|
||||
chat_id = _parsed["chat_id"]
|
||||
platform_str = parts[2]
|
||||
chat_id = parts[4]
|
||||
|
||||
# Deduplicate: one notification per chat, even if multiple
|
||||
# sessions (different users/threads) share the same chat.
|
||||
@@ -1530,7 +1510,7 @@ class GatewayRunner:
|
||||
|
||||
# Include thread_id if present so the message lands in the
|
||||
# correct forum topic / thread.
|
||||
thread_id = _parsed.get("thread_id")
|
||||
thread_id = parts[5] if len(parts) > 5 else None
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
|
||||
await adapter.send(chat_id, msg, metadata=metadata)
|
||||
@@ -3978,7 +3958,7 @@ class GatewayRunner:
|
||||
synth_text = _format_gateway_process_notification(evt)
|
||||
if synth_text:
|
||||
try:
|
||||
await self._inject_watch_notification(synth_text, evt)
|
||||
await self._inject_watch_notification(synth_text, event)
|
||||
except Exception as e2:
|
||||
logger.error("Watch notification injection error: %s", e2)
|
||||
except Exception as e:
|
||||
@@ -7472,75 +7452,14 @@ class GatewayRunner:
|
||||
return prefix
|
||||
return user_text
|
||||
|
||||
def _build_process_event_source(self, evt: dict):
|
||||
"""Resolve the canonical source for a synthetic background-process event.
|
||||
|
||||
Prefer the persisted session-store origin for the event's session key.
|
||||
Falling back to the currently active foreground event is what causes
|
||||
cross-topic bleed, so don't do that.
|
||||
"""
|
||||
from gateway.session import SessionSource
|
||||
|
||||
session_key = str(evt.get("session_key") or "").strip()
|
||||
derived_platform = ""
|
||||
derived_chat_type = ""
|
||||
derived_chat_id = ""
|
||||
|
||||
if session_key:
|
||||
try:
|
||||
self.session_store._ensure_loaded()
|
||||
entry = self.session_store._entries.get(session_key)
|
||||
if entry and getattr(entry, "origin", None):
|
||||
return entry.origin
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Synthetic process-event session-store lookup failed for %s: %s",
|
||||
session_key,
|
||||
exc,
|
||||
)
|
||||
|
||||
_parsed = _parse_session_key(session_key)
|
||||
if _parsed:
|
||||
derived_platform = _parsed["platform"]
|
||||
derived_chat_type = _parsed["chat_type"]
|
||||
derived_chat_id = _parsed["chat_id"]
|
||||
|
||||
platform_name = str(evt.get("platform") or derived_platform or "").strip().lower()
|
||||
chat_type = str(evt.get("chat_type") or derived_chat_type or "").strip().lower()
|
||||
chat_id = str(evt.get("chat_id") or derived_chat_id or "").strip()
|
||||
if not platform_name or not chat_type or not chat_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
platform = Platform(platform_name)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Synthetic process event has invalid platform metadata: %r",
|
||||
platform_name,
|
||||
)
|
||||
return None
|
||||
|
||||
return SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
thread_id=str(evt.get("thread_id") or "").strip() or None,
|
||||
user_id=str(evt.get("user_id") or "").strip() or None,
|
||||
user_name=str(evt.get("user_name") or "").strip() or None,
|
||||
)
|
||||
|
||||
async def _inject_watch_notification(self, synth_text: str, evt: dict) -> None:
|
||||
async def _inject_watch_notification(self, synth_text: str, original_event) -> None:
|
||||
"""Inject a watch-pattern notification as a synthetic message event.
|
||||
|
||||
Routing must come from the queued watch event itself, not from whatever
|
||||
foreground message happened to be active when the queue was drained.
|
||||
Uses the source from the original user event to route the notification
|
||||
back to the correct chat/adapter.
|
||||
"""
|
||||
source = self._build_process_event_source(evt)
|
||||
source = getattr(original_event, "source", None)
|
||||
if not source:
|
||||
logger.warning(
|
||||
"Dropping watch notification with no routing metadata for process %s",
|
||||
evt.get("session_id", "unknown"),
|
||||
)
|
||||
return
|
||||
platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform)
|
||||
adapter = None
|
||||
@@ -7558,12 +7477,7 @@ class GatewayRunner:
|
||||
source=source,
|
||||
internal=True,
|
||||
)
|
||||
logger.info(
|
||||
"Watch pattern notification — injecting for %s chat=%s thread=%s",
|
||||
platform_name,
|
||||
source.chat_id,
|
||||
source.thread_id,
|
||||
)
|
||||
logger.info("Watch pattern notification — injecting for %s", platform_name)
|
||||
await adapter.handle_message(synth_event)
|
||||
except Exception as e:
|
||||
logger.error("Watch notification injection error: %s", e)
|
||||
@@ -7633,42 +7547,33 @@ class GatewayRunner:
|
||||
f"Command: {session.command}\n"
|
||||
f"Output:\n{_out}]"
|
||||
)
|
||||
source = self._build_process_event_source({
|
||||
"session_id": session_id,
|
||||
"session_key": session_key,
|
||||
"platform": platform_name,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": thread_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
})
|
||||
if not source:
|
||||
logger.warning(
|
||||
"Dropping completion notification with no routing metadata for process %s",
|
||||
session_id,
|
||||
)
|
||||
break
|
||||
|
||||
adapter = None
|
||||
for p, a in self.adapters.items():
|
||||
if p == source.platform:
|
||||
if p.value == platform_name:
|
||||
adapter = a
|
||||
break
|
||||
if adapter and source.chat_id:
|
||||
if adapter and chat_id:
|
||||
try:
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
from gateway.config import Platform
|
||||
_platform_enum = Platform(platform_name)
|
||||
_source = SessionSource(
|
||||
platform=_platform_enum,
|
||||
chat_id=chat_id,
|
||||
thread_id=thread_id or None,
|
||||
user_id=user_id or None,
|
||||
user_name=user_name or None,
|
||||
)
|
||||
synth_event = MessageEvent(
|
||||
text=synth_text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
source=_source,
|
||||
internal=True,
|
||||
)
|
||||
logger.info(
|
||||
"Process %s finished — injecting agent notification for session %s chat=%s thread=%s",
|
||||
session_id,
|
||||
session_key,
|
||||
source.chat_id,
|
||||
source.thread_id,
|
||||
"Process %s finished — injecting agent notification for session %s",
|
||||
session_id, session_key,
|
||||
)
|
||||
await adapter.handle_message(synth_event)
|
||||
except Exception as e:
|
||||
|
||||
@@ -167,6 +167,7 @@ def _resolve_runtime_from_pool_entry(
|
||||
api_mode = "chat_completions"
|
||||
elif provider == "copilot":
|
||||
api_mode = _copilot_runtime_api_mode(model_cfg, getattr(entry, "runtime_api_key", ""))
|
||||
base_url = base_url or PROVIDER_REGISTRY["copilot"].inference_base_url
|
||||
else:
|
||||
configured_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
# Honour model.base_url from config.yaml when the configured provider
|
||||
|
||||
@@ -10,9 +10,8 @@ lifecycle instead of read-only search endpoints.
|
||||
Config via environment variables (profile-scoped via each profile's .env):
|
||||
OPENVIKING_ENDPOINT — Server URL (default: http://127.0.0.1:1933)
|
||||
OPENVIKING_API_KEY — API key (required for authenticated servers)
|
||||
OPENVIKING_ACCOUNT — Tenant account (default: default)
|
||||
OPENVIKING_ACCOUNT — Tenant account (default: root)
|
||||
OPENVIKING_USER — Tenant user (default: default)
|
||||
OPENVIKING_AGENT — Tenant agent (default: hermes)
|
||||
|
||||
Capabilities:
|
||||
- Automatic memory extraction on session commit (6 categories)
|
||||
@@ -81,12 +80,11 @@ class _VikingClient:
|
||||
"""Thin HTTP client for the OpenViking REST API."""
|
||||
|
||||
def __init__(self, endpoint: str, api_key: str = "",
|
||||
account: str = "", user: str = "", agent: str = ""):
|
||||
account: str = "", user: str = ""):
|
||||
self._endpoint = endpoint.rstrip("/")
|
||||
self._api_key = api_key
|
||||
self._account = account or os.environ.get("OPENVIKING_ACCOUNT", "default")
|
||||
self._account = account or os.environ.get("OPENVIKING_ACCOUNT", "root")
|
||||
self._user = user or os.environ.get("OPENVIKING_USER", "default")
|
||||
self._agent = agent or os.environ.get("OPENVIKING_AGENT", "hermes")
|
||||
self._httpx = _get_httpx()
|
||||
if self._httpx is None:
|
||||
raise ImportError("httpx is required for OpenViking: pip install httpx")
|
||||
@@ -96,7 +94,6 @@ class _VikingClient:
|
||||
"Content-Type": "application/json",
|
||||
"X-OpenViking-Account": self._account,
|
||||
"X-OpenViking-User": self._user,
|
||||
"X-OpenViking-Agent": self._agent,
|
||||
}
|
||||
if self._api_key:
|
||||
h["X-API-Key"] = self._api_key
|
||||
@@ -285,44 +282,20 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
},
|
||||
{
|
||||
"key": "api_key",
|
||||
"description": "OpenViking API key (leave blank for local dev mode)",
|
||||
"description": "OpenViking API key",
|
||||
"secret": True,
|
||||
"env_var": "OPENVIKING_API_KEY",
|
||||
},
|
||||
{
|
||||
"key": "account",
|
||||
"description": "OpenViking tenant account ID ([default], used when local mode, OPENVIKING_API_KEY is empty)",
|
||||
"default": "default",
|
||||
"env_var": "OPENVIKING_ACCOUNT",
|
||||
},
|
||||
{
|
||||
"key": "user",
|
||||
"description": "OpenViking user ID within the account ([default], used when local mode, OPENVIKING_API_KEY is empty)",
|
||||
"default": "default",
|
||||
"env_var": "OPENVIKING_USER",
|
||||
},
|
||||
{
|
||||
"key": "agent",
|
||||
"description": "OpenViking agent ID within the account ([hermes], useful in multi-agent mode)",
|
||||
"default": "hermes",
|
||||
"env_var": "OPENVIKING_AGENT",
|
||||
},
|
||||
]
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._endpoint = os.environ.get("OPENVIKING_ENDPOINT", _DEFAULT_ENDPOINT)
|
||||
self._api_key = os.environ.get("OPENVIKING_API_KEY", "")
|
||||
self._account = os.environ.get("OPENVIKING_ACCOUNT", "default")
|
||||
self._user = os.environ.get("OPENVIKING_USER", "default")
|
||||
self._agent = os.environ.get("OPENVIKING_AGENT", "hermes")
|
||||
self._session_id = session_id
|
||||
self._turn_count = 0
|
||||
|
||||
try:
|
||||
self._client = _VikingClient(
|
||||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
self._client = _VikingClient(self._endpoint, self._api_key)
|
||||
if not self._client.health():
|
||||
logger.warning("OpenViking server at %s is not reachable", self._endpoint)
|
||||
self._client = None
|
||||
@@ -352,8 +325,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
"(abstract/overview/full), viking_browse to explore.\n"
|
||||
"Use viking_remember to store facts, viking_add_resource to index URLs/docs."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("OpenViking system_prompt_block failed: %s", e)
|
||||
except Exception:
|
||||
return (
|
||||
"# OpenViking Knowledge Base\n"
|
||||
f"Active. Endpoint: {self._endpoint}\n"
|
||||
@@ -379,10 +351,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
|
||||
def _run():
|
||||
try:
|
||||
client = _VikingClient(
|
||||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
client = _VikingClient(self._endpoint, self._api_key)
|
||||
resp = client.post("/api/v1/search/find", {
|
||||
"query": query,
|
||||
"top_k": 5,
|
||||
@@ -417,10 +386,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
client = _VikingClient(
|
||||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
client = _VikingClient(self._endpoint, self._api_key)
|
||||
sid = self._session_id
|
||||
|
||||
# Add user message
|
||||
@@ -476,10 +442,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
||||
|
||||
def _write():
|
||||
try:
|
||||
client = _VikingClient(
|
||||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
client = _VikingClient(self._endpoint, self._api_key)
|
||||
# Add as a user message with memory context so the commit
|
||||
# picks it up as an explicit memory during extraction
|
||||
client.post(f"/api/v1/sessions/{self._session_id}/messages", {
|
||||
|
||||
144
run_agent.py
144
run_agent.py
@@ -754,7 +754,6 @@ class AIAgent:
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None # Optional message that triggered interrupt
|
||||
self._execution_thread_id: int | None = None # Set at run_conversation() start
|
||||
self._interrupt_thread_signal_pending = False
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# Subagent delegation state
|
||||
@@ -2950,15 +2949,7 @@ class AIAgent:
|
||||
# Signal all tools to abort any in-flight operations immediately.
|
||||
# Scope the interrupt to this agent's execution thread so other
|
||||
# agents running in the same process (gateway) are not affected.
|
||||
if self._execution_thread_id is not None:
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
self._interrupt_thread_signal_pending = False
|
||||
else:
|
||||
# The interrupt arrived before run_conversation() finished
|
||||
# binding the agent to its execution thread. Defer the tool-level
|
||||
# interrupt signal until startup completes instead of targeting
|
||||
# the caller thread by mistake.
|
||||
self._interrupt_thread_signal_pending = True
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
# Propagate interrupt to any running child agents (subagent delegation)
|
||||
with self._active_children_lock:
|
||||
children_copy = list(self._active_children)
|
||||
@@ -2974,9 +2965,7 @@ class AIAgent:
|
||||
"""Clear any pending interrupt request and the per-thread tool interrupt signal."""
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None
|
||||
self._interrupt_thread_signal_pending = False
|
||||
if self._execution_thread_id is not None:
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
@@ -3051,18 +3040,6 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def commit_memory_session(self, messages: list = None) -> None:
|
||||
"""Trigger end-of-session extraction without tearing providers down.
|
||||
Called when session_id rotates (e.g. /new, context compression);
|
||||
providers keep their state and continue running under the old
|
||||
session_id — they just flush pending extraction now."""
|
||||
if not self._memory_manager:
|
||||
return
|
||||
try:
|
||||
self._memory_manager.on_session_end(messages or [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release all resources held by this agent instance.
|
||||
|
||||
@@ -5533,27 +5510,9 @@ class AIAgent:
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
_last_heartbeat = time.time()
|
||||
_HEARTBEAT_INTERVAL = 30.0 # seconds between gateway activity touches
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
|
||||
# Periodic heartbeat: touch the agent's activity tracker so the
|
||||
# gateway's inactivity monitor knows we're alive while waiting
|
||||
# for stream chunks. Without this, long thinking pauses (e.g.
|
||||
# reasoning models) or slow prefill on local providers (Ollama)
|
||||
# trigger false inactivity timeouts. The _call thread touches
|
||||
# activity on each chunk, but the gap between API call start
|
||||
# and first chunk can exceed the gateway timeout — especially
|
||||
# when the stale-stream timeout is disabled (local providers).
|
||||
_hb_now = time.time()
|
||||
if _hb_now - _last_heartbeat >= _HEARTBEAT_INTERVAL:
|
||||
_last_heartbeat = _hb_now
|
||||
_waiting_secs = int(_hb_now - last_chunk_time["t"])
|
||||
self._touch_activity(
|
||||
f"waiting for stream response ({_waiting_secs}s, no chunks yet)"
|
||||
)
|
||||
|
||||
# Detect stale streams: connections kept alive by SSE pings
|
||||
# but delivering no real chunks. Kill the client so the
|
||||
# inner retry loop can start a fresh connection.
|
||||
@@ -6867,8 +6826,6 @@ class AIAgent:
|
||||
try:
|
||||
# Propagate title to the new session with auto-numbering
|
||||
old_title = self._session_db.get_session_title(self.session_id)
|
||||
# Trigger memory extraction on the old session before it rotates.
|
||||
self.commit_memory_session(messages)
|
||||
self._session_db.end_session(self.session_id, "compression")
|
||||
old_session_id = self.session_id
|
||||
self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||
@@ -7170,22 +7127,8 @@ class AIAgent:
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
|
||||
results = [None] * num_tools
|
||||
|
||||
# Touch activity before launching workers so the gateway knows
|
||||
# we're executing tools (not stuck).
|
||||
self._current_tool = tool_names_str
|
||||
self._touch_activity(f"executing {num_tools} tools concurrently: {tool_names_str}")
|
||||
|
||||
def _run_tool(index, tool_call, function_name, function_args):
|
||||
"""Worker function executed in a thread."""
|
||||
# Set the activity callback on THIS worker thread so
|
||||
# _wait_for_process (terminal commands) can fire heartbeats.
|
||||
# The callback is thread-local; the main thread's callback
|
||||
# is invisible to worker threads.
|
||||
try:
|
||||
from tools.environments.base import set_activity_callback
|
||||
set_activity_callback(self._touch_activity)
|
||||
except Exception:
|
||||
pass
|
||||
start = time.time()
|
||||
try:
|
||||
result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id)
|
||||
@@ -7215,26 +7158,8 @@ class AIAgent:
|
||||
f = executor.submit(_run_tool, i, tc, name, args)
|
||||
futures.append(f)
|
||||
|
||||
# Wait for all to complete with periodic heartbeats so the
|
||||
# gateway's inactivity monitor doesn't kill us during long
|
||||
# concurrent tool batches.
|
||||
_conc_start = time.time()
|
||||
while True:
|
||||
done, not_done = concurrent.futures.wait(
|
||||
futures, timeout=30.0,
|
||||
)
|
||||
if not not_done:
|
||||
break
|
||||
_conc_elapsed = int(time.time() - _conc_start)
|
||||
_still_running = [
|
||||
parsed_calls[futures.index(f)][1]
|
||||
for f in not_done
|
||||
if f in futures
|
||||
]
|
||||
self._touch_activity(
|
||||
f"concurrent tools running ({_conc_elapsed}s, "
|
||||
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
|
||||
)
|
||||
# Wait for all to complete (exceptions are captured inside _run_tool)
|
||||
concurrent.futures.wait(futures)
|
||||
finally:
|
||||
if spinner:
|
||||
# Build a summary message for the spinner stop
|
||||
@@ -7466,16 +7391,6 @@ class AIAgent:
|
||||
old_text=function_args.get("old_text"),
|
||||
store=self._memory_store,
|
||||
)
|
||||
# Bridge: notify external memory provider of built-in memory writes
|
||||
if self._memory_manager and function_args.get("action") in ("add", "replace"):
|
||||
try:
|
||||
self._memory_manager.on_memory_write(
|
||||
function_args.get("action", ""),
|
||||
target,
|
||||
function_args.get("content", ""),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}")
|
||||
@@ -7952,7 +7867,6 @@ class AIAgent:
|
||||
self._thinking_prefill_retries = 0
|
||||
self._post_tool_empty_retried = False
|
||||
self._last_content_with_tools = None
|
||||
self._last_content_tools_all_housekeeping = False
|
||||
self._mute_post_response = False
|
||||
self._unicode_sanitization_passes = 0
|
||||
|
||||
@@ -8140,7 +8054,6 @@ class AIAgent:
|
||||
self._empty_content_retries = 0
|
||||
self._thinking_prefill_retries = 0
|
||||
self._last_content_with_tools = None
|
||||
self._last_content_tools_all_housekeeping = False
|
||||
self._mute_post_response = False
|
||||
# Re-estimate after compression
|
||||
_preflight_tokens = estimate_request_tokens_rough(
|
||||
@@ -8200,19 +8113,11 @@ class AIAgent:
|
||||
|
||||
# Record the execution thread so interrupt()/clear_interrupt() can
|
||||
# scope the tool-level interrupt signal to THIS agent's thread only.
|
||||
# Must be set before any thread-scoped interrupt syncing.
|
||||
# Must be set before clear_interrupt() which uses it.
|
||||
self._execution_thread_id = threading.current_thread().ident
|
||||
|
||||
# Always clear stale per-thread state from a previous turn. If an
|
||||
# interrupt arrived before startup finished, preserve it and bind it
|
||||
# to this execution thread now instead of dropping it on the floor.
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
if self._interrupt_requested:
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
self._interrupt_thread_signal_pending = False
|
||||
else:
|
||||
self._interrupt_message = None
|
||||
self._interrupt_thread_signal_pending = False
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
|
||||
# External memory provider: prefetch once before the tool loop.
|
||||
# Reuse the cached result on every iteration to avoid re-calling
|
||||
@@ -10211,7 +10116,6 @@ class AIAgent:
|
||||
tc.function.name in _HOUSEKEEPING_TOOLS
|
||||
for tc in assistant_message.tool_calls
|
||||
)
|
||||
self._last_content_tools_all_housekeeping = _all_housekeeping
|
||||
if _all_housekeeping and self._has_stream_consumers():
|
||||
self._mute_post_response = True
|
||||
elif self.quiet_mode:
|
||||
@@ -10394,22 +10298,15 @@ class AIAgent:
|
||||
break
|
||||
|
||||
# If the previous turn already delivered real content alongside
|
||||
# HOUSEKEEPING tool calls (e.g. "You're welcome!" + memory save),
|
||||
# the model has nothing more to say. Use the earlier content
|
||||
# immediately instead of wasting API calls on retries.
|
||||
# NOTE: Only use this shortcut when ALL tools in that turn were
|
||||
# housekeeping (memory, todo, etc.). When substantive tools
|
||||
# were called (terminal, search_files, etc.), the content was
|
||||
# likely mid-task narration ("I'll scan the directory...") and
|
||||
# the empty follow-up means the model choked — let the
|
||||
# post-tool nudge below handle that instead of exiting early.
|
||||
# tool calls (e.g. "You're welcome!" + memory save), the model
|
||||
# has nothing more to say. Use the earlier content immediately
|
||||
# instead of wasting API calls on retries that won't help.
|
||||
fallback = getattr(self, '_last_content_with_tools', None)
|
||||
if fallback and getattr(self, '_last_content_tools_all_housekeeping', False):
|
||||
if fallback:
|
||||
_turn_exit_reason = "fallback_prior_turn_content"
|
||||
logger.info("Empty follow-up after tool calls — using prior turn content as final response")
|
||||
self._emit_status("↻ Empty response after tool calls — using earlier content as final answer")
|
||||
self._last_content_with_tools = None
|
||||
self._last_content_tools_all_housekeeping = False
|
||||
self._empty_content_retries = 0
|
||||
# Do NOT modify the assistant message content — the
|
||||
# old code injected "Calling the X tools..." which
|
||||
@@ -10420,18 +10317,13 @@ class AIAgent:
|
||||
break
|
||||
|
||||
# ── Post-tool-call empty response nudge ───────────
|
||||
# The model returned empty after executing tool calls.
|
||||
# This covers two cases:
|
||||
# (a) No prior-turn content at all — model went silent
|
||||
# (b) Prior turn had content + SUBSTANTIVE tools (the
|
||||
# fallback above was skipped because the content
|
||||
# was mid-task narration, not a final answer)
|
||||
# The model returned empty after executing tool calls
|
||||
# but there's no prior-turn content to fall back on.
|
||||
# Instead of giving up, nudge the model to continue by
|
||||
# appending a user-level hint. This is the #9400 case:
|
||||
# weaker models (mimo-v2-pro, GLM-5, etc.) sometimes
|
||||
# return empty after tool results instead of continuing
|
||||
# to the next step. One retry with a nudge usually
|
||||
# fixes it.
|
||||
# weaker models (GLM-5, etc.) sometimes return empty
|
||||
# after tool results instead of continuing to the next
|
||||
# step. One retry with a nudge usually fixes it.
|
||||
_prior_was_tool = any(
|
||||
m.get("role") == "tool"
|
||||
for m in messages[-5:] # check recent messages
|
||||
@@ -10441,10 +10333,6 @@ class AIAgent:
|
||||
and not getattr(self, "_post_tool_empty_retried", False)
|
||||
):
|
||||
self._post_tool_empty_retried = True
|
||||
# Clear stale narration so it doesn't resurface
|
||||
# on a later empty response after the nudge.
|
||||
self._last_content_with_tools = None
|
||||
self._last_content_tools_all_housekeeping = False
|
||||
logger.info(
|
||||
"Empty response after tool calls — nudging model "
|
||||
"to continue processing"
|
||||
|
||||
@@ -197,8 +197,6 @@ AUTHOR_MAP = {
|
||||
"zhouboli@gmail.com": "zhouboli",
|
||||
"zqiao@microsoft.com": "tomqiaozc",
|
||||
"zzn+pa@zzn.im": "xinbenlv",
|
||||
"zaynjarvis@gmail.com": "ZaynJarvis",
|
||||
"zhiheng.liu@bytedance.com": "ZaynJarvis",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ class TestResolveVisionProviderClientModelNormalization:
|
||||
|
||||
assert provider == "zai"
|
||||
assert client is not None
|
||||
assert model == "glm-5.1"
|
||||
assert model == "glm-5v-turbo" # zai has dedicated vision model in _PROVIDER_VISION_MODELS
|
||||
|
||||
|
||||
class TestVisionPathApiMode:
|
||||
|
||||
@@ -695,102 +695,3 @@ class TestMemoryContextFencing:
|
||||
fence_end = combined.index("</memory-context>")
|
||||
assert "Alice" in combined[fence_start:fence_end]
|
||||
assert combined.index("weather") < fence_start
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIAgent.commit_memory_session — routes to MemoryManager.on_session_end
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _CommitRecorder(FakeMemoryProvider):
|
||||
"""Provider that records on_session_end calls for assertions."""
|
||||
|
||||
def __init__(self, name="recorder"):
|
||||
super().__init__(name)
|
||||
self.end_calls = []
|
||||
|
||||
def on_session_end(self, messages):
|
||||
self.end_calls.append(list(messages or []))
|
||||
|
||||
|
||||
class TestCommitMemorySessionRouting:
|
||||
def test_on_session_end_fans_out(self):
|
||||
mgr = MemoryManager()
|
||||
builtin = _CommitRecorder("builtin")
|
||||
external = _CommitRecorder("openviking")
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(external)
|
||||
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
mgr.on_session_end(msgs)
|
||||
|
||||
assert builtin.end_calls == [msgs]
|
||||
assert external.end_calls == [msgs]
|
||||
|
||||
def test_on_session_end_tolerates_failure(self):
|
||||
mgr = MemoryManager()
|
||||
builtin = FakeMemoryProvider("builtin")
|
||||
bad = _CommitRecorder("bad-provider")
|
||||
bad.on_session_end = lambda m: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(bad)
|
||||
|
||||
mgr.on_session_end([]) # must not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_memory_write bridge — must fire from both concurrent AND sequential paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOnMemoryWriteBridge:
|
||||
"""Verify that MemoryManager.on_memory_write is called when built-in
|
||||
memory writes happen. This is a regression test for #10174 where the
|
||||
sequential tool execution path (_execute_tool_calls_sequential) was
|
||||
missing the bridge call, so single memory tool calls never notified
|
||||
external memory providers.
|
||||
"""
|
||||
|
||||
def test_on_memory_write_add(self):
|
||||
"""on_memory_write fires for 'add' actions."""
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.on_memory_write("add", "memory", "new fact")
|
||||
assert p.memory_writes == [("add", "memory", "new fact")]
|
||||
|
||||
def test_on_memory_write_replace(self):
|
||||
"""on_memory_write fires for 'replace' actions."""
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.on_memory_write("replace", "user", "updated pref")
|
||||
assert p.memory_writes == [("replace", "user", "updated pref")]
|
||||
|
||||
def test_on_memory_write_remove_not_bridged(self):
|
||||
"""The bridge intentionally skips 'remove' — only add/replace notify."""
|
||||
# This tests the contract that run_agent.py checks:
|
||||
# function_args.get("action") in ("add", "replace")
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
# Manager itself doesn't filter — run_agent.py does.
|
||||
# But providers should handle remove gracefully.
|
||||
mgr.on_memory_write("remove", "memory", "old fact")
|
||||
assert p.memory_writes == [("remove", "memory", "old fact")]
|
||||
|
||||
def test_on_memory_write_tolerates_provider_failure(self):
|
||||
"""If a provider's on_memory_write raises, others still get notified."""
|
||||
mgr = MemoryManager()
|
||||
bad = FakeMemoryProvider("builtin")
|
||||
bad.on_memory_write = MagicMock(side_effect=RuntimeError("boom"))
|
||||
good = FakeMemoryProvider("good")
|
||||
mgr.add_provider(bad)
|
||||
mgr.add_provider(good)
|
||||
|
||||
mgr.on_memory_write("add", "user", "test")
|
||||
# Good provider still received the call despite bad provider crashing
|
||||
assert good.memory_writes == [("add", "user", "test")]
|
||||
|
||||
@@ -8,8 +8,6 @@ from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, _send_media_via_adapter, run_job, SILENT_MARKER, _build_job_prompt
|
||||
from tools.env_passthrough import clear_env_passthrough
|
||||
from tools.credential_files import clear_credential_files
|
||||
|
||||
|
||||
class TestResolveOrigin:
|
||||
@@ -879,117 +877,6 @@ class TestRunJobPerJobOverrides:
|
||||
|
||||
|
||||
class TestRunJobSkillBacked:
|
||||
def test_run_job_preserves_skill_env_passthrough_into_worker_thread(self, tmp_path):
|
||||
job = {
|
||||
"id": "skill-env-job",
|
||||
"name": "skill env test",
|
||||
"prompt": "Use the skill.",
|
||||
"skill": "notion",
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
|
||||
def _skill_view(name):
|
||||
assert name == "notion"
|
||||
from tools.env_passthrough import register_env_passthrough
|
||||
|
||||
register_env_passthrough(["NOTION_API_KEY"])
|
||||
return json.dumps({"success": True, "content": "# notion\nUse Notion."})
|
||||
|
||||
def _run_conversation(prompt):
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
|
||||
assert "NOTION_API_KEY" in get_all_passthrough()
|
||||
return {"final_response": "ok"}
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("tools.skills_tool.skill_view", side_effect=_skill_view), \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.side_effect = _run_conversation
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
try:
|
||||
success, output, final_response, error = run_job(job)
|
||||
finally:
|
||||
clear_env_passthrough()
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
|
||||
def test_run_job_preserves_credential_file_passthrough_into_worker_thread(self, tmp_path):
|
||||
"""copy_context() also propagates credential_files ContextVar."""
|
||||
job = {
|
||||
"id": "cred-env-job",
|
||||
"name": "cred file test",
|
||||
"prompt": "Use the skill.",
|
||||
"skill": "google-workspace",
|
||||
}
|
||||
|
||||
fake_db = MagicMock()
|
||||
|
||||
# Create a credential file so register_credential_file succeeds
|
||||
cred_dir = tmp_path / "credentials"
|
||||
cred_dir.mkdir()
|
||||
(cred_dir / "google_token.json").write_text('{"token": "t"}')
|
||||
|
||||
def _skill_view(name):
|
||||
assert name == "google-workspace"
|
||||
from tools.credential_files import register_credential_file
|
||||
|
||||
register_credential_file("credentials/google_token.json")
|
||||
return json.dumps({"success": True, "content": "# google-workspace\nUse Google."})
|
||||
|
||||
def _run_conversation(prompt):
|
||||
from tools.credential_files import _get_registered
|
||||
|
||||
registered = _get_registered()
|
||||
assert registered, "credential files must be visible in worker thread"
|
||||
assert any("google_token.json" in v for v in registered.values())
|
||||
return {"final_response": "ok"}
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("tools.credential_files._resolve_hermes_home", return_value=tmp_path), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("tools.skills_tool.skill_view", side_effect=_skill_view), \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.side_effect = _run_conversation
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
try:
|
||||
success, output, final_response, error = run_job(job)
|
||||
finally:
|
||||
clear_credential_files()
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "ok"
|
||||
|
||||
def test_run_job_loads_skill_and_disables_recursive_cron_tools(self, tmp_path):
|
||||
job = {
|
||||
"id": "skill-job",
|
||||
|
||||
66
tests/gateway/conftest.py
Normal file
66
tests/gateway/conftest.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Shared fixtures for gateway tests.
|
||||
|
||||
The ``_ensure_telegram_mock`` helper guarantees that a minimal mock of
|
||||
the ``telegram`` package is registered in :data:`sys.modules` **before**
|
||||
any test file triggers ``from gateway.platforms.telegram import ...``.
|
||||
|
||||
Without this, ``pytest-xdist`` workers that happen to collect
|
||||
``test_telegram_caption_merge.py`` (bare top-level import, no per-file
|
||||
mock) first will cache ``ChatType = None`` from the production
|
||||
ImportError fallback, causing 30+ downstream test failures wherever
|
||||
``ChatType.GROUP`` / ``ChatType.SUPERGROUP`` is accessed.
|
||||
|
||||
Individual test files may still call their own ``_ensure_telegram_mock``
|
||||
— it short-circuits when the mock is already present.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def _ensure_telegram_mock() -> None:
|
||||
"""Install a comprehensive telegram mock in sys.modules.
|
||||
|
||||
Idempotent — skips when the real library is already imported.
|
||||
Uses ``sys.modules[name] = mod`` (overwrite) instead of
|
||||
``setdefault`` so it wins even if a partial/broken import
|
||||
already cached a module with ``ChatType = None``.
|
||||
"""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return # Real library is installed — nothing to mock
|
||||
|
||||
mod = MagicMock()
|
||||
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
mod.constants.ParseMode.MARKDOWN = "Markdown"
|
||||
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
mod.constants.ParseMode.HTML = "HTML"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
mod.constants.ChatType.GROUP = "group"
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
|
||||
# Real exception classes so ``except (NetworkError, ...)`` clauses
|
||||
# in production code don't blow up with TypeError.
|
||||
mod.error.NetworkError = type("NetworkError", (OSError,), {})
|
||||
mod.error.TimedOut = type("TimedOut", (OSError,), {})
|
||||
mod.error.BadRequest = type("BadRequest", (Exception,), {})
|
||||
mod.error.Forbidden = type("Forbidden", (Exception,), {})
|
||||
mod.error.InvalidToken = type("InvalidToken", (Exception,), {})
|
||||
mod.error.RetryAfter = type("RetryAfter", (Exception,), {"retry_after": 1})
|
||||
mod.error.Conflict = type("Conflict", (Exception,), {})
|
||||
|
||||
# Update.ALL_TYPES used in start_polling()
|
||||
mod.Update.ALL_TYPES = []
|
||||
|
||||
for name in (
|
||||
"telegram",
|
||||
"telegram.ext",
|
||||
"telegram.constants",
|
||||
"telegram.request",
|
||||
):
|
||||
sys.modules[name] = mod
|
||||
sys.modules["telegram.error"] = mod.error
|
||||
|
||||
|
||||
# Run at collection time — before any test file's module-level imports.
|
||||
_ensure_telegram_mock()
|
||||
@@ -14,7 +14,7 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform
|
||||
from gateway.run import GatewayRunner, _parse_session_key
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -45,7 +45,7 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
return runner
|
||||
|
||||
@@ -243,162 +243,3 @@ async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
|
||||
assert adapter.send.await_count == 1
|
||||
_, kwargs = adapter.send.call_args
|
||||
assert kwargs["metadata"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_watch_notification_routes_from_session_store_origin(monkeypatch, tmp_path):
|
||||
from gateway.session import SessionSource
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
|
||||
origin=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
chat_type="group",
|
||||
thread_id="42",
|
||||
user_id="123",
|
||||
user_name="Emiliyan",
|
||||
)
|
||||
)
|
||||
|
||||
evt = {
|
||||
"session_id": "proc_watch",
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
}
|
||||
|
||||
await runner._inject_watch_notification("[SYSTEM: Background process matched]", evt)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
synth_event = adapter.handle_message.await_args.args[0]
|
||||
assert synth_event.internal is True
|
||||
assert synth_event.source.platform == Platform.TELEGRAM
|
||||
assert synth_event.source.chat_id == "-100"
|
||||
assert synth_event.source.chat_type == "group"
|
||||
assert synth_event.source.thread_id == "42"
|
||||
assert synth_event.source.user_id == "123"
|
||||
assert synth_event.source.user_name == "Emiliyan"
|
||||
|
||||
|
||||
def test_build_process_event_source_falls_back_to_session_key_chat_type(monkeypatch, tmp_path):
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
|
||||
evt = {
|
||||
"session_id": "proc_watch",
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
"platform": "telegram",
|
||||
"chat_id": "-100",
|
||||
"thread_id": "42",
|
||||
"user_id": "123",
|
||||
"user_name": "Emiliyan",
|
||||
}
|
||||
|
||||
source = runner._build_process_event_source(evt)
|
||||
|
||||
assert source is not None
|
||||
assert source.platform == Platform.TELEGRAM
|
||||
assert source.chat_id == "-100"
|
||||
assert source.chat_type == "group"
|
||||
assert source.thread_id == "42"
|
||||
assert source.user_id == "123"
|
||||
assert source.user_name == "Emiliyan"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_watch_notification_ignores_foreground_event_source(monkeypatch, tmp_path):
|
||||
"""Negative test: watch notification must NOT route to the foreground thread."""
|
||||
from gateway.session import SessionSource
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
# Session store has the process's original thread (thread 42)
|
||||
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
|
||||
origin=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
chat_type="group",
|
||||
thread_id="42",
|
||||
user_id="proc_owner",
|
||||
user_name="alice",
|
||||
)
|
||||
)
|
||||
|
||||
# The evt dict carries the correct session_key — NOT a foreground event
|
||||
evt = {
|
||||
"session_id": "proc_cross_thread",
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
}
|
||||
|
||||
await runner._inject_watch_notification("[SYSTEM: watch match]", evt)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
synth_event = adapter.handle_message.await_args.args[0]
|
||||
# Must route to thread 42 (process origin), NOT some other thread
|
||||
assert synth_event.source.thread_id == "42"
|
||||
assert synth_event.source.user_id == "proc_owner"
|
||||
|
||||
|
||||
def test_build_process_event_source_returns_none_for_empty_evt(monkeypatch, tmp_path):
|
||||
"""Missing session_key and no platform metadata → None (drop notification)."""
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
|
||||
source = runner._build_process_event_source({"session_id": "proc_orphan"})
|
||||
assert source is None
|
||||
|
||||
|
||||
def test_build_process_event_source_returns_none_for_invalid_platform(monkeypatch, tmp_path):
|
||||
"""Invalid platform string → None."""
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
|
||||
evt = {
|
||||
"session_id": "proc_bad",
|
||||
"platform": "not_a_real_platform",
|
||||
"chat_type": "dm",
|
||||
"chat_id": "123",
|
||||
}
|
||||
source = runner._build_process_event_source(evt)
|
||||
assert source is None
|
||||
|
||||
|
||||
def test_build_process_event_source_returns_none_for_short_session_key(monkeypatch, tmp_path):
|
||||
"""Session key with <5 parts doesn't parse, falls through to empty metadata → None."""
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
|
||||
evt = {
|
||||
"session_id": "proc_short",
|
||||
"session_key": "agent:main:telegram", # Too few parts
|
||||
}
|
||||
source = runner._build_process_event_source(evt)
|
||||
assert source is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_session_key helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_parse_session_key_valid():
|
||||
result = _parse_session_key("agent:main:telegram:group:-100")
|
||||
assert result == {"platform": "telegram", "chat_type": "group", "chat_id": "-100"}
|
||||
|
||||
|
||||
def test_parse_session_key_with_extra_parts():
|
||||
"""Thread ID (6th part) is extracted; further parts are ignored."""
|
||||
result = _parse_session_key("agent:main:discord:group:chan123:thread456")
|
||||
assert result == {"platform": "discord", "chat_type": "group", "chat_id": "chan123", "thread_id": "thread456"}
|
||||
|
||||
|
||||
def test_parse_session_key_with_user_id_part():
|
||||
"""7th part (user_id) is ignored — only up to thread_id is extracted."""
|
||||
result = _parse_session_key("agent:main:telegram:group:chat1:thread42:user99")
|
||||
assert result == {"platform": "telegram", "chat_type": "group", "chat_id": "chat1", "thread_id": "thread42"}
|
||||
|
||||
|
||||
def test_parse_session_key_too_short():
|
||||
assert _parse_session_key("agent:main:telegram") is None
|
||||
assert _parse_session_key("") is None
|
||||
|
||||
|
||||
def test_parse_session_key_wrong_prefix():
|
||||
assert _parse_session_key("cron:main:telegram:dm:123") is None
|
||||
assert _parse_session_key("agent:cron:telegram:dm:123") is None
|
||||
|
||||
@@ -230,59 +230,6 @@ async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path)
|
||||
assert event.source.user_name == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_complete_uses_session_store_origin_for_group_topic(monkeypatch, tmp_path):
|
||||
import tools.process_registry as pr_module
|
||||
from gateway.session import SessionSource
|
||||
|
||||
sessions = [
|
||||
SimpleNamespace(
|
||||
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock(), handle_message=AsyncMock())
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
runner.session_store._entries["agent:main:telegram:group:-100:42"] = SimpleNamespace(
|
||||
origin=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
chat_type="group",
|
||||
thread_id="42",
|
||||
user_id="user-42",
|
||||
user_name="alice",
|
||||
)
|
||||
)
|
||||
|
||||
watcher = {
|
||||
"session_id": "proc_test_internal",
|
||||
"check_interval": 0,
|
||||
"session_key": "agent:main:telegram:group:-100:42",
|
||||
"platform": "telegram",
|
||||
"chat_id": "-100",
|
||||
"thread_id": "42",
|
||||
"notify_on_complete": True,
|
||||
}
|
||||
|
||||
await runner._run_process_watcher(watcher)
|
||||
|
||||
assert adapter.handle_message.await_count == 1
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.internal is True
|
||||
assert event.source.platform == Platform.TELEGRAM
|
||||
assert event.source.chat_id == "-100"
|
||||
assert event.source.chat_type == "group"
|
||||
assert event.source.thread_id == "42"
|
||||
assert event.source.user_id == "user-42"
|
||||
assert event.source.user_name == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_user_id_skips_pairing(monkeypatch, tmp_path):
|
||||
"""A non-internal event with user_id=None should be silently dropped."""
|
||||
|
||||
@@ -613,6 +613,7 @@ class TestDetectVenvDir:
|
||||
# Not inside a virtualenv
|
||||
monkeypatch.setattr("sys.prefix", "/usr")
|
||||
monkeypatch.setattr("sys.base_prefix", "/usr")
|
||||
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
|
||||
monkeypatch.setattr(gateway_cli, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
dot_venv = tmp_path / ".venv"
|
||||
@@ -624,6 +625,7 @@ class TestDetectVenvDir:
|
||||
def test_falls_back_to_venv_directory(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.prefix", "/usr")
|
||||
monkeypatch.setattr("sys.base_prefix", "/usr")
|
||||
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
|
||||
monkeypatch.setattr(gateway_cli, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
venv = tmp_path / "venv"
|
||||
@@ -635,6 +637,7 @@ class TestDetectVenvDir:
|
||||
def test_prefers_dot_venv_over_venv(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.prefix", "/usr")
|
||||
monkeypatch.setattr("sys.base_prefix", "/usr")
|
||||
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
|
||||
monkeypatch.setattr(gateway_cli, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
(tmp_path / ".venv").mkdir()
|
||||
@@ -646,6 +649,7 @@ class TestDetectVenvDir:
|
||||
def test_returns_none_when_no_virtualenv(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.prefix", "/usr")
|
||||
monkeypatch.setattr("sys.base_prefix", "/usr")
|
||||
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
|
||||
monkeypatch.setattr(gateway_cli, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
result = gateway_cli._detect_venv_dir()
|
||||
|
||||
@@ -103,7 +103,7 @@ class TestCleanupStaleAsyncClients:
|
||||
mock_client._client = MagicMock()
|
||||
mock_client._client.is_closed = False
|
||||
|
||||
key = ("test_stale", True, "", "", "", ())
|
||||
key = ("test_stale", True, "", "", id(loop))
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", loop)
|
||||
|
||||
@@ -127,7 +127,7 @@ class TestCleanupStaleAsyncClients:
|
||||
loop = asyncio.new_event_loop() # NOT closed
|
||||
|
||||
mock_client = MagicMock()
|
||||
key = ("test_live", True, "", "", "", ())
|
||||
key = ("test_live", True, "", "", id(loop))
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", loop)
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestCleanupStaleAsyncClients:
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
key = ("test_sync", False, "", "", "", ())
|
||||
key = ("test_sync", False, "", "", 0)
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", None)
|
||||
|
||||
@@ -160,131 +160,3 @@ class TestCleanupStaleAsyncClients:
|
||||
finally:
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache bounded growth (#10200)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClientCacheBoundedGrowth:
|
||||
"""Verify the cache stays bounded when loops change (fix for #10200).
|
||||
|
||||
Previously, loop_id was part of the cache key, so every new event loop
|
||||
created a new entry for the same provider config. Now loop identity is
|
||||
validated at hit time and stale entries are replaced in-place.
|
||||
"""
|
||||
|
||||
def test_same_key_replaces_stale_loop_entry(self):
|
||||
"""When the loop changes, the old entry should be replaced, not duplicated."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
_get_cached_client,
|
||||
)
|
||||
|
||||
key = ("test_replace", True, "", "", "", ())
|
||||
|
||||
# Simulate a stale entry from a closed loop
|
||||
old_loop = asyncio.new_event_loop()
|
||||
old_loop.close()
|
||||
old_client = MagicMock()
|
||||
old_client._client = MagicMock()
|
||||
old_client._client.is_closed = False
|
||||
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (old_client, "old-model", old_loop)
|
||||
|
||||
try:
|
||||
# Now call _get_cached_client — should detect stale loop and evict
|
||||
with patch("agent.auxiliary_client.resolve_provider_client") as mock_resolve:
|
||||
mock_resolve.return_value = (MagicMock(), "new-model")
|
||||
client, model = _get_cached_client(
|
||||
"test_replace", async_mode=True,
|
||||
)
|
||||
# The old entry should have been replaced
|
||||
with _client_cache_lock:
|
||||
assert key in _client_cache, "Key should still exist (replaced)"
|
||||
entry = _client_cache[key]
|
||||
assert entry[1] == "new-model", "Should have the new model"
|
||||
finally:
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
|
||||
def test_different_loops_do_not_grow_cache(self):
|
||||
"""Multiple event loops for the same provider should NOT create multiple entries."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
)
|
||||
|
||||
key = ("test_no_grow", True, "", "", "", ())
|
||||
|
||||
loops = []
|
||||
try:
|
||||
for i in range(5):
|
||||
loop = asyncio.new_event_loop()
|
||||
loops.append(loop)
|
||||
mock_client = MagicMock()
|
||||
mock_client._client = MagicMock()
|
||||
mock_client._client.is_closed = False
|
||||
|
||||
# Close previous loop entries (simulating worker thread recycling)
|
||||
if i > 0:
|
||||
loops[i - 1].close()
|
||||
|
||||
with _client_cache_lock:
|
||||
# Simulate what _get_cached_client does: replace on loop mismatch
|
||||
if key in _client_cache:
|
||||
old_entry = _client_cache[key]
|
||||
del _client_cache[key]
|
||||
_client_cache[key] = (mock_client, f"model-{i}", loop)
|
||||
|
||||
# Only one entry should exist for this key
|
||||
with _client_cache_lock:
|
||||
count = sum(1 for k in _client_cache if k == key)
|
||||
assert count == 1, f"Expected 1 entry, got {count}"
|
||||
finally:
|
||||
for loop in loops:
|
||||
if not loop.is_closed():
|
||||
loop.close()
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
|
||||
def test_max_cache_size_eviction(self):
|
||||
"""Cache should not exceed _CLIENT_CACHE_MAX_SIZE."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
_CLIENT_CACHE_MAX_SIZE,
|
||||
)
|
||||
|
||||
# Save existing cache state
|
||||
with _client_cache_lock:
|
||||
saved = dict(_client_cache)
|
||||
_client_cache.clear()
|
||||
|
||||
try:
|
||||
# Fill to max + 5
|
||||
for i in range(_CLIENT_CACHE_MAX_SIZE + 5):
|
||||
mock_client = MagicMock()
|
||||
mock_client._client = MagicMock()
|
||||
mock_client._client.is_closed = False
|
||||
key = (f"evict_test_{i}", False, "", "", "", ())
|
||||
with _client_cache_lock:
|
||||
# Inline the eviction logic (same as _get_cached_client)
|
||||
while len(_client_cache) >= _CLIENT_CACHE_MAX_SIZE:
|
||||
evict_key = next(iter(_client_cache))
|
||||
del _client_cache[evict_key]
|
||||
_client_cache[key] = (mock_client, f"model-{i}", None)
|
||||
|
||||
with _client_cache_lock:
|
||||
assert len(_client_cache) <= _CLIENT_CACHE_MAX_SIZE, \
|
||||
f"Cache size {len(_client_cache)} exceeds max {_CLIENT_CACHE_MAX_SIZE}"
|
||||
# The earliest entries should have been evicted
|
||||
assert ("evict_test_0", False, "", "", "", ()) not in _client_cache
|
||||
# The latest entries should be present
|
||||
assert (f"evict_test_{_CLIENT_CACHE_MAX_SIZE + 4}", False, "", "", "", ()) in _client_cache
|
||||
finally:
|
||||
with _client_cache_lock:
|
||||
_client_cache.clear()
|
||||
_client_cache.update(saved)
|
||||
|
||||
@@ -28,8 +28,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent._interrupt_requested = False
|
||||
agent._interrupt_message = None
|
||||
agent._execution_thread_id = None
|
||||
agent._interrupt_thread_signal_pending = False
|
||||
agent._execution_thread_id = None # defaults to current thread in set_interrupt
|
||||
agent._active_children = []
|
||||
agent._active_children_lock = threading.Lock()
|
||||
agent.quiet_mode = True
|
||||
@@ -47,17 +46,15 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
||||
assert parent._interrupt_requested is True
|
||||
assert child._interrupt_requested is True
|
||||
assert child._interrupt_message == "new user message"
|
||||
assert is_interrupted() is False
|
||||
assert parent._interrupt_thread_signal_pending is True
|
||||
assert is_interrupted() is True
|
||||
|
||||
def test_child_clear_interrupt_at_start_clears_thread(self):
|
||||
"""child.clear_interrupt() at start of run_conversation clears the
|
||||
bound execution thread's interrupt flag.
|
||||
per-thread interrupt flag for the current thread.
|
||||
"""
|
||||
child = self._make_bare_agent()
|
||||
child._interrupt_requested = True
|
||||
child._interrupt_message = "msg"
|
||||
child._execution_thread_id = threading.current_thread().ident
|
||||
|
||||
# Interrupt for current thread is set
|
||||
set_interrupt(True)
|
||||
@@ -131,36 +128,6 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
||||
child_thread.join(timeout=1)
|
||||
set_interrupt(False)
|
||||
|
||||
def test_prestart_interrupt_binds_to_execution_thread(self):
|
||||
"""An interrupt that arrives before startup should bind to the agent thread."""
|
||||
agent = self._make_bare_agent()
|
||||
barrier = threading.Barrier(2)
|
||||
result = {}
|
||||
|
||||
agent.interrupt("stop before start")
|
||||
assert agent._interrupt_requested is True
|
||||
assert agent._interrupt_thread_signal_pending is True
|
||||
assert is_interrupted() is False
|
||||
|
||||
def run_thread():
|
||||
from tools.interrupt import set_interrupt as _set_interrupt_for_test
|
||||
|
||||
agent._execution_thread_id = threading.current_thread().ident
|
||||
_set_interrupt_for_test(False, agent._execution_thread_id)
|
||||
if agent._interrupt_requested:
|
||||
_set_interrupt_for_test(True, agent._execution_thread_id)
|
||||
agent._interrupt_thread_signal_pending = False
|
||||
barrier.wait(timeout=5)
|
||||
result["thread_interrupted"] = is_interrupted()
|
||||
|
||||
t = threading.Thread(target=run_thread)
|
||||
t.start()
|
||||
barrier.wait(timeout=5)
|
||||
t.join(timeout=2)
|
||||
|
||||
assert result["thread_interrupted"] is True
|
||||
assert agent._interrupt_thread_signal_pending is False
|
||||
|
||||
|
||||
class TestPerThreadInterruptIsolation(unittest.TestCase):
|
||||
"""Verify that interrupting one agent does NOT affect another agent's thread.
|
||||
|
||||
@@ -9,6 +9,8 @@ def _build_agent(model_cfg, custom_providers=None, model="anthropic/claude-opus-
|
||||
if custom_providers is not None:
|
||||
cfg["custom_providers"] = custom_providers
|
||||
|
||||
base_url = model_cfg.get("base_url", "")
|
||||
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", return_value=cfg),
|
||||
patch("agent.model_metadata.get_model_context_length", return_value=128_000),
|
||||
@@ -21,6 +23,7 @@ def _build_agent(model_cfg, custom_providers=None, model="anthropic/claude-opus-
|
||||
agent = AIAgent(
|
||||
model=model,
|
||||
api_key="test-key-1234567890",
|
||||
base_url=base_url,
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
||||
@@ -805,7 +805,10 @@ class TestCodexReasoningPreflight:
|
||||
reasoning_items = [i for i in normalized if i.get("type") == "reasoning"]
|
||||
assert len(reasoning_items) == 1
|
||||
assert reasoning_items[0]["encrypted_content"] == "abc123encrypted"
|
||||
assert reasoning_items[0]["id"] == "r_001"
|
||||
# Note: "id" is intentionally excluded from normalized output —
|
||||
# with store=False the API returns 404 on server-side id resolution.
|
||||
# The id is only used for local deduplication via seen_ids.
|
||||
assert "id" not in reasoning_items[0]
|
||||
assert reasoning_items[0]["summary"] == [{"type": "summary_text", "text": "Thinking about it"}]
|
||||
|
||||
def test_reasoning_item_without_id(self, monkeypatch):
|
||||
|
||||
@@ -46,9 +46,18 @@ def api_module(monkeypatch, tmp_path):
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
# Ensure the gws CLI code path is taken even when the binary isn't
|
||||
# installed (CI). Without this, calendar_list() falls through to the
|
||||
# Python SDK path which imports ``googleapiclient`` — not in deps.
|
||||
module._gws_binary = lambda: "/usr/bin/gws"
|
||||
# Bypass authentication check — no real token file in CI.
|
||||
module._ensure_authenticated = lambda: None
|
||||
return module
|
||||
|
||||
|
||||
_gws_installed = importlib.util.find_spec("shutil") and __import__("shutil").which("gws")
|
||||
|
||||
|
||||
def _write_token(path: Path, *, token="ya29.test", expiry=None, **extra):
|
||||
data = {
|
||||
"token": token,
|
||||
@@ -124,13 +133,14 @@ def test_bridge_main_injects_token_env(bridge_module, tmp_path):
|
||||
assert captured["cmd"] == ["gws", "gmail", "+triage"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _gws_installed, reason="gws CLI not installed")
|
||||
def test_api_calendar_list_uses_agenda_by_default(api_module):
|
||||
"""calendar list without dates uses +agenda helper."""
|
||||
captured = {}
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
captured["cmd"] = cmd
|
||||
return MagicMock(returncode=0)
|
||||
return MagicMock(returncode=0, stdout="{}", stderr="")
|
||||
|
||||
args = api_module.argparse.Namespace(
|
||||
start="", end="", max=25, calendar="primary", func=api_module.calendar_list,
|
||||
@@ -146,6 +156,7 @@ def test_api_calendar_list_uses_agenda_by_default(api_module):
|
||||
assert "--days" in gws_args
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _gws_installed, reason="gws CLI not installed")
|
||||
def test_api_calendar_list_respects_date_range(api_module):
|
||||
"""calendar list with --start/--end uses raw events list API."""
|
||||
captured = {}
|
||||
|
||||
@@ -296,7 +296,7 @@ def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
monotonic_values = iter([0.0, 0.0, 0.0, 12.5, 12.5])
|
||||
monotonic_values = iter([0.0, 12.5])
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
calls.append((method, url, json, timeout))
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestSendMatrix:
|
||||
session.put.assert_called_once()
|
||||
call_kwargs = session.put.call_args
|
||||
url = call_kwargs[0][0]
|
||||
assert url.startswith("https://matrix.example.com/_matrix/client/v3/rooms/!room:example.com/send/m.room.message/")
|
||||
assert url.startswith("https://matrix.example.com/_matrix/client/v3/rooms/%21room%3Aexample.com/send/m.room.message/")
|
||||
assert call_kwargs[1]["headers"]["Authorization"] == "Bearer syt_tok"
|
||||
payload = call_kwargs[1]["json"]
|
||||
assert payload["msgtype"] == "m.text"
|
||||
|
||||
@@ -92,25 +92,6 @@ class TestCheckWatchPatterns:
|
||||
assert "disk full" in evt["output"]
|
||||
assert evt["session_id"] == "proc_test_watch"
|
||||
|
||||
def test_match_carries_session_key_and_watcher_routing_metadata(self, registry):
|
||||
session = _make_session(watch_patterns=["ERROR"])
|
||||
session.session_key = "agent:main:telegram:group:-100:42"
|
||||
session.watcher_platform = "telegram"
|
||||
session.watcher_chat_id = "-100"
|
||||
session.watcher_user_id = "u123"
|
||||
session.watcher_user_name = "alice"
|
||||
session.watcher_thread_id = "42"
|
||||
|
||||
registry._check_watch_patterns(session, "ERROR: disk full\n")
|
||||
evt = registry.completion_queue.get_nowait()
|
||||
|
||||
assert evt["session_key"] == "agent:main:telegram:group:-100:42"
|
||||
assert evt["platform"] == "telegram"
|
||||
assert evt["chat_id"] == "-100"
|
||||
assert evt["user_id"] == "u123"
|
||||
assert evt["user_name"] == "alice"
|
||||
assert evt["thread_id"] == "42"
|
||||
|
||||
def test_multiple_patterns(self, registry):
|
||||
"""First matching pattern is reported."""
|
||||
session = _make_session(watch_patterns=["WARN", "ERROR"])
|
||||
|
||||
@@ -105,10 +105,6 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
if self._client_timeout_grace_seconds is not None:
|
||||
deadline = time.monotonic() + prepared.timeout + self._client_timeout_grace_seconds
|
||||
|
||||
_last_activity_touch = time.monotonic()
|
||||
_modal_exec_start = time.monotonic()
|
||||
_ACTIVITY_INTERVAL = 10.0 # match _wait_for_process cadence
|
||||
|
||||
while True:
|
||||
if is_interrupted():
|
||||
try:
|
||||
@@ -132,22 +128,6 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
pass
|
||||
return self._timeout_result_for_modal(prepared.timeout)
|
||||
|
||||
# Periodic activity touch so the gateway knows we're alive
|
||||
_now = time.monotonic()
|
||||
if _now - _last_activity_touch >= _ACTIVITY_INTERVAL:
|
||||
_last_activity_touch = _now
|
||||
try:
|
||||
from tools.environments.base import _get_activity_callback
|
||||
_cb = _get_activity_callback()
|
||||
except Exception:
|
||||
_cb = None
|
||||
if _cb:
|
||||
try:
|
||||
_elapsed = int(_now - _modal_exec_start)
|
||||
_cb(f"modal command running ({_elapsed}s elapsed)")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(self._poll_interval_seconds)
|
||||
|
||||
def _before_execute(self) -> None:
|
||||
|
||||
@@ -191,15 +191,9 @@ class ProcessRegistry:
|
||||
session._watch_disabled = True
|
||||
self.completion_queue.put({
|
||||
"session_id": session.id,
|
||||
"session_key": session.session_key,
|
||||
"command": session.command,
|
||||
"type": "watch_disabled",
|
||||
"suppressed": session._watch_suppressed,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"user_id": session.watcher_user_id,
|
||||
"user_name": session.watcher_user_name,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
"message": (
|
||||
f"Watch patterns disabled for process {session.id} — "
|
||||
f"too many matches ({session._watch_suppressed} suppressed). "
|
||||
@@ -225,17 +219,11 @@ class ProcessRegistry:
|
||||
|
||||
self.completion_queue.put({
|
||||
"session_id": session.id,
|
||||
"session_key": session.session_key,
|
||||
"command": session.command,
|
||||
"type": "watch_match",
|
||||
"pattern": matched_pattern,
|
||||
"output": output,
|
||||
"suppressed": suppressed,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"user_id": session.watcher_user_id,
|
||||
"user_name": session.watcher_user_name,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1384,10 +1384,14 @@ def terminal_tool(
|
||||
if pty_disabled_reason:
|
||||
result_data["pty_note"] = pty_disabled_reason
|
||||
|
||||
# Populate routing metadata on the session so that
|
||||
# watch-pattern and completion notifications can be
|
||||
# routed back to the correct chat/thread.
|
||||
if background and (notify_on_complete or watch_patterns):
|
||||
# Mark for agent notification on completion
|
||||
if notify_on_complete and background:
|
||||
proc_session.notify_on_complete = True
|
||||
result_data["notify_on_complete"] = True
|
||||
|
||||
# In gateway mode, auto-register a fast watcher so the
|
||||
# gateway can detect completion and trigger a new agent
|
||||
# turn. CLI mode uses the completion_queue directly.
|
||||
from gateway.session_context import get_session_env as _gse
|
||||
_gw_platform = _gse("HERMES_SESSION_PLATFORM", "")
|
||||
if _gw_platform:
|
||||
@@ -1400,26 +1404,16 @@ def terminal_tool(
|
||||
proc_session.watcher_user_id = _gw_user_id
|
||||
proc_session.watcher_user_name = _gw_user_name
|
||||
proc_session.watcher_thread_id = _gw_thread_id
|
||||
|
||||
# Mark for agent notification on completion
|
||||
if notify_on_complete and background:
|
||||
proc_session.notify_on_complete = True
|
||||
result_data["notify_on_complete"] = True
|
||||
|
||||
# In gateway mode, auto-register a fast watcher so the
|
||||
# gateway can detect completion and trigger a new agent
|
||||
# turn. CLI mode uses the completion_queue directly.
|
||||
if proc_session.watcher_platform:
|
||||
proc_session.watcher_interval = 5
|
||||
process_registry.pending_watchers.append({
|
||||
"session_id": proc_session.id,
|
||||
"check_interval": 5,
|
||||
"session_key": session_key,
|
||||
"platform": proc_session.watcher_platform,
|
||||
"chat_id": proc_session.watcher_chat_id,
|
||||
"user_id": proc_session.watcher_user_id,
|
||||
"user_name": proc_session.watcher_user_name,
|
||||
"thread_id": proc_session.watcher_thread_id,
|
||||
"platform": _gw_platform,
|
||||
"chat_id": _gw_chat_id,
|
||||
"user_id": _gw_user_id,
|
||||
"user_name": _gw_user_name,
|
||||
"thread_id": _gw_thread_id,
|
||||
"notify_on_complete": True,
|
||||
})
|
||||
|
||||
|
||||
@@ -8,24 +8,20 @@
|
||||
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
|
||||
|
||||
:root {
|
||||
/* Dark amber palette for light mode — readable on white (WCAG AA compliant)
|
||||
Current gold #FFD700 has only 1.4:1 contrast on white; these tones pass 4.5:1+ */
|
||||
--ifm-color-primary: #8B6508;
|
||||
--ifm-color-primary-dark: #7A5800;
|
||||
--ifm-color-primary-darker: #6E4F00;
|
||||
--ifm-color-primary-darkest: #5A4100;
|
||||
--ifm-color-primary-light: #9E7410;
|
||||
--ifm-color-primary-lighter: #B38319;
|
||||
--ifm-color-primary-lightest: #C89222;
|
||||
/* Gold/Amber palette from landing page */
|
||||
--ifm-color-primary: #FFD700;
|
||||
--ifm-color-primary-dark: #E6C200;
|
||||
--ifm-color-primary-darker: #D9B700;
|
||||
--ifm-color-primary-darkest: #B39600;
|
||||
--ifm-color-primary-light: #FFDD33;
|
||||
--ifm-color-primary-lighter: #FFE14D;
|
||||
--ifm-color-primary-lightest: #FFEB80;
|
||||
|
||||
--ifm-font-family-base: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
||||
--ifm-font-family-monospace: 'JetBrains Mono', 'Fira Code', 'Cascadia Code', monospace;
|
||||
|
||||
--ifm-code-font-size: 90%;
|
||||
--ifm-heading-font-weight: 600;
|
||||
|
||||
--ifm-link-color: #7A5800;
|
||||
--ifm-link-hover-color: #5A4100;
|
||||
}
|
||||
|
||||
/* Dark mode — the PRIMARY mode, matches landing page */
|
||||
@@ -95,13 +91,6 @@
|
||||
padding-left: calc(var(--ifm-menu-link-padding-horizontal) - 3px);
|
||||
}
|
||||
|
||||
/* Light mode sidebar active */
|
||||
[data-theme='light'] .menu__link--active:not(.menu__link--sublist) {
|
||||
background-color: rgba(139, 101, 8, 0.08);
|
||||
border-left: 3px solid #8B6508;
|
||||
padding-left: calc(var(--ifm-menu-link-padding-horizontal) - 3px);
|
||||
}
|
||||
|
||||
/* Code blocks */
|
||||
[data-theme='dark'] .prism-code {
|
||||
background-color: #0a0a12 !important;
|
||||
@@ -178,16 +167,6 @@ pre.prism-code.language-ascii code {
|
||||
border-color: rgba(255, 215, 0, 0.06);
|
||||
}
|
||||
|
||||
/* Light mode table styling */
|
||||
[data-theme='light'] table th {
|
||||
background-color: rgba(139, 101, 8, 0.06);
|
||||
border-color: rgba(139, 101, 8, 0.15);
|
||||
}
|
||||
|
||||
[data-theme='light'] table td {
|
||||
border-color: rgba(139, 101, 8, 0.10);
|
||||
}
|
||||
|
||||
/* Footer */
|
||||
.footer {
|
||||
border-top: 1px solid rgba(255, 215, 0, 0.08);
|
||||
@@ -198,16 +177,11 @@ pre.prism-code.language-ascii code {
|
||||
transition: color 0.2s;
|
||||
}
|
||||
|
||||
[data-theme='dark'] .footer a:hover {
|
||||
.footer a:hover {
|
||||
color: #FFD700;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
[data-theme='light'] .footer a:hover {
|
||||
color: #7A5800;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
/* Scrollbar */
|
||||
[data-theme='dark'] ::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
|
||||
Reference in New Issue
Block a user