mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:37:13 +08:00
Compare commits
12 Commits
fix/api-se
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f7e5db456 | ||
|
|
f008ee1019 | ||
|
|
60fdb58ce4 | ||
|
|
18d28c63a7 | ||
|
|
3c57eaf744 | ||
|
|
2d232c9991 | ||
|
|
0375b2a0d7 | ||
|
|
08fa326bb0 | ||
|
|
bde45f5a2a | ||
|
|
716e616d28 | ||
|
|
bdccdd67a1 | ||
|
|
148f46620f |
@@ -231,7 +231,7 @@ class KawaiiSpinner:
|
||||
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
|
||||
]
|
||||
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots'):
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots', print_fn=None):
|
||||
self.message = message
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
|
||||
self.running = False
|
||||
@@ -239,12 +239,26 @@ class KawaiiSpinner:
|
||||
self.frame_idx = 0
|
||||
self.start_time = None
|
||||
self.last_line_len = 0
|
||||
# Optional callable to route all output through (e.g. a no-op for silent
|
||||
# background agents). When set, bypasses self._out entirely so that
|
||||
# agents with _print_fn overridden remain fully silent.
|
||||
self._print_fn = print_fn
|
||||
# Capture stdout NOW, before any redirect_stdout(devnull) from
|
||||
# child agents can replace sys.stdout with a black hole.
|
||||
self._out = sys.stdout
|
||||
|
||||
def _write(self, text: str, end: str = '\n', flush: bool = False):
|
||||
"""Write to the stdout captured at spinner creation time."""
|
||||
"""Write to the stdout captured at spinner creation time.
|
||||
|
||||
If a print_fn was supplied at construction, all output is routed through
|
||||
it instead — allowing callers to silence the spinner with a no-op lambda.
|
||||
"""
|
||||
if self._print_fn is not None:
|
||||
try:
|
||||
self._print_fn(text)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
try:
|
||||
self._out.write(text + end)
|
||||
if flush:
|
||||
|
||||
@@ -688,6 +688,12 @@ display:
|
||||
# Toggle at runtime with /verbose in the CLI
|
||||
tool_progress: all
|
||||
|
||||
# What Enter does when Hermes is already busy in the CLI.
|
||||
# interrupt: Interrupt the current run and redirect Hermes (default)
|
||||
# queue: Queue your message for the next turn
|
||||
# Ctrl+C always interrupts regardless of this setting.
|
||||
busy_input_mode: interrupt
|
||||
|
||||
# Background process notifications (gateway/messaging only).
|
||||
# Controls how chatty the process watcher is when you use
|
||||
# terminal(background=true, check_interval=...) from Telegram/Discord/etc.
|
||||
|
||||
79
cli.py
79
cli.py
@@ -205,6 +205,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"resume_display": "full",
|
||||
"show_reasoning": False,
|
||||
"streaming": True,
|
||||
"busy_input_mode": "interrupt",
|
||||
|
||||
"skin": "default",
|
||||
},
|
||||
@@ -1035,13 +1036,18 @@ class HermesCLI:
|
||||
self.config = CLI_CONFIG
|
||||
self.compact = compact if compact is not None else CLI_CONFIG["display"].get("compact", False)
|
||||
# tool_progress: "off", "new", "all", "verbose" (from config.yaml display section)
|
||||
self.tool_progress_mode = CLI_CONFIG["display"].get("tool_progress", "all")
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise to string.
|
||||
_raw_tp = CLI_CONFIG["display"].get("tool_progress", "all")
|
||||
self.tool_progress_mode = "off" if _raw_tp is False else str(_raw_tp)
|
||||
# resume_display: "full" (show history) | "minimal" (one-liner only)
|
||||
self.resume_display = CLI_CONFIG["display"].get("resume_display", "full")
|
||||
# bell_on_complete: play terminal bell (\a) when agent finishes a response
|
||||
self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False)
|
||||
# show_reasoning: display model thinking/reasoning before the response
|
||||
self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False)
|
||||
# busy_input_mode: "interrupt" (Enter interrupts current run) or "queue" (Enter queues for next turn)
|
||||
_bim = CLI_CONFIG["display"].get("busy_input_mode", "interrupt")
|
||||
self.busy_input_mode = "queue" if str(_bim).strip().lower() == "queue" else "interrupt"
|
||||
|
||||
self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose")
|
||||
|
||||
@@ -1329,7 +1335,12 @@ class HermesCLI:
|
||||
def _build_status_bar_text(self, width: Optional[int] = None) -> str:
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
width = width or shutil.get_terminal_size((80, 24)).columns
|
||||
if width is None:
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
width = get_app().output.get_size().columns
|
||||
except Exception:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
percent = snapshot["context_percent"]
|
||||
percent_label = f"{percent}%" if percent is not None else "--"
|
||||
duration_label = snapshot["duration"]
|
||||
@@ -1359,7 +1370,16 @@ class HermesCLI:
|
||||
return []
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
# Use prompt_toolkit's own terminal width when running inside the
|
||||
# TUI — shutil.get_terminal_size() can return stale or fallback
|
||||
# values (especially on SSH) that differ from what prompt_toolkit
|
||||
# actually renders, causing the fragments to overflow to a second
|
||||
# line and produce duplicated status bar rows over long sessions.
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
width = get_app().output.get_size().columns
|
||||
except Exception:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
duration_label = snapshot["duration"]
|
||||
|
||||
if width < 52:
|
||||
@@ -3722,17 +3742,17 @@ class HermesCLI:
|
||||
elif canonical == "background":
|
||||
self._handle_background_command(cmd_original)
|
||||
elif canonical == "queue":
|
||||
if not self._agent_running:
|
||||
_cprint(" /queue only works while Hermes is busy. Just type your message normally.")
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
else:
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
self._pending_input.put(payload)
|
||||
if self._agent_running:
|
||||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "skin":
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif canonical == "voice":
|
||||
@@ -6112,16 +6132,22 @@ class HermesCLI:
|
||||
# Bundle text + images as a tuple when images are present
|
||||
payload = (text, images) if images else text
|
||||
if self._agent_running and not (text and text.startswith("/")):
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
_dbg = _hermes_home / "interrupt_debug.log"
|
||||
with open(_dbg, "a") as _f:
|
||||
import time as _t
|
||||
_f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, "
|
||||
f"agent_running={self._agent_running}\n")
|
||||
except Exception:
|
||||
pass
|
||||
if self.busy_input_mode == "queue":
|
||||
# Queue for the next turn instead of interrupting
|
||||
self._pending_input.put(payload)
|
||||
preview = text if text else f"[{len(images)} image{'s' if len(images) != 1 else ''} attached]"
|
||||
_cprint(f" Queued for the next turn: {preview[:80]}{'...' if len(preview) > 80 else ''}")
|
||||
else:
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
_dbg = _hermes_home / "interrupt_debug.log"
|
||||
with open(_dbg, "a") as _f:
|
||||
import time as _t
|
||||
_f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, "
|
||||
f"agent_running={self._agent_running}\n")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
event.app.current_buffer.reset(append_to_history=True)
|
||||
@@ -6894,6 +6920,15 @@ class HermesCLI:
|
||||
Window(
|
||||
content=FormattedTextControl(lambda: cli_ref._get_status_bar_fragments()),
|
||||
height=1,
|
||||
# Prevent fragments that overflow the terminal width from
|
||||
# wrapping onto a second line, which causes the status bar to
|
||||
# appear duplicated (one full + one partial row) during long
|
||||
# sessions, especially on SSH where shutil.get_terminal_size
|
||||
# may return stale values. _get_status_bar_fragments now reads
|
||||
# width from prompt_toolkit's own output object, so fragments
|
||||
# will always fit; wrap_lines=False is the belt-and-suspenders
|
||||
# guard against any future width mismatch.
|
||||
wrap_lines=False,
|
||||
),
|
||||
filter=Condition(lambda: cli_ref._status_bar_visible),
|
||||
)
|
||||
|
||||
@@ -366,14 +366,20 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Create an AIAgent instance using the gateway's runtime config.
|
||||
|
||||
Uses _resolve_runtime_agent_kwargs() to pick up model, api_key,
|
||||
base_url, etc. from config.yaml / env vars.
|
||||
base_url, etc. from config.yaml / env vars. Toolsets are resolved
|
||||
from config.yaml platform_toolsets.api_server (same as all other
|
||||
gateway platforms), falling back to the hermes-api-server default.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model, _load_gateway_config
|
||||
from hermes_cli.tools_config import _get_platform_tools
|
||||
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
model = _resolve_gateway_model()
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, "api_server"))
|
||||
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
|
||||
agent = AIAgent(
|
||||
@@ -383,6 +389,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt or None,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
session_id=session_id,
|
||||
platform="api_server",
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
|
||||
@@ -8,6 +8,7 @@ and implement the required methods.
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -329,6 +330,24 @@ class SendResult:
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
raw_response: Any = None
|
||||
retryable: bool = False # True for transient errors (network, timeout) — base will retry automatically
|
||||
|
||||
|
||||
# Error substrings that indicate a transient network failure worth retrying
|
||||
_RETRYABLE_ERROR_PATTERNS = (
|
||||
"connecterror",
|
||||
"connectionerror",
|
||||
"connectionreset",
|
||||
"connectionrefused",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"network",
|
||||
"broken pipe",
|
||||
"remotedisconnected",
|
||||
"eoferror",
|
||||
"readtimeout",
|
||||
"writetimeout",
|
||||
)
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
@@ -833,6 +852,91 @@ class BasePlatformAdapter(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_error(error: Optional[str]) -> bool:
|
||||
"""Return True if the error string looks like a transient network failure."""
|
||||
if not error:
|
||||
return False
|
||||
lowered = error.lower()
|
||||
return any(pat in lowered for pat in _RETRYABLE_ERROR_PATTERNS)
|
||||
|
||||
async def _send_with_retry(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Any = None,
|
||||
max_retries: int = 2,
|
||||
base_delay: float = 2.0,
|
||||
) -> "SendResult":
|
||||
"""
|
||||
Send a message with automatic retry for transient network errors.
|
||||
|
||||
On permanent failures (e.g. formatting / permission errors) falls back
|
||||
to a plain-text version before giving up. If all attempts fail due to
|
||||
network errors, sends the user a brief delivery-failure notice so they
|
||||
know to retry rather than waiting indefinitely.
|
||||
"""
|
||||
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
return result
|
||||
|
||||
error_str = result.error or ""
|
||||
is_network = result.retryable or self._is_retryable_error(error_str)
|
||||
|
||||
if is_network:
|
||||
# Retry with exponential backoff for transient errors
|
||||
for attempt in range(1, max_retries + 1):
|
||||
delay = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1)
|
||||
logger.warning(
|
||||
"[%s] Send failed (attempt %d/%d, retrying in %.1fs): %s",
|
||||
self.name, attempt, max_retries, delay, error_str,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if result.success:
|
||||
logger.info("[%s] Send succeeded on retry %d", self.name, attempt)
|
||||
return result
|
||||
error_str = result.error or ""
|
||||
if not (result.retryable or self._is_retryable_error(error_str)):
|
||||
break # error switched to non-transient — fall through to plain-text fallback
|
||||
else:
|
||||
# All retries exhausted (loop completed without break) — notify user
|
||||
logger.error("[%s] Failed to deliver response after %d retries: %s", self.name, max_retries, error_str)
|
||||
notice = (
|
||||
"\u26a0\ufe0f Message delivery failed after multiple attempts. "
|
||||
"Please try again \u2014 your request was processed but the response could not be sent."
|
||||
)
|
||||
try:
|
||||
await self.send(chat_id=chat_id, content=notice, reply_to=reply_to, metadata=metadata)
|
||||
except Exception as notify_err:
|
||||
logger.debug("[%s] Could not send delivery-failure notice: %s", self.name, notify_err)
|
||||
return result
|
||||
|
||||
# Non-network / post-retry formatting failure: try plain text as fallback
|
||||
logger.warning("[%s] Send failed: %s — trying plain-text fallback", self.name, error_str)
|
||||
fallback_result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{content[:3500]}",
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error)
|
||||
return fallback_result
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
@@ -982,26 +1086,13 @@ class BasePlatformAdapter(ABC):
|
||||
# Send the text portion
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
result = await self._send_with_retry(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
@@ -551,9 +551,20 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
"""Continuously sync with the homeserver."""
|
||||
import nio
|
||||
|
||||
while not self._closing:
|
||||
try:
|
||||
await self._client.sync(timeout=30000)
|
||||
resp = await self._client.sync(timeout=30000)
|
||||
if isinstance(resp, nio.SyncError):
|
||||
if self._closing:
|
||||
return
|
||||
logger.warning(
|
||||
"Matrix: sync returned %s: %s — retrying in 5s",
|
||||
type(resp).__name__,
|
||||
getattr(resp, "message", resp),
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
|
||||
@@ -573,6 +573,10 @@ class GatewayRunner:
|
||||
session_id=old_session_id,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
# Fully silence the flush agent — quiet_mode only suppresses init
|
||||
# messages; tool call output still leaks to the terminal through
|
||||
# _safe_print → _print_fn. Set a no-op to prevent that.
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
# Build conversation history from transcript
|
||||
msgs = [
|
||||
@@ -954,12 +958,20 @@ class GatewayRunner:
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
os.getenv(v, "").lower() in ("true", "1", "yes")
|
||||
for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS")
|
||||
)
|
||||
if not _any_allowlist and not _allow_all:
|
||||
logger.warning(
|
||||
"No user allowlists configured. All unauthorized users will be denied. "
|
||||
@@ -2175,6 +2187,7 @@ class GatewayRunner:
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
_hyg_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
_compressed, _ = await loop.run_in_executor(
|
||||
@@ -3885,6 +3898,7 @@ class GatewayRunner:
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
compressed, _ = await loop.run_in_executor(
|
||||
@@ -4799,9 +4813,14 @@ class GatewayRunner:
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, platform_key))
|
||||
|
||||
# Tool progress mode from config.yaml: "all", "new", "verbose", "off"
|
||||
# Falls back to env vars for backward compatibility
|
||||
# Falls back to env vars for backward compatibility.
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise before
|
||||
# the `or` chain so it doesn't silently fall through to "all".
|
||||
_raw_tp = user_config.get("display", {}).get("tool_progress")
|
||||
if _raw_tp is False:
|
||||
_raw_tp = "off"
|
||||
progress_mode = (
|
||||
user_config.get("display", {}).get("tool_progress")
|
||||
_raw_tp
|
||||
or os.getenv("HERMES_TOOL_PROGRESS_MODE")
|
||||
or "all"
|
||||
)
|
||||
@@ -5128,7 +5147,25 @@ class GatewayRunner:
|
||||
agent.stream_delta_callback = _stream_delta_cb
|
||||
agent.status_callback = _status_callback_sync
|
||||
agent.reasoning_config = reasoning_config
|
||||
|
||||
|
||||
# Background review delivery — send "💾 Memory updated" etc. to user
|
||||
def _bg_review_send(message: str) -> None:
|
||||
if not _status_adapter:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_status_adapter.send(
|
||||
_status_chat_id,
|
||||
message,
|
||||
metadata=_status_thread_metadata,
|
||||
),
|
||||
_loop_for_step,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("background_review_callback error: %s", _e)
|
||||
|
||||
agent.background_review_callback = _bg_review_send
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
# Capture the full tool definitions for transcript logging
|
||||
|
||||
@@ -955,13 +955,17 @@ class SessionStore:
|
||||
try:
|
||||
self._db.clear_messages(session_id)
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
self._db.append_message(
|
||||
session_id=session_id,
|
||||
role=msg.get("role", "unknown"),
|
||||
role=role,
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
reasoning=msg.get("reasoning") if role == "assistant" else None,
|
||||
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
|
||||
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
@@ -264,6 +264,7 @@ DEFAULT_CONFIG = {
|
||||
"compact": False,
|
||||
"personality": "kawaii",
|
||||
"resume_display": "full",
|
||||
"busy_input_mode": "interrupt",
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
|
||||
@@ -2968,6 +2968,95 @@ def setup_tools(config: dict, first_install: bool = False):
|
||||
tools_command(first_install=first_install, config=config)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Post-Migration Section Skip Logic
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]:
|
||||
"""Return a short summary if a setup section is already configured, else None.
|
||||
|
||||
Used after OpenClaw migration to detect which sections can be skipped.
|
||||
``get_env_value`` is the module-level import from hermes_cli.config
|
||||
so that test patches on ``setup_mod.get_env_value`` take effect.
|
||||
"""
|
||||
if section_key == "model":
|
||||
has_key = bool(
|
||||
get_env_value("OPENROUTER_API_KEY")
|
||||
or get_env_value("OPENAI_API_KEY")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
)
|
||||
if not has_key:
|
||||
# Check for OAuth providers
|
||||
try:
|
||||
from hermes_cli.auth import get_active_provider
|
||||
if get_active_provider():
|
||||
has_key = True
|
||||
except Exception:
|
||||
pass
|
||||
if not has_key:
|
||||
return None
|
||||
model = config.get("model")
|
||||
if isinstance(model, str) and model.strip():
|
||||
return model.strip()
|
||||
if isinstance(model, dict):
|
||||
return str(model.get("default") or model.get("model") or "configured")
|
||||
return "configured"
|
||||
|
||||
elif section_key == "terminal":
|
||||
backend = config.get("terminal", {}).get("backend", "local")
|
||||
return f"backend: {backend}"
|
||||
|
||||
elif section_key == "agent":
|
||||
max_turns = config.get("agent", {}).get("max_turns", 90)
|
||||
return f"max turns: {max_turns}"
|
||||
|
||||
elif section_key == "gateway":
|
||||
platforms = []
|
||||
if get_env_value("TELEGRAM_BOT_TOKEN"):
|
||||
platforms.append("Telegram")
|
||||
if get_env_value("DISCORD_BOT_TOKEN"):
|
||||
platforms.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN"):
|
||||
platforms.append("Slack")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
||||
elif section_key == "tools":
|
||||
tools = []
|
||||
if get_env_value("ELEVENLABS_API_KEY"):
|
||||
tools.append("TTS/ElevenLabs")
|
||||
if get_env_value("BROWSERBASE_API_KEY"):
|
||||
tools.append("Browser")
|
||||
if get_env_value("FIRECRAWL_API_KEY"):
|
||||
tools.append("Firecrawl")
|
||||
if tools:
|
||||
return ", ".join(tools)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _skip_configured_section(
|
||||
config: dict, section_key: str, label: str
|
||||
) -> bool:
|
||||
"""Show an already-configured section summary and offer to skip.
|
||||
|
||||
Returns True if the user chose to skip, False if the section should run.
|
||||
"""
|
||||
summary = _get_section_config_summary(config, section_key)
|
||||
if not summary:
|
||||
return False
|
||||
print()
|
||||
print_success(f" {label}: {summary}")
|
||||
return not prompt_yes_no(f" Reconfigure {label.lower()}?", default=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenClaw Migration
|
||||
# =============================================================================
|
||||
@@ -3039,7 +3128,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=False,
|
||||
overwrite=True,
|
||||
migrate_secrets=True,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
@@ -3195,6 +3284,8 @@ def run_setup_wizard(args):
|
||||
)
|
||||
)
|
||||
|
||||
migration_ran = False
|
||||
|
||||
if is_existing:
|
||||
# ── Returning User Menu ──
|
||||
print()
|
||||
@@ -3264,7 +3355,8 @@ def run_setup_wizard(args):
|
||||
return
|
||||
|
||||
# Offer OpenClaw migration before configuration begins
|
||||
if _offer_openclaw_migration(hermes_home):
|
||||
migration_ran = _offer_openclaw_migration(hermes_home)
|
||||
if migration_ran:
|
||||
# Reload config in case migration wrote to it
|
||||
config = load_config()
|
||||
|
||||
@@ -3277,20 +3369,31 @@ def run_setup_wizard(args):
|
||||
print()
|
||||
print_info("You can edit these files directly or use 'hermes config edit'")
|
||||
|
||||
if migration_ran:
|
||||
print()
|
||||
print_info("Settings were imported from OpenClaw.")
|
||||
print_info("Each section below will show what was imported — press Enter to keep,")
|
||||
print_info("or choose to reconfigure if needed.")
|
||||
|
||||
# Section 1: Model & Provider
|
||||
setup_model_provider(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "model", "Model & Provider")):
|
||||
setup_model_provider(config)
|
||||
|
||||
# Section 2: Terminal Backend
|
||||
setup_terminal_backend(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "terminal", "Terminal Backend")):
|
||||
setup_terminal_backend(config)
|
||||
|
||||
# Section 3: Agent Settings
|
||||
setup_agent_settings(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "agent", "Agent Settings")):
|
||||
setup_agent_settings(config)
|
||||
|
||||
# Section 4: Messaging Platforms
|
||||
setup_gateway(config)
|
||||
if not (migration_ran and _skip_configured_section(config, "gateway", "Messaging Platforms")):
|
||||
setup_gateway(config)
|
||||
|
||||
# Section 5: Tools
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
if not (migration_ran and _skip_configured_section(config, "tools", "Tools")):
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
|
||||
# Save and show summary
|
||||
save_config(config)
|
||||
|
||||
@@ -134,6 +134,7 @@ PLATFORMS = {
|
||||
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
|
||||
}
|
||||
|
||||
|
||||
|
||||
36
run_agent.py
36
run_agent.py
@@ -486,6 +486,7 @@ class AIAgent:
|
||||
# instead of going directly to stdout where patch_stdout's StdoutProxy
|
||||
# would mangle the escape sequences. None = use builtins.print.
|
||||
self._print_fn = None
|
||||
self.background_review_callback = None # Optional sync callback for gateway delivery
|
||||
self.skip_context_files = skip_context_files
|
||||
self.pass_session_id = pass_session_id
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
@@ -1525,6 +1526,12 @@ class AIAgent:
|
||||
if actions:
|
||||
summary = " · ".join(dict.fromkeys(actions))
|
||||
self._safe_print(f" 💾 {summary}")
|
||||
_bg_cb = self.background_review_callback
|
||||
if _bg_cb:
|
||||
try:
|
||||
_bg_cb(f"💾 {summary}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Background memory/skill review failed: %s", e)
|
||||
@@ -4127,6 +4134,25 @@ class AIAgent:
|
||||
or is_native_anthropic
|
||||
)
|
||||
|
||||
# Update context compressor limits for the fallback model.
|
||||
# Without this, compression decisions use the primary model's
|
||||
# context window (e.g. 200K) instead of the fallback's (e.g. 32K),
|
||||
# causing oversized sessions to overflow the fallback.
|
||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
fb_context_length = get_model_context_length(
|
||||
self.model, base_url=self.base_url,
|
||||
api_key=self.api_key, provider=self.provider,
|
||||
)
|
||||
self.context_compressor.model = self.model
|
||||
self.context_compressor.base_url = self.base_url
|
||||
self.context_compressor.api_key = self.api_key
|
||||
self.context_compressor.provider = self.provider
|
||||
self.context_compressor.context_length = fb_context_length
|
||||
self.context_compressor.threshold_tokens = int(
|
||||
fb_context_length * self.context_compressor.threshold_percent
|
||||
)
|
||||
|
||||
self._emit_status(
|
||||
f"🔄 Primary model failed — switching to fallback: "
|
||||
f"{fb_model} via {fb_provider}"
|
||||
@@ -5080,7 +5106,7 @@ class AIAgent:
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
|
||||
try:
|
||||
@@ -5121,7 +5147,7 @@ class AIAgent:
|
||||
# Print cute message per tool
|
||||
if self.quiet_mode:
|
||||
cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result)
|
||||
print(f" {cute_msg}")
|
||||
self._safe_print(f" {cute_msg}")
|
||||
elif not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s")
|
||||
@@ -5306,7 +5332,7 @@ class AIAgent:
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
self._delegate_spinner = spinner
|
||||
_delegate_result = None
|
||||
@@ -5336,7 +5362,7 @@ class AIAgent:
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
if len(preview) > 30:
|
||||
preview = preview[:27] + "..."
|
||||
spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots')
|
||||
spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner.start()
|
||||
_spinner_result = None
|
||||
try:
|
||||
@@ -6019,7 +6045,7 @@ class AIAgent:
|
||||
# Raw KawaiiSpinner only when no streaming consumers
|
||||
# (would conflict with streamed token output)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type, print_fn=self._print_fn)
|
||||
thinking_spinner.start()
|
||||
|
||||
# Log request details if verbose
|
||||
|
||||
46
tests/gateway/test_allowlist_startup_check.py
Normal file
46
tests/gateway/test_allowlist_startup_check.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Tests for the startup allowlist warning check in gateway/run.py."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _would_warn():
|
||||
"""Replicate the startup allowlist warning logic. Returns True if warning fires."""
|
||||
_any_allowlist = any(
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
os.getenv(v, "").lower() in ("true", "1", "yes")
|
||||
for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS")
|
||||
)
|
||||
return not _any_allowlist and not _allow_all
|
||||
|
||||
|
||||
class TestAllowlistStartupCheck:
|
||||
|
||||
def test_no_config_emits_warning(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert _would_warn() is True
|
||||
|
||||
def test_signal_group_allowed_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"SIGNAL_GROUP_ALLOWED_USERS": "user1"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
|
||||
def test_telegram_allow_all_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"TELEGRAM_ALLOW_ALL_USERS": "true"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
|
||||
def test_gateway_allow_all_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"GATEWAY_ALLOW_ALL_USERS": "yes"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
129
tests/gateway/test_api_server_toolset.py
Normal file
129
tests/gateway/test_api_server_toolset.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tests for hermes-api-server toolset and API server tool availability."""
|
||||
import os
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from toolsets import resolve_toolset, get_toolset, validate_toolset
|
||||
|
||||
|
||||
class TestHermesApiServerToolset:
|
||||
"""Tests for the hermes-api-server toolset definition."""
|
||||
|
||||
def test_toolset_exists(self):
|
||||
ts = get_toolset("hermes-api-server")
|
||||
assert ts is not None
|
||||
|
||||
def test_toolset_validates(self):
|
||||
assert validate_toolset("hermes-api-server")
|
||||
|
||||
def test_toolset_includes_web_tools(self):
|
||||
tools = resolve_toolset("hermes-api-server")
|
||||
assert "web_search" in tools
|
||||
assert "web_extract" in tools
|
||||
|
||||
def test_toolset_includes_core_tools(self):
|
||||
tools = resolve_toolset("hermes-api-server")
|
||||
expected = [
|
||||
"terminal", "process",
|
||||
"read_file", "write_file", "patch", "search_files",
|
||||
"vision_analyze", "image_generate",
|
||||
"execute_code", "delegate_task",
|
||||
"todo", "memory", "session_search", "cronjob",
|
||||
]
|
||||
for tool in expected:
|
||||
assert tool in tools, f"Missing expected tool: {tool}"
|
||||
|
||||
def test_toolset_includes_browser_tools(self):
|
||||
tools = resolve_toolset("hermes-api-server")
|
||||
for tool in ["browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close"]:
|
||||
assert tool in tools, f"Missing browser tool: {tool}"
|
||||
|
||||
def test_toolset_includes_homeassistant_tools(self):
|
||||
tools = resolve_toolset("hermes-api-server")
|
||||
for tool in ["ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service"]:
|
||||
assert tool in tools, f"Missing HA tool: {tool}"
|
||||
|
||||
def test_toolset_excludes_clarify(self):
|
||||
tools = resolve_toolset("hermes-api-server")
|
||||
assert "clarify" not in tools
|
||||
|
||||
def test_toolset_excludes_send_message(self):
|
||||
tools = resolve_toolset("hermes-api-server")
|
||||
assert "send_message" not in tools
|
||||
|
||||
def test_toolset_excludes_text_to_speech(self):
|
||||
tools = resolve_toolset("hermes-api-server")
|
||||
assert "text_to_speech" not in tools
|
||||
|
||||
|
||||
class TestApiServerPlatformConfig:
|
||||
def test_platforms_dict_includes_api_server(self):
|
||||
from hermes_cli.tools_config import PLATFORMS
|
||||
assert "api_server" in PLATFORMS
|
||||
assert PLATFORMS["api_server"]["default_toolset"] == "hermes-api-server"
|
||||
|
||||
|
||||
class TestApiServerAdapterToolset:
|
||||
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True)
|
||||
def test_create_agent_reads_config_toolsets(self):
|
||||
"""API server resolves toolsets from config like all other platforms."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
adapter = APIServerAdapter(PlatformConfig())
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \
|
||||
patch("gateway.run._resolve_gateway_model") as mock_model, \
|
||||
patch("gateway.run._load_gateway_config") as mock_config, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
|
||||
mock_kwargs.return_value = {"api_key": "test-key", "base_url": None,
|
||||
"provider": None, "api_mode": None,
|
||||
"command": None, "args": []}
|
||||
mock_model.return_value = "test/model"
|
||||
# No platform_toolsets override — should fall back to hermes-api-server default
|
||||
mock_config.return_value = {}
|
||||
mock_agent_cls.return_value = MagicMock()
|
||||
|
||||
adapter._create_agent()
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
call_kwargs = mock_agent_cls.call_args
|
||||
toolsets = call_kwargs.kwargs.get("enabled_toolsets")
|
||||
assert isinstance(toolsets, list)
|
||||
assert len(toolsets) > 0
|
||||
assert call_kwargs.kwargs.get("platform") == "api_server"
|
||||
|
||||
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True)
|
||||
def test_create_agent_respects_config_override(self):
|
||||
"""User can override API server toolsets via platform_toolsets in config.yaml."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
adapter = APIServerAdapter(PlatformConfig())
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \
|
||||
patch("gateway.run._resolve_gateway_model") as mock_model, \
|
||||
patch("gateway.run._load_gateway_config") as mock_config, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
|
||||
mock_kwargs.return_value = {"api_key": "test-key", "base_url": None,
|
||||
"provider": None, "api_mode": None,
|
||||
"command": None, "args": []}
|
||||
mock_model.return_value = "test/model"
|
||||
# User overrides with just web and terminal
|
||||
mock_config.return_value = {
|
||||
"platform_toolsets": {"api_server": ["web", "terminal"]}
|
||||
}
|
||||
mock_agent_cls.return_value = MagicMock()
|
||||
|
||||
adapter._create_agent()
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
call_kwargs = mock_agent_cls.call_args
|
||||
toolsets = call_kwargs.kwargs.get("enabled_toolsets")
|
||||
assert sorted(toolsets) == ["terminal", "web"]
|
||||
@@ -7,11 +7,21 @@ Verifies that:
|
||||
3. The flush still works normally when memory files don't exist
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_dotenv(monkeypatch):
|
||||
"""gateway.run imports dotenv at module level; stub it so tests run without the package."""
|
||||
fake = types.ModuleType("dotenv")
|
||||
fake.load_dotenv = lambda *a, **kw: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
@@ -57,105 +67,151 @@ class TestCronSessionBypass:
|
||||
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
|
||||
|
||||
|
||||
def _make_flush_context(monkeypatch, memory_dir=None):
|
||||
"""Return (runner, tmp_agent, fake_run_agent) with run_agent mocked in sys.modules."""
|
||||
tmp_agent = MagicMock()
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = MagicMock(return_value=tmp_agent)
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
return runner, tmp_agent, memory_dir
|
||||
|
||||
|
||||
class TestMemoryInjection:
|
||||
"""The flush prompt should include current memory state from disk."""
|
||||
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path, monkeypatch):
|
||||
"""When memory files exist, their content appears in the flush prompt."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
|
||||
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch, memory_dir)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_123")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
|
||||
flush_prompt = call_kwargs.get("user_message", "")
|
||||
|
||||
# Verify both memory sections appear in the prompt
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
|
||||
assert "Agent knows Python" in flush_prompt
|
||||
assert "User prefers dark mode" in flush_prompt
|
||||
assert "Name: Alice" in flush_prompt
|
||||
assert "Timezone: PST" in flush_prompt
|
||||
# Verify the stale-overwrite warning is present
|
||||
assert "Do NOT overwrite or remove entries" in flush_prompt
|
||||
assert "current live state of memory" in flush_prompt
|
||||
|
||||
def test_flush_works_without_memory_files(self, tmp_path):
|
||||
def test_flush_works_without_memory_files(self, tmp_path, monkeypatch):
|
||||
"""When no memory files exist, flush still runs without the guard."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
empty_dir = tmp_path / "no_memories"
|
||||
empty_dir.mkdir()
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_456")
|
||||
|
||||
# Should still run, just without the memory guard section
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "Do NOT overwrite or remove entries" not in flush_prompt
|
||||
assert "Review the conversation above" in flush_prompt
|
||||
|
||||
def test_empty_memory_files_no_injection(self, tmp_path):
|
||||
def test_empty_memory_files_no_injection(self, tmp_path, monkeypatch):
|
||||
"""Empty memory files should not trigger the guard section."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("")
|
||||
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_789")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
# No memory content → no guard section
|
||||
assert "current live state of memory" not in flush_prompt
|
||||
|
||||
|
||||
class TestFlushAgentSilenced:
|
||||
"""The flush agent must not produce any terminal output."""
|
||||
|
||||
def test_print_fn_set_to_noop(self, tmp_path, monkeypatch):
|
||||
"""_print_fn on the flush agent must be a no-op so tool output never leaks."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
captured_agent = {}
|
||||
|
||||
def _fake_ai_agent(*args, **kwargs):
|
||||
agent = MagicMock()
|
||||
captured_agent["instance"] = agent
|
||||
return agent
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _fake_ai_agent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=tmp_path)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_silent")
|
||||
|
||||
agent = captured_agent["instance"]
|
||||
assert agent._print_fn is not None, "_print_fn should be overridden to suppress output"
|
||||
# Confirm it is callable and produces no output (no exception)
|
||||
agent._print_fn("should be silenced")
|
||||
|
||||
def test_kawaii_spinner_respects_print_fn(self):
|
||||
"""KawaiiSpinner must route all output through print_fn when supplied."""
|
||||
from agent.display import KawaiiSpinner
|
||||
|
||||
written = []
|
||||
spinner = KawaiiSpinner("test", print_fn=lambda *a, **kw: written.append(a))
|
||||
spinner._write("hello")
|
||||
assert written == [("hello",)], "spinner should route through print_fn"
|
||||
|
||||
# A no-op print_fn must produce no output to stdout
|
||||
import io, sys
|
||||
buf = io.StringIO()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = buf
|
||||
try:
|
||||
silent_spinner = KawaiiSpinner("silent", print_fn=lambda *a, **kw: None)
|
||||
silent_spinner._write("should not appear")
|
||||
silent_spinner.stop("done")
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
assert buf.getvalue() == "", "no-op print_fn spinner must not write to stdout"
|
||||
|
||||
|
||||
class TestFlushPromptStructure:
|
||||
"""Verify the flush prompt retains its core instructions."""
|
||||
|
||||
def test_core_instructions_present(self):
|
||||
def test_core_instructions_present(self, monkeypatch):
|
||||
"""The flush prompt should still contain the original guidance."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Make the import fail gracefully so we test without memory files
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_struct")
|
||||
|
||||
231
tests/gateway/test_send_retry.py
Normal file
231
tests/gateway/test_send_retry.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Tests for BasePlatformAdapter._send_with_retry and _is_retryable_error.
|
||||
|
||||
Verifies that:
|
||||
- Transient network errors trigger retry with backoff
|
||||
- Permanent errors fall back to plain-text immediately (no retry)
|
||||
- User receives a delivery-failure notice when all retries are exhausted
|
||||
- Successful sends on retry return success
|
||||
- SendResult.retryable flag is respected
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult, _RETRYABLE_ERROR_PATTERNS
|
||||
from gateway.platforms.base import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal concrete adapter for testing (no real network)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
cfg = PlatformConfig()
|
||||
super().__init__(cfg, Platform.TELEGRAM)
|
||||
self._send_results = [] # queue of SendResult to return per call
|
||||
self._send_calls = [] # record of (chat_id, content) sent
|
||||
|
||||
def _next_result(self) -> SendResult:
|
||||
if self._send_results:
|
||||
return self._send_results.pop(0)
|
||||
return SendResult(success=True, message_id="ok")
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None, **kwargs) -> SendResult:
|
||||
self._send_calls.append((chat_id, content))
|
||||
return self._next_result()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
pass
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"name": "test", "type": "direct", "chat_id": chat_id}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_retryable_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsRetryableError:
|
||||
def test_none_is_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error(None)
|
||||
|
||||
def test_empty_string_is_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("")
|
||||
|
||||
@pytest.mark.parametrize("pattern", _RETRYABLE_ERROR_PATTERNS)
|
||||
def test_known_pattern_is_retryable(self, pattern):
|
||||
assert _StubAdapter._is_retryable_error(f"httpx.{pattern.title()}: connection dropped")
|
||||
|
||||
def test_permission_error_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("Forbidden: bot was blocked by the user")
|
||||
|
||||
def test_bad_request_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("Bad Request: can't parse entities")
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _StubAdapter._is_retryable_error("CONNECTERROR: host unreachable")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — success on first attempt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetrySuccess:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_first_attempt(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [SendResult(success=True, message_id="123")]
|
||||
result = await adapter._send_with_retry("chat1", "hello")
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_message_id(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [SendResult(success=True, message_id="abc")]
|
||||
result = await adapter._send_with_retry("chat1", "hi")
|
||||
assert result.message_id == "abc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — network error with successful retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryNetworkRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_connect_error_and_succeeds(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="httpx.ConnectError: connection refused"),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2 # initial + 1 retry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_timeout_and_succeeds(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="ReadTimeout: request timed out"),
|
||||
SendResult(success=False, error="ReadTimeout: request timed out"),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=3, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_flag_respected(self):
|
||||
"""SendResult.retryable=True should trigger retry even if error string doesn't match."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="internal platform error", retryable=True),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_to_nonnetwork_transition_falls_back_to_plaintext(self):
|
||||
"""If error switches from network to formatting mid-retry, fall through to plain-text fallback."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="httpx.ConnectError: host unreachable"),
|
||||
SendResult(success=False, error="Bad Request: can't parse entities"),
|
||||
SendResult(success=True, message_id="fallback_ok"), # plain-text fallback
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
# 3 calls: initial (network) + 1 retry (non-network, breaks loop) + plain-text fallback
|
||||
assert len(adapter._send_calls) == 3
|
||||
assert "plain text" in adapter._send_calls[-1][1].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — all retries exhausted → user notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryExhausted:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_user_notice_after_exhaustion(self):
|
||||
adapter = _StubAdapter()
|
||||
network_err = SendResult(success=False, error="httpx.ConnectError: host unreachable")
|
||||
# initial + 2 retries + notice attempt
|
||||
adapter._send_results = [network_err, network_err, network_err, SendResult(success=True)]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
# Result is the last failed one (before notice)
|
||||
assert not result.success
|
||||
# 4 total calls: 1 initial + 2 retries + 1 notice
|
||||
assert len(adapter._send_calls) == 4
|
||||
# The notice content should mention delivery failure
|
||||
notice_content = adapter._send_calls[-1][1]
|
||||
assert "delivery failed" in notice_content.lower() or "Message delivery failed" in notice_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notice_send_exception_doesnt_propagate(self):
|
||||
"""If the notice itself throws, _send_with_retry should not raise."""
|
||||
adapter = _StubAdapter()
|
||||
network_err = SendResult(success=False, error="ConnectError")
|
||||
adapter._send_results = [network_err, network_err, network_err]
|
||||
|
||||
original_send = adapter.send
|
||||
call_count = [0]
|
||||
|
||||
async def send_with_notice_failure(chat_id, content, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] > 3:
|
||||
raise RuntimeError("notice send also failed")
|
||||
return network_err
|
||||
|
||||
adapter.send = send_with_notice_failure
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert not result.success # still failed, but no exception raised
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — non-network failure → plain-text fallback (no retry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_network_error_falls_back_immediately(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="Bad Request: can't parse entities"),
|
||||
SendResult(success=True, message_id="fallback_ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0)
|
||||
# No sleep — no retry loop for non-network errors
|
||||
mock_sleep.assert_not_called()
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2
|
||||
# Fallback content should be plain-text notice
|
||||
assert "plain text" in adapter._send_calls[1][1].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_failure_logged_but_not_raised(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="Forbidden: bot blocked"),
|
||||
SendResult(success=False, error="Forbidden: bot blocked"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2)
|
||||
assert not result.success
|
||||
assert len(adapter._send_calls) == 2 # original + fallback only
|
||||
@@ -859,3 +859,46 @@ class TestLastPromptTokens:
|
||||
billing_base_url=None,
|
||||
model="openai/gpt-5.4",
|
||||
)
|
||||
|
||||
|
||||
class TestRewriteTranscriptPreservesReasoning:
|
||||
"""rewrite_transcript must not drop reasoning fields from SQLite."""
|
||||
|
||||
def test_reasoning_survives_rewrite(self, tmp_path):
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "test.db")
|
||||
session_id = "reasoning-test"
|
||||
db.create_session(session_id=session_id, source="cli")
|
||||
|
||||
# Insert a message WITH all three reasoning fields
|
||||
db.append_message(
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content="The answer is 42.",
|
||||
reasoning="I need to think step by step.",
|
||||
reasoning_details=[{"type": "summary", "text": "step by step"}],
|
||||
codex_reasoning_items=[{"id": "r1", "type": "reasoning"}],
|
||||
)
|
||||
|
||||
# Verify all three were stored
|
||||
before = db.get_messages_as_conversation(session_id)
|
||||
assert before[0].get("reasoning") == "I need to think step by step."
|
||||
assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
# Now simulate /retry: build the SessionStore and call rewrite_transcript
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = db
|
||||
store._loaded = True
|
||||
|
||||
# rewrite_transcript receives the messages that load_transcript returned
|
||||
store.rewrite_transcript(session_id, before)
|
||||
|
||||
# Load again — all three reasoning fields must survive
|
||||
after = db.get_messages_as_conversation(session_id)
|
||||
assert after[0].get("reasoning") == "I need to think step by step."
|
||||
assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
@@ -94,7 +94,7 @@ class TestOfferOpenclawMigration:
|
||||
fake_mod.Migrator.assert_called_once()
|
||||
call_kwargs = fake_mod.Migrator.call_args[1]
|
||||
assert call_kwargs["execute"] is True
|
||||
assert call_kwargs["overwrite"] is False
|
||||
assert call_kwargs["overwrite"] is True
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
assert call_kwargs["preset_name"] == "full"
|
||||
fake_migrator.migrate.assert_called_once()
|
||||
@@ -285,3 +285,182 @@ class TestSetupWizardOpenclawIntegration:
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
mock_migration.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_section_config_summary / _skip_configured_section — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSectionConfigSummary:
|
||||
"""Test the _get_section_config_summary helper."""
|
||||
|
||||
def test_model_returns_none_without_api_key(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "model")
|
||||
assert result is None
|
||||
|
||||
def test_model_returns_summary_with_api_key(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"model": "openai/gpt-4"}, "model"
|
||||
)
|
||||
assert result == "openai/gpt-4"
|
||||
|
||||
def test_model_returns_dict_default_key(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENAI_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"model": {"default": "claude-opus-4", "provider": "anthropic"}},
|
||||
"model",
|
||||
)
|
||||
assert result == "claude-opus-4"
|
||||
|
||||
def test_terminal_always_returns(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"terminal": {"backend": "docker"}}, "terminal"
|
||||
)
|
||||
assert result == "backend: docker"
|
||||
|
||||
def test_agent_always_returns(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"agent": {"max_turns": 120}}, "agent"
|
||||
)
|
||||
assert result == "max turns: 120"
|
||||
|
||||
def test_gateway_returns_none_without_tokens(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert result is None
|
||||
|
||||
def test_gateway_lists_platforms(self):
|
||||
def env_side(key):
|
||||
if key == "TELEGRAM_BOT_TOKEN":
|
||||
return "tok123"
|
||||
if key == "DISCORD_BOT_TOKEN":
|
||||
return "disc456"
|
||||
return ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert "Telegram" in result
|
||||
assert "Discord" in result
|
||||
|
||||
def test_tools_returns_none_without_keys(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "tools")
|
||||
assert result is None
|
||||
|
||||
def test_tools_lists_configured(self):
|
||||
def env_side(key):
|
||||
return "key" if key == "BROWSERBASE_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "tools")
|
||||
assert "Browser" in result
|
||||
|
||||
|
||||
class TestSkipConfiguredSection:
|
||||
"""Test the _skip_configured_section helper."""
|
||||
|
||||
def test_returns_false_when_not_configured(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._skip_configured_section({}, "model", "Model")
|
||||
assert result is False
|
||||
|
||||
def test_returns_true_when_user_skips(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
result = setup_mod._skip_configured_section(
|
||||
{"model": "openai/gpt-4"}, "model", "Model"
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_when_user_wants_reconfig(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
):
|
||||
result = setup_mod._skip_configured_section(
|
||||
{"model": "openai/gpt-4"}, "model", "Model"
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestSetupWizardSkipsConfiguredSections:
|
||||
"""After migration, already-configured sections should offer skip."""
|
||||
|
||||
def test_sections_skipped_when_migration_imported_settings(self, tmp_path):
|
||||
"""When migration ran and API key exists, model section should be skippable.
|
||||
|
||||
Simulates the real flow: get_env_value returns "" during the is_existing
|
||||
check (before migration), then returns a key after migration imported it.
|
||||
"""
|
||||
args = _first_time_args()
|
||||
|
||||
# Track whether migration has "run" — after it does, API key is available
|
||||
migration_done = {"value": False}
|
||||
|
||||
def env_side(key):
|
||||
if migration_done["value"] and key == "OPENROUTER_API_KEY":
|
||||
return "sk-xxx"
|
||||
return ""
|
||||
|
||||
def fake_migration(hermes_home):
|
||||
migration_done["value"] = True
|
||||
return True
|
||||
|
||||
reloaded_config = {"model": "openai/gpt-4"}
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(
|
||||
setup_mod, "load_config",
|
||||
side_effect=[{}, reloaded_config],
|
||||
),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "is_interactive_stdin", return_value=True),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
patch("builtins.input", return_value=""),
|
||||
# Migration succeeds and flips the env_side flag
|
||||
patch.object(
|
||||
setup_mod, "_offer_openclaw_migration",
|
||||
side_effect=fake_migration,
|
||||
),
|
||||
# User says No to all reconfig prompts
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
patch.object(setup_mod, "setup_model_provider") as mock_model,
|
||||
patch.object(setup_mod, "setup_terminal_backend") as mock_terminal,
|
||||
patch.object(setup_mod, "setup_agent_settings") as mock_agent,
|
||||
patch.object(setup_mod, "setup_gateway") as mock_gateway,
|
||||
patch.object(setup_mod, "setup_tools") as mock_tools,
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
# Model has API key → skip offered, user said No → section NOT called
|
||||
mock_model.assert_not_called()
|
||||
# Terminal/agent always have a summary → skip offered, user said No
|
||||
mock_terminal.assert_not_called()
|
||||
mock_agent.assert_not_called()
|
||||
# Gateway has no tokens (env_side returns "" for gateway keys) → section runs
|
||||
mock_gateway.assert_called_once()
|
||||
# Tools have no keys → section runs
|
||||
mock_tools.assert_called_once()
|
||||
|
||||
@@ -96,6 +96,59 @@ class TestVerboseAndToolProgress:
|
||||
assert cli.tool_progress_mode in ("off", "new", "all", "verbose")
|
||||
|
||||
|
||||
class TestBusyInputMode:
|
||||
def test_default_busy_input_mode_is_interrupt(self):
|
||||
cli = _make_cli()
|
||||
assert cli.busy_input_mode == "interrupt"
|
||||
|
||||
def test_busy_input_mode_queue_is_honored(self):
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}})
|
||||
assert cli.busy_input_mode == "queue"
|
||||
|
||||
def test_unknown_busy_input_mode_falls_back_to_interrupt(self):
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "bogus"}})
|
||||
assert cli.busy_input_mode == "interrupt"
|
||||
|
||||
def test_queue_command_works_while_busy(self):
|
||||
"""When agent is running, /queue should still put the prompt in _pending_input."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
cli.process_command("/queue follow up")
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
|
||||
def test_queue_command_works_while_idle(self):
|
||||
"""When agent is idle, /queue should still queue (not reject)."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = False
|
||||
cli.process_command("/queue follow up")
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
|
||||
def test_queue_mode_routes_busy_enter_to_pending(self):
|
||||
"""In queue mode, Enter while busy should go to _pending_input, not _interrupt_queue."""
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}})
|
||||
cli._agent_running = True
|
||||
# Simulate what handle_enter does for non-command input while busy
|
||||
text = "follow up"
|
||||
if cli.busy_input_mode == "queue":
|
||||
cli._pending_input.put(text)
|
||||
else:
|
||||
cli._interrupt_queue.put(text)
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
assert cli._interrupt_queue.empty()
|
||||
|
||||
def test_interrupt_mode_routes_busy_enter_to_interrupt(self):
|
||||
"""In interrupt mode (default), Enter while busy goes to _interrupt_queue."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
text = "redirect"
|
||||
if cli.busy_input_mode == "queue":
|
||||
cli._pending_input.put(text)
|
||||
else:
|
||||
cli._interrupt_queue.put(text)
|
||||
assert cli._interrupt_queue.get_nowait() == "redirect"
|
||||
assert cli._pending_input.empty()
|
||||
|
||||
|
||||
class TestSingleQueryState:
|
||||
def test_voice_and_interrupt_state_initialized_before_run(self):
|
||||
"""Single-query mode calls chat() without going through run()."""
|
||||
|
||||
@@ -182,3 +182,94 @@ class TestCLIUsageReport:
|
||||
assert "Total cost:" in output
|
||||
assert "n/a" in output
|
||||
assert "Pricing unknown for glm-5" in output
|
||||
|
||||
|
||||
class TestStatusBarWidthSource:
|
||||
"""Ensure status bar fragments don't overflow the terminal width."""
|
||||
|
||||
def _make_wide_cli(self):
|
||||
from datetime import datetime, timedelta
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=100_000,
|
||||
completion_tokens=5_000,
|
||||
total_tokens=105_000,
|
||||
api_calls=20,
|
||||
context_tokens=100_000,
|
||||
context_length=200_000,
|
||||
)
|
||||
cli_obj._status_bar_visible = True
|
||||
return cli_obj
|
||||
|
||||
def test_fragments_fit_within_announced_width(self):
|
||||
"""Total fragment text length must not exceed the width used to build them."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
for width in (40, 52, 76, 80, 120, 200):
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=width)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app):
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
|
||||
total_text = "".join(text for _, text in frags)
|
||||
assert len(total_text) <= width + 4, ( # +4 for minor padding chars
|
||||
f"At width={width}, fragment total {len(total_text)} chars overflows "
|
||||
f"({total_text!r})"
|
||||
)
|
||||
|
||||
def test_fragments_use_pt_width_over_shutil(self):
|
||||
"""When prompt_toolkit reports a width, shutil.get_terminal_size must not be used."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=120)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app) as mock_get_app, \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
cli_obj._get_status_bar_fragments()
|
||||
|
||||
mock_shutil.assert_not_called()
|
||||
|
||||
def test_fragments_fall_back_to_shutil_when_no_app(self):
|
||||
"""Outside a TUI context (no running app), shutil must be used as fallback."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", side_effect=Exception("no app")), \
|
||||
patch("shutil.get_terminal_size", return_value=MagicMock(columns=100)) as mock_shutil:
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
|
||||
mock_shutil.assert_called()
|
||||
assert len(frags) > 0
|
||||
|
||||
def test_build_status_bar_text_uses_pt_width(self):
|
||||
"""_build_status_bar_text() must also prefer prompt_toolkit width."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=80)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app), \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
text = cli_obj._build_status_bar_text() # no explicit width
|
||||
|
||||
mock_shutil.assert_not_called()
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_explicit_width_skips_pt_lookup(self):
|
||||
"""An explicit width= argument must bypass both PT and shutil lookups."""
|
||||
from unittest.mock import patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
with patch("prompt_toolkit.application.get_app") as mock_get_app, \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
text = cli_obj._build_status_bar_text(width=100)
|
||||
|
||||
mock_get_app.assert_not_called()
|
||||
mock_shutil.assert_not_called()
|
||||
assert len(text) > 0
|
||||
|
||||
89
tests/test_compressor_fallback_update.py
Normal file
89
tests/test_compressor_fallback_update.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests that _try_activate_fallback updates the context compressor."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
|
||||
def _make_agent_with_compressor() -> AIAgent:
|
||||
"""Build a minimal AIAgent with a context_compressor, skipping __init__."""
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
|
||||
# Primary model settings
|
||||
agent.model = "primary-model"
|
||||
agent.provider = "openrouter"
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.api_key = "sk-primary"
|
||||
agent.api_mode = "chat_completions"
|
||||
agent.client = MagicMock()
|
||||
agent.quiet_mode = True
|
||||
|
||||
# Fallback config
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
# Context compressor with primary model values
|
||||
compressor = ContextCompressor(
|
||||
model="primary-model",
|
||||
threshold_percent=0.50,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="sk-primary",
|
||||
provider="openrouter",
|
||||
quiet_mode=True,
|
||||
)
|
||||
agent.context_compressor = compressor
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.resolve_provider_client")
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
def test_compressor_updated_on_fallback(mock_ctx_len, mock_resolve):
|
||||
"""After fallback activation, the compressor must reflect the fallback model."""
|
||||
agent = _make_agent_with_compressor()
|
||||
|
||||
assert agent.context_compressor.model == "primary-model"
|
||||
|
||||
fb_client = MagicMock()
|
||||
fb_client.base_url = "https://api.openai.com/v1"
|
||||
fb_client.api_key = "sk-fallback"
|
||||
mock_resolve.return_value = (fb_client, None)
|
||||
|
||||
agent._is_direct_openai_url = lambda url: "api.openai.com" in url
|
||||
agent._emit_status = lambda msg: None
|
||||
|
||||
result = agent._try_activate_fallback()
|
||||
|
||||
assert result is True
|
||||
assert agent._fallback_activated is True
|
||||
|
||||
c = agent.context_compressor
|
||||
assert c.model == "gpt-4o"
|
||||
assert c.base_url == "https://api.openai.com/v1"
|
||||
assert c.api_key == "sk-fallback"
|
||||
assert c.provider == "openai"
|
||||
assert c.context_length == 128_000
|
||||
assert c.threshold_tokens == int(128_000 * c.threshold_percent)
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.resolve_provider_client")
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
def test_compressor_not_present_does_not_crash(mock_ctx_len, mock_resolve):
|
||||
"""If the agent has no compressor, fallback should still succeed."""
|
||||
agent = _make_agent_with_compressor()
|
||||
agent.context_compressor = None
|
||||
|
||||
fb_client = MagicMock()
|
||||
fb_client.base_url = "https://api.openai.com/v1"
|
||||
fb_client.api_key = "sk-fallback"
|
||||
mock_resolve.return_value = (fb_client, None)
|
||||
|
||||
agent._is_direct_openai_url = lambda url: "api.openai.com" in url
|
||||
agent._emit_status = lambda msg: None
|
||||
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
36
toolsets.py
36
toolsets.py
@@ -248,6 +248,42 @@ TOOLSETS = {
|
||||
],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"hermes-api-server": {
|
||||
"description": "OpenAI-compatible API server — full agent tools accessible via HTTP (no interactive UI tools like clarify or send_message)",
|
||||
"tools": [
|
||||
# Web
|
||||
"web_search", "web_extract",
|
||||
# Terminal + process management
|
||||
"terminal", "process",
|
||||
# File manipulation
|
||||
"read_file", "write_file", "patch", "search_files",
|
||||
# Vision + image generation
|
||||
"vision_analyze", "image_generate",
|
||||
# MoA
|
||||
"mixture_of_agents",
|
||||
# Skills
|
||||
"skills_list", "skill_view", "skill_manage",
|
||||
# Browser automation
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision", "browser_console",
|
||||
# Planning & memory
|
||||
"todo", "memory",
|
||||
# Session history search
|
||||
"session_search",
|
||||
# Code execution + delegation
|
||||
"execute_code", "delegate_task",
|
||||
# Cronjob management
|
||||
"cronjob",
|
||||
# Home Assistant smart home control (gated on HASS_TOKEN via check_fn)
|
||||
"ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service",
|
||||
# Honcho memory tools (gated on honcho being active via check_fn)
|
||||
"honcho_context", "honcho_profile", "honcho_search", "honcho_conclude",
|
||||
],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"hermes-cli": {
|
||||
"description": "Full interactive CLI toolset - all default tools plus cronjob management",
|
||||
|
||||
Reference in New Issue
Block a user