mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-21 09:31:31 +08:00
Compare commits
30 Commits
opencode-p
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f0ee9245e | ||
|
|
d50e5be500 | ||
|
|
cc54818d26 | ||
|
|
f374ae4c61 | ||
|
|
8fd9fafc84 | ||
|
|
26d6083624 | ||
|
|
470c3ea51a | ||
|
|
388241f798 | ||
|
|
67ae7a79df | ||
|
|
6b0022bb7b | ||
|
|
0109547fa2 | ||
|
|
c66c688727 | ||
|
|
988ecc7420 | ||
|
|
7165eff901 | ||
|
|
714e4941b8 | ||
|
|
23addf48d3 | ||
|
|
4d99305345 | ||
|
|
a933079564 | ||
|
|
0ed28ab80c | ||
|
|
28380e7aed | ||
|
|
970042deab | ||
|
|
9bb83d1298 | ||
|
|
69f85a4dce | ||
|
|
3659e1f0c2 | ||
|
|
21c2d32471 | ||
|
|
f66b3fe76b | ||
|
|
9aa82d4807 | ||
|
|
9b2fb1cc2e | ||
|
|
29c98e8f83 | ||
|
|
9e0fc62650 |
@@ -22,6 +22,9 @@ from acp.schema import (
|
||||
InitializeResponse,
|
||||
ListSessionsResponse,
|
||||
LoadSessionResponse,
|
||||
McpServerHttp,
|
||||
McpServerSse,
|
||||
McpServerStdio,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
@@ -93,6 +96,71 @@ class HermesACPAgent(acp.Agent):
|
||||
self._conn = conn
|
||||
logger.info("ACP client connected")
|
||||
|
||||
async def _register_session_mcp_servers(
|
||||
self,
|
||||
state: SessionState,
|
||||
mcp_servers: list[McpServerStdio | McpServerHttp | McpServerSse] | None,
|
||||
) -> None:
|
||||
"""Register ACP-provided MCP servers and refresh the agent tool surface."""
|
||||
if not mcp_servers:
|
||||
return
|
||||
|
||||
try:
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
config_map: dict[str, dict] = {}
|
||||
for server in mcp_servers:
|
||||
name = server.name
|
||||
if isinstance(server, McpServerStdio):
|
||||
config = {
|
||||
"command": server.command,
|
||||
"args": list(server.args),
|
||||
"env": {item.name: item.value for item in server.env},
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"url": server.url,
|
||||
"headers": {item.name: item.value for item in server.headers},
|
||||
}
|
||||
config_map[name] = config
|
||||
|
||||
await asyncio.to_thread(register_mcp_servers, config_map)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Session %s: failed to register ACP MCP servers",
|
||||
state.session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from model_tools import get_tool_definitions
|
||||
|
||||
enabled_toolsets = getattr(state.agent, "enabled_toolsets", None) or ["hermes-acp"]
|
||||
disabled_toolsets = getattr(state.agent, "disabled_toolsets", None)
|
||||
state.agent.tools = get_tool_definitions(
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets,
|
||||
quiet_mode=True,
|
||||
)
|
||||
state.agent.valid_tool_names = {
|
||||
tool["function"]["name"] for tool in state.agent.tools or []
|
||||
}
|
||||
invalidate = getattr(state.agent, "_invalidate_system_prompt", None)
|
||||
if callable(invalidate):
|
||||
invalidate()
|
||||
logger.info(
|
||||
"Session %s: refreshed tool surface after ACP MCP registration (%d tools)",
|
||||
state.session_id,
|
||||
len(state.agent.tools or []),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Session %s: failed to refresh tool surface after ACP MCP registration",
|
||||
state.session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ---- ACP lifecycle ------------------------------------------------------
|
||||
|
||||
async def initialize(
|
||||
@@ -149,6 +217,7 @@ class HermesACPAgent(acp.Agent):
|
||||
**kwargs: Any,
|
||||
) -> NewSessionResponse:
|
||||
state = self.session_manager.create_session(cwd=cwd)
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
||||
return NewSessionResponse(session_id=state.session_id)
|
||||
|
||||
@@ -163,6 +232,7 @@ class HermesACPAgent(acp.Agent):
|
||||
if state is None:
|
||||
logger.warning("load_session: session %s not found", session_id)
|
||||
return None
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Loaded session %s", session_id)
|
||||
return LoadSessionResponse()
|
||||
|
||||
@@ -177,6 +247,7 @@ class HermesACPAgent(acp.Agent):
|
||||
if state is None:
|
||||
logger.warning("resume_session: session %s not found, creating new", session_id)
|
||||
state = self.session_manager.create_session(cwd=cwd)
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Resumed session %s", state.session_id)
|
||||
return ResumeSessionResponse()
|
||||
|
||||
@@ -200,6 +271,8 @@ class HermesACPAgent(acp.Agent):
|
||||
) -> ForkSessionResponse:
|
||||
state = self.session_manager.fork_session(session_id, cwd=cwd)
|
||||
new_id = state.session_id if state else ""
|
||||
if state is not None:
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Forked session %s -> %s", session_id, new_id)
|
||||
return ForkSessionResponse(session_id=new_id)
|
||||
|
||||
|
||||
@@ -301,8 +301,6 @@ Update the summary using this exact structure. PRESERVE all existing information
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
||||
|
||||
Write the summary in the same language the user was using in the conversation.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
else:
|
||||
# First compaction: summarize from scratch
|
||||
@@ -341,8 +339,6 @@ Use this exact structure:
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details.
|
||||
|
||||
Write the summary in the same language the user was using in the conversation.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
try:
|
||||
|
||||
61
cli.py
61
cli.py
@@ -3052,10 +3052,54 @@ class HermesCLI:
|
||||
print(f" Config File: {config_path} {config_status}")
|
||||
print()
|
||||
|
||||
def _list_recent_sessions(self, limit: int = 10) -> list[dict[str, Any]]:
|
||||
"""Return recent CLI sessions for in-chat browsing/resume affordances."""
|
||||
if not self._session_db:
|
||||
return []
|
||||
try:
|
||||
sessions = self._session_db.list_sessions_rich(
|
||||
source="cli",
|
||||
exclude_sources=["tool"],
|
||||
limit=limit,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
return [s for s in sessions if s.get("id") != self.session_id]
|
||||
|
||||
def _show_recent_sessions(self, *, reason: str = "history", limit: int = 10) -> bool:
|
||||
"""Render recent sessions inline from the active chat TUI.
|
||||
|
||||
Returns True when something was shown, False if no session list was available.
|
||||
"""
|
||||
sessions = self._list_recent_sessions(limit=limit)
|
||||
if not sessions:
|
||||
return False
|
||||
|
||||
from hermes_cli.main import _relative_time
|
||||
|
||||
print()
|
||||
if reason == "history":
|
||||
print("(._.) No messages in the current chat yet — here are recent sessions you can resume:")
|
||||
else:
|
||||
print(" Recent sessions:")
|
||||
print()
|
||||
print(f" {'Title':<32} {'Preview':<40} {'Last Active':<13} {'ID'}")
|
||||
print(f" {'─' * 32} {'─' * 40} {'─' * 13} {'─' * 24}")
|
||||
for session in sessions:
|
||||
title = (session.get("title") or "—")[:30]
|
||||
preview = (session.get("preview") or "")[:38]
|
||||
last_active = _relative_time(session.get("last_active"))
|
||||
print(f" {title:<32} {preview:<40} {last_active:<13} {session['id']}")
|
||||
print()
|
||||
print(" Use /resume <session id or title> to continue where you left off.")
|
||||
print()
|
||||
return True
|
||||
|
||||
def show_history(self):
|
||||
"""Display conversation history."""
|
||||
if not self.conversation_history:
|
||||
print("(._.) No conversation history yet.")
|
||||
if not self._show_recent_sessions(reason="history"):
|
||||
print("(._.) No conversation history yet.")
|
||||
return
|
||||
|
||||
preview_limit = 400
|
||||
@@ -3180,6 +3224,8 @@ class HermesCLI:
|
||||
|
||||
if not target:
|
||||
_cprint(" Usage: /resume <session_id_or_title>")
|
||||
if self._show_recent_sessions(reason="resume"):
|
||||
return
|
||||
_cprint(" Tip: Use /history or `hermes sessions list` to find sessions.")
|
||||
return
|
||||
|
||||
@@ -4970,11 +5016,18 @@ class HermesCLI:
|
||||
return # mcp_servers unchanged (some other section was edited)
|
||||
|
||||
self._config_mcp_servers = new_mcp
|
||||
# Notify user and reload
|
||||
# Notify user and reload. Run in a separate thread with a hard
|
||||
# timeout so a hung MCP server cannot block the process_loop
|
||||
# indefinitely (which would freeze the entire TUI).
|
||||
print()
|
||||
print("🔄 MCP server config changed — reloading connections...")
|
||||
with self._busy_command(self._slow_command_status("/reload-mcp")):
|
||||
self._reload_mcp()
|
||||
_reload_thread = threading.Thread(
|
||||
target=self._reload_mcp, daemon=True
|
||||
)
|
||||
_reload_thread.start()
|
||||
_reload_thread.join(timeout=30)
|
||||
if _reload_thread.is_alive():
|
||||
print(" ⚠️ MCP reload timed out (30s). Some servers may not have reconnected.")
|
||||
|
||||
def _reload_mcp(self):
|
||||
"""Reload MCP servers: disconnect all, re-read config.yaml, reconnect.
|
||||
|
||||
@@ -9,6 +9,7 @@ runs at a time if multiple processes overlap.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -443,8 +444,30 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
session_db=_session_db,
|
||||
)
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
# Run the agent with a timeout so a hung API call or tool doesn't
|
||||
# block the cron ticker thread indefinitely. Default 10 minutes;
|
||||
# override via env var. Uses a separate thread because
|
||||
# run_conversation is synchronous.
|
||||
_cron_timeout = float(os.getenv("HERMES_CRON_TIMEOUT", 600))
|
||||
_cron_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
_cron_future = _cron_pool.submit(agent.run_conversation, prompt)
|
||||
try:
|
||||
result = _cron_future.result(timeout=_cron_timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
logger.error(
|
||||
"Job '%s' timed out after %.0fs — interrupting agent",
|
||||
job_name, _cron_timeout,
|
||||
)
|
||||
if hasattr(agent, "interrupt"):
|
||||
agent.interrupt("Cron job timed out")
|
||||
_cron_pool.shutdown(wait=False, cancel_futures=True)
|
||||
raise TimeoutError(
|
||||
f"Cron job '{job_name}' timed out after "
|
||||
f"{int(_cron_timeout // 60)} minutes"
|
||||
)
|
||||
finally:
|
||||
_cron_pool.shutdown(wait=False)
|
||||
|
||||
final_response = result.get("final_response", "") or ""
|
||||
# Use a separate variable for log display; keep final_response clean
|
||||
# for delivery logic (empty response = no delivery).
|
||||
|
||||
@@ -76,14 +76,13 @@ Open Zed settings (`Cmd+,` on macOS or `Ctrl+,` on Linux) and add to your
|
||||
|
||||
```json
|
||||
{
|
||||
"acp": {
|
||||
"agents": [
|
||||
{
|
||||
"name": "hermes-agent",
|
||||
"registry_dir": "/path/to/hermes-agent/acp_registry"
|
||||
}
|
||||
]
|
||||
}
|
||||
"agent_servers": {
|
||||
"hermes-agent": {
|
||||
"type": "custom",
|
||||
"command": "hermes",
|
||||
"args": ["acp"],
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -563,6 +563,18 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["TELEGRAM_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
|
||||
whatsapp_cfg = yaml_cfg.get("whatsapp", {})
|
||||
if isinstance(whatsapp_cfg, dict):
|
||||
if "require_mention" in whatsapp_cfg and not os.getenv("WHATSAPP_REQUIRE_MENTION"):
|
||||
os.environ["WHATSAPP_REQUIRE_MENTION"] = str(whatsapp_cfg["require_mention"]).lower()
|
||||
if "mention_patterns" in whatsapp_cfg and not os.getenv("WHATSAPP_MENTION_PATTERNS"):
|
||||
os.environ["WHATSAPP_MENTION_PATTERNS"] = json.dumps(whatsapp_cfg["mention_patterns"])
|
||||
frc = whatsapp_cfg.get("free_response_chats")
|
||||
if frc is not None and not os.getenv("WHATSAPP_FREE_RESPONSE_CHATS"):
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["WHATSAPP_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process config.yaml — falling back to .env / gateway.json values. "
|
||||
|
||||
@@ -1046,6 +1046,13 @@ class BasePlatformAdapter(ABC):
|
||||
self._active_sessions[session_key].set()
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
# Mark session as active BEFORE spawning background task to close
|
||||
# the race window where a second message arriving before the task
|
||||
# starts would also pass the _active_sessions check and spawn a
|
||||
# duplicate task. (grammY sequentialize / aiogram EventIsolation
|
||||
# pattern — set the guard synchronously, not inside the task.)
|
||||
self._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Spawn background task to process this message
|
||||
task = asyncio.create_task(self._process_message_background(event, session_key))
|
||||
try:
|
||||
@@ -1092,8 +1099,10 @@ class BasePlatformAdapter(ABC):
|
||||
if getattr(result, "success", False):
|
||||
delivery_succeeded = True
|
||||
|
||||
# Create interrupt event for this session
|
||||
interrupt_event = asyncio.Event()
|
||||
# Reuse the interrupt event set by handle_message() (which marks
|
||||
# the session active before spawning this task to prevent races).
|
||||
# Fall back to a new Event only if the entry was removed externally.
|
||||
interrupt_event = self._active_sessions.get(session_key) or asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
@@ -1106,9 +1115,12 @@ class BasePlatformAdapter(ABC):
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
|
||||
# Send response if any
|
||||
# Send response if any. A None/empty response is normal when
|
||||
# streaming already delivered the text (already_sent=True) or
|
||||
# when the message was queued behind an active agent. Log at
|
||||
# DEBUG to avoid noisy warnings for expected behavior.
|
||||
if not response:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
logger.debug("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
|
||||
@@ -900,7 +900,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception:
|
||||
pass # best-effort truncation
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
# Flood control / RetryAfter — back off and retry once
|
||||
# Flood control / RetryAfter — short waits are retried inline,
|
||||
# long waits return a failure immediately so streaming can fall back
|
||||
# to a normal final send instead of leaving a truncated partial.
|
||||
retry_after = getattr(e, "retry_after", None)
|
||||
if retry_after is not None or "retry after" in err_str:
|
||||
wait = retry_after if retry_after else 1.0
|
||||
@@ -908,6 +910,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"[%s] Telegram flood control, waiting %.1fs",
|
||||
self.name, wait,
|
||||
)
|
||||
if wait > 5.0:
|
||||
return SendResult(success=False, error=f"flood_control:{wait}")
|
||||
await asyncio.sleep(wait)
|
||||
try:
|
||||
await self._bot.edit_message_text(
|
||||
|
||||
@@ -16,9 +16,11 @@ with different backends via a bridge pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
@@ -138,12 +140,137 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
get_hermes_dir("platforms/whatsapp/session", "whatsapp/session")
|
||||
))
|
||||
self._reply_prefix: Optional[str] = config.extra.get("reply_prefix")
|
||||
self._mention_patterns = self._compile_mention_patterns()
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Optional[Path] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._session_lock_identity: Optional[str] = None
|
||||
|
||||
def _whatsapp_require_mention(self) -> bool:
|
||||
configured = self.config.extra.get("require_mention")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() in ("true", "1", "yes", "on")
|
||||
return bool(configured)
|
||||
return os.getenv("WHATSAPP_REQUIRE_MENTION", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
def _whatsapp_free_response_chats(self) -> set[str]:
|
||||
raw = self.config.extra.get("free_response_chats")
|
||||
if raw is None:
|
||||
raw = os.getenv("WHATSAPP_FREE_RESPONSE_CHATS", "")
|
||||
if isinstance(raw, list):
|
||||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
return {part.strip() for part in str(raw).split(",") if part.strip()}
|
||||
|
||||
def _compile_mention_patterns(self):
|
||||
patterns = self.config.extra.get("mention_patterns")
|
||||
if patterns is None:
|
||||
raw = os.getenv("WHATSAPP_MENTION_PATTERNS", "").strip()
|
||||
if raw:
|
||||
try:
|
||||
patterns = json.loads(raw)
|
||||
except Exception:
|
||||
patterns = [part.strip() for part in raw.splitlines() if part.strip()]
|
||||
if not patterns:
|
||||
patterns = [part.strip() for part in raw.split(",") if part.strip()]
|
||||
if patterns is None:
|
||||
return []
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
if not isinstance(patterns, list):
|
||||
logger.warning("[%s] whatsapp mention_patterns must be a list or string; got %s", self.name, type(patterns).__name__)
|
||||
return []
|
||||
|
||||
compiled = []
|
||||
for pattern in patterns:
|
||||
if not isinstance(pattern, str) or not pattern.strip():
|
||||
continue
|
||||
try:
|
||||
compiled.append(re.compile(pattern, re.IGNORECASE))
|
||||
except re.error as exc:
|
||||
logger.warning("[%s] Invalid WhatsApp mention pattern %r: %s", self.name, pattern, exc)
|
||||
if compiled:
|
||||
logger.info("[%s] Loaded %d WhatsApp mention pattern(s)", self.name, len(compiled))
|
||||
return compiled
|
||||
|
||||
@staticmethod
|
||||
def _normalize_whatsapp_id(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
normalized = str(value).strip()
|
||||
if ":" in normalized and "@" in normalized:
|
||||
normalized = normalized.replace(":", "@", 1)
|
||||
return normalized
|
||||
|
||||
def _bot_ids_from_message(self, data: Dict[str, Any]) -> set[str]:
|
||||
bot_ids = set()
|
||||
for candidate in data.get("botIds") or []:
|
||||
normalized = self._normalize_whatsapp_id(candidate)
|
||||
if normalized:
|
||||
bot_ids.add(normalized)
|
||||
return bot_ids
|
||||
|
||||
def _message_is_reply_to_bot(self, data: Dict[str, Any]) -> bool:
|
||||
quoted_participant = self._normalize_whatsapp_id(data.get("quotedParticipant"))
|
||||
if not quoted_participant:
|
||||
return False
|
||||
return quoted_participant in self._bot_ids_from_message(data)
|
||||
|
||||
def _message_mentions_bot(self, data: Dict[str, Any]) -> bool:
|
||||
bot_ids = self._bot_ids_from_message(data)
|
||||
if not bot_ids:
|
||||
return False
|
||||
mentioned_ids = {
|
||||
nid
|
||||
for candidate in (data.get("mentionedIds") or [])
|
||||
if (nid := self._normalize_whatsapp_id(candidate))
|
||||
}
|
||||
if mentioned_ids & bot_ids:
|
||||
return True
|
||||
|
||||
body = str(data.get("body") or "")
|
||||
lower_body = body.lower()
|
||||
for bot_id in bot_ids:
|
||||
bare_id = bot_id.split("@", 1)[0].lower()
|
||||
if bare_id and (f"@{bare_id}" in lower_body or bare_id in lower_body):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _message_matches_mention_patterns(self, data: Dict[str, Any]) -> bool:
|
||||
if not self._mention_patterns:
|
||||
return False
|
||||
body = str(data.get("body") or "")
|
||||
return any(pattern.search(body) for pattern in self._mention_patterns)
|
||||
|
||||
def _clean_bot_mention_text(self, text: str, data: Dict[str, Any]) -> str:
|
||||
if not text:
|
||||
return text
|
||||
bot_ids = self._bot_ids_from_message(data)
|
||||
cleaned = text
|
||||
for bot_id in bot_ids:
|
||||
bare_id = bot_id.split("@", 1)[0]
|
||||
if bare_id:
|
||||
cleaned = re.sub(rf"@{re.escape(bare_id)}\b[,:\-]*\s*", "", cleaned)
|
||||
return cleaned.strip() or text
|
||||
|
||||
def _should_process_message(self, data: Dict[str, Any]) -> bool:
|
||||
if not data.get("isGroup"):
|
||||
return True
|
||||
chat_id = str(data.get("chatId") or "")
|
||||
if chat_id in self._whatsapp_free_response_chats():
|
||||
return True
|
||||
if not self._whatsapp_require_mention():
|
||||
return True
|
||||
body = str(data.get("body") or "").strip()
|
||||
if body.startswith("/"):
|
||||
return True
|
||||
if self._message_is_reply_to_bot(data):
|
||||
return True
|
||||
if self._message_mentions_bot(data):
|
||||
return True
|
||||
return self._message_matches_mention_patterns(data)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
@@ -687,6 +814,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
"""Build a MessageEvent from bridge message data, downloading images to cache."""
|
||||
try:
|
||||
if not self._should_process_message(data):
|
||||
return None
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if data.get("hasMedia"):
|
||||
@@ -768,6 +898,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
# the message text so the agent can read it inline.
|
||||
# Cap at 100KB to match Telegram/Discord/Slack behaviour.
|
||||
body = data.get("body", "")
|
||||
if data.get("isGroup"):
|
||||
body = self._clean_bot_mention_text(body, data)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if msg_type == MessageType.DOCUMENT and cached_urls:
|
||||
for doc_path in cached_urls:
|
||||
|
||||
166
gateway/run.py
166
gateway/run.py
@@ -303,6 +303,43 @@ def _resolve_runtime_agent_kwargs() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _build_media_placeholder(event) -> str:
|
||||
"""Build a text placeholder for media-only events so they aren't dropped.
|
||||
|
||||
When a photo/document is queued during active processing and later
|
||||
dequeued, only .text is extracted. If the event has no caption,
|
||||
the media would be silently lost. This builds a placeholder that
|
||||
the vision enrichment pipeline will replace with a real description.
|
||||
"""
|
||||
parts = []
|
||||
media_urls = getattr(event, "media_urls", None) or []
|
||||
media_types = getattr(event, "media_types", None) or []
|
||||
for i, url in enumerate(media_urls):
|
||||
mtype = media_types[i] if i < len(media_types) else ""
|
||||
if mtype.startswith("image/") or getattr(event, "message_type", None) == MessageType.PHOTO:
|
||||
parts.append(f"[User sent an image: {url}]")
|
||||
elif mtype.startswith("audio/"):
|
||||
parts.append(f"[User sent audio: {url}]")
|
||||
else:
|
||||
parts.append(f"[User sent a file: {url}]")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _dequeue_pending_text(adapter, session_key: str) -> str | None:
|
||||
"""Consume and return the text of a pending queued message.
|
||||
|
||||
Preserves media context for captionless photo/document events by
|
||||
building a placeholder so the message isn't silently dropped.
|
||||
"""
|
||||
event = adapter.get_pending_message(session_key)
|
||||
if not event:
|
||||
return None
|
||||
text = event.text
|
||||
if not text and getattr(event, "media_urls", None):
|
||||
text = _build_media_placeholder(event)
|
||||
return text
|
||||
|
||||
|
||||
def _check_unavailable_skill(command_name: str) -> str | None:
|
||||
"""Check if a command matches a known-but-inactive skill.
|
||||
|
||||
@@ -411,10 +448,14 @@ def _resolve_hermes_bin() -> Optional[list[str]]:
|
||||
class GatewayRunner:
|
||||
"""
|
||||
Main gateway controller.
|
||||
|
||||
|
||||
Manages the lifecycle of all platform adapters and routes
|
||||
messages to/from the agent.
|
||||
"""
|
||||
|
||||
# Class-level defaults so partial construction in tests doesn't
|
||||
# blow up on attribute access.
|
||||
_running_agents_ts: Dict[str, float] = {}
|
||||
|
||||
def __init__(self, config: Optional[GatewayConfig] = None):
|
||||
self.config = config or load_gateway_config()
|
||||
@@ -446,6 +487,7 @@ class GatewayRunner:
|
||||
# Track running agents per session for interrupt support
|
||||
# Key: session_key, Value: AIAgent instance
|
||||
self._running_agents: Dict[str, Any] = {}
|
||||
self._running_agents_ts: Dict[str, float] = {} # start timestamp per session
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
|
||||
# Cache AIAgent instances per session to preserve prompt caching.
|
||||
@@ -1698,6 +1740,20 @@ class GatewayRunner:
|
||||
# simultaneous updates. Do NOT interrupt for photo-only follow-ups here;
|
||||
# let the adapter-level batching/queueing logic absorb them.
|
||||
_quick_key = self._session_key_for_source(source)
|
||||
|
||||
# Staleness eviction: if an entry has been in _running_agents for
|
||||
# longer than the agent timeout, it's a leaked lock from a hung or
|
||||
# crashed handler. Evict it so the session isn't permanently stuck.
|
||||
_STALE_TTL = float(os.getenv("HERMES_AGENT_TIMEOUT", 600)) + 60 # timeout + 1 min grace
|
||||
_stale_ts = self._running_agents_ts.get(_quick_key, 0)
|
||||
if _quick_key in self._running_agents and _stale_ts and (time.time() - _stale_ts) > _STALE_TTL:
|
||||
logger.warning(
|
||||
"Evicting stale _running_agents entry for %s (age: %.0fs)",
|
||||
_quick_key[:30], time.time() - _stale_ts,
|
||||
)
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
|
||||
if _quick_key in self._running_agents:
|
||||
if event.get_command() == "status":
|
||||
return await self._handle_status_command(event)
|
||||
@@ -2023,6 +2079,7 @@ class GatewayRunner:
|
||||
# "already running" guard and spin up a duplicate agent for the
|
||||
# same session — corrupting the transcript.
|
||||
self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL
|
||||
self._running_agents_ts[_quick_key] = time.time()
|
||||
|
||||
try:
|
||||
return await self._handle_message_with_agent(event, source, _quick_key)
|
||||
@@ -2033,6 +2090,7 @@ class GatewayRunner:
|
||||
# not linger or the session would be permanently locked out.
|
||||
if self._running_agents.get(_quick_key) is _AGENT_PENDING_SENTINEL:
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
|
||||
async def _handle_message_with_agent(self, event, source, _quick_key: str):
|
||||
"""Inner handler that runs under the _running_agents sentinel guard."""
|
||||
@@ -2303,7 +2361,18 @@ class GatewayRunner:
|
||||
# 85% * 1.4 = 119% of context — which exceeds the model's limit
|
||||
# and prevented hygiene from ever firing for ~200K models (GLM-5).
|
||||
|
||||
_needs_compress = _approx_tokens >= _compress_token_threshold
|
||||
# Hard safety valve: force compression if message count is
|
||||
# extreme, regardless of token estimates. This breaks the
|
||||
# death spiral where API disconnects prevent token data
|
||||
# collection, which prevents compression, which causes more
|
||||
# disconnects. 400 messages is well above normal sessions
|
||||
# but catches runaway growth before it becomes unrecoverable.
|
||||
# (#2153)
|
||||
_HARD_MSG_LIMIT = 400
|
||||
_needs_compress = (
|
||||
_approx_tokens >= _compress_token_threshold
|
||||
or _msg_count >= _HARD_MSG_LIMIT
|
||||
)
|
||||
|
||||
if _needs_compress:
|
||||
logger.info(
|
||||
@@ -5384,11 +5453,13 @@ class GatewayRunner:
|
||||
progress_lines = [] # Accumulated tool lines
|
||||
progress_msg_id = None # ID of the progress message to edit
|
||||
can_edit = True # False once an edit fails (platform doesn't support it)
|
||||
_last_edit_ts = 0.0 # Throttle edits to avoid Telegram flood control
|
||||
_PROGRESS_EDIT_INTERVAL = 1.5 # Minimum seconds between edits
|
||||
|
||||
while True:
|
||||
try:
|
||||
raw = progress_queue.get_nowait()
|
||||
|
||||
|
||||
# Handle dedup messages: update last line with repeat counter
|
||||
if isinstance(raw, tuple) and len(raw) == 3 and raw[0] == "__dedup__":
|
||||
_, base_msg, count = raw
|
||||
@@ -5399,6 +5470,19 @@ class GatewayRunner:
|
||||
msg = raw
|
||||
progress_lines.append(msg)
|
||||
|
||||
# Throttle edits: batch rapid tool updates into fewer
|
||||
# API calls to avoid hitting Telegram flood control.
|
||||
# (grammY auto-retry pattern: proactively rate-limit
|
||||
# instead of reacting to 429s.)
|
||||
_now = time.monotonic()
|
||||
_remaining = _PROGRESS_EDIT_INTERVAL - (_now - _last_edit_ts)
|
||||
if _remaining > 0:
|
||||
# Wait out the throttle interval, then loop back to
|
||||
# drain any additional queued messages before sending
|
||||
# a single batched edit.
|
||||
await asyncio.sleep(_remaining)
|
||||
continue
|
||||
|
||||
if can_edit and progress_msg_id is not None:
|
||||
# Try to edit the existing progress message
|
||||
full_text = "\n".join(progress_lines)
|
||||
@@ -5408,8 +5492,15 @@ class GatewayRunner:
|
||||
content=full_text,
|
||||
)
|
||||
if not result.success:
|
||||
# Platform doesn't support editing — stop trying,
|
||||
# send just this new line as a separate message
|
||||
_err = (getattr(result, "error", "") or "").lower()
|
||||
if "flood" in _err or "retry after" in _err:
|
||||
# Flood control hit — disable further edits,
|
||||
# switch to sending new messages only for
|
||||
# important updates. Don't block 23s.
|
||||
logger.info(
|
||||
"[%s] Progress edits disabled due to flood control",
|
||||
adapter.name,
|
||||
)
|
||||
can_edit = False
|
||||
await adapter.send(chat_id=source.chat_id, content=msg, metadata=_progress_metadata)
|
||||
else:
|
||||
@@ -5423,6 +5514,8 @@ class GatewayRunner:
|
||||
if result.success and result.message_id:
|
||||
progress_msg_id = result.message_id
|
||||
|
||||
_last_edit_ts = time.monotonic()
|
||||
|
||||
# Restore typing indicator
|
||||
await asyncio.sleep(0.3)
|
||||
await adapter.send_typing(source.chat_id, metadata=_progress_metadata)
|
||||
@@ -5468,15 +5561,25 @@ class GatewayRunner:
|
||||
_loop_for_step = asyncio.get_event_loop()
|
||||
_hooks_ref = self.hooks
|
||||
|
||||
def _step_callback_sync(iteration: int, tool_names: list) -> None:
|
||||
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
||||
try:
|
||||
# prev_tools may be list[str] or list[dict] with "name"/"result"
|
||||
# keys. Normalise to keep "tool_names" backward-compatible for
|
||||
# user-authored hooks that do ', '.join(tool_names)'.
|
||||
_names: list[str] = []
|
||||
for _t in (prev_tools or []):
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_hooks_ref.emit("agent:step", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
"user_id": source.user_id,
|
||||
"session_id": session_id,
|
||||
"iteration": iteration,
|
||||
"tool_names": tool_names,
|
||||
"tool_names": _names,
|
||||
"tools": prev_tools,
|
||||
}),
|
||||
_loop_for_step,
|
||||
)
|
||||
@@ -5933,9 +6036,38 @@ class GatewayRunner:
|
||||
interrupt_monitor = asyncio.create_task(monitor_for_interrupt())
|
||||
|
||||
try:
|
||||
# Run in thread pool to not block
|
||||
# Run in thread pool to not block. Cap total execution time
|
||||
# so a hung API call or runaway tool doesn't permanently lock
|
||||
# the session. Default 10 minutes; override with env var.
|
||||
_agent_timeout = float(os.getenv("HERMES_AGENT_TIMEOUT", 600))
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, run_sync)
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, run_sync),
|
||||
timeout=_agent_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"Agent execution timed out after %.0fs for session %s",
|
||||
_agent_timeout, session_key,
|
||||
)
|
||||
# Interrupt the agent if it's still running so the thread
|
||||
# pool worker is freed.
|
||||
_timed_out_agent = agent_holder[0]
|
||||
if _timed_out_agent and hasattr(_timed_out_agent, "interrupt"):
|
||||
_timed_out_agent.interrupt("Execution timed out")
|
||||
response = {
|
||||
"final_response": (
|
||||
f"⏱️ Request timed out after {int(_agent_timeout // 60)} minutes. "
|
||||
"The agent may have been stuck on a tool or API call.\n"
|
||||
"Try again, or use /reset to start fresh."
|
||||
),
|
||||
"messages": result_holder[0].get("messages", []) if result_holder[0] else [],
|
||||
"api_calls": 0,
|
||||
"tools": tools_holder[0] or [],
|
||||
"history_offset": 0,
|
||||
"failed": True,
|
||||
}
|
||||
|
||||
# Track fallback model state: if the agent switched to a
|
||||
# fallback model during this run, persist it so /model shows
|
||||
@@ -5963,18 +6095,12 @@ class GatewayRunner:
|
||||
pending = None
|
||||
if result and adapter and session_key:
|
||||
if result.get("interrupted"):
|
||||
# Interrupted — consume the interrupt message
|
||||
pending_event = adapter.get_pending_message(session_key)
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
elif result.get("interrupt_message"):
|
||||
pending = _dequeue_pending_text(adapter, session_key)
|
||||
if not pending and result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
else:
|
||||
# Normal completion — check for /queue'd messages that were
|
||||
# stored without triggering an interrupt.
|
||||
pending_event = adapter.get_pending_message(session_key)
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
pending = _dequeue_pending_text(adapter, session_key)
|
||||
if pending:
|
||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||
|
||||
if pending:
|
||||
@@ -6050,6 +6176,8 @@ class GatewayRunner:
|
||||
tracking_task.cancel()
|
||||
if session_key and session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
if session_key:
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
|
||||
# Wait for cancelled tasks
|
||||
for task in [progress_task, interrupt_monitor, tracking_task]:
|
||||
|
||||
@@ -174,12 +174,12 @@ class GatewayStreamConsumer:
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
else:
|
||||
# Edit not supported by this adapter — stop streaming,
|
||||
# let the normal send path handle the final response.
|
||||
# Without this guard, adapters like Signal/Email would
|
||||
# flood the chat with a new message every edit_interval.
|
||||
# If an edit fails mid-stream (especially Telegram flood control),
|
||||
# stop progressive edits and let the normal final send path deliver
|
||||
# the complete answer instead of leaving the user with a partial.
|
||||
logger.debug("Edit failed, disabling streaming for this adapter")
|
||||
self._edit_supported = False
|
||||
self._already_sent = False
|
||||
else:
|
||||
# Editing not supported — skip intermediate updates.
|
||||
# The final response will be sent by the normal path.
|
||||
|
||||
@@ -258,8 +258,11 @@ def _system_service_identity(run_as_user: str | None = None) -> tuple[str, str,
|
||||
username = (run_as_user or os.getenv("SUDO_USER") or os.getenv("USER") or os.getenv("LOGNAME") or getpass.getuser()).strip()
|
||||
if not username:
|
||||
raise ValueError("Could not determine which user the gateway service should run as")
|
||||
if username == "root" and not run_as_user:
|
||||
raise ValueError("Refusing to install the gateway system service as root; pass --run-as-user root to override (e.g. in LXC containers)")
|
||||
if username == "root":
|
||||
raise ValueError("Refusing to install the gateway system service as root; pass --run-as USER")
|
||||
print_warning("Installing gateway service to run as root.")
|
||||
print_info(" This is fine for LXC/container environments but not recommended on bare-metal hosts.")
|
||||
|
||||
try:
|
||||
user_info = pwd.getpwnam(username)
|
||||
@@ -321,9 +324,9 @@ def install_linux_gateway_from_setup(force: bool = False) -> tuple[str | None, b
|
||||
while True:
|
||||
run_as_user = prompt(" Run the system gateway service as which user?", default="")
|
||||
run_as_user = (run_as_user or "").strip()
|
||||
if run_as_user and run_as_user != "root":
|
||||
if run_as_user:
|
||||
break
|
||||
print_error(" Enter a non-root username.")
|
||||
print_error(" Enter a username.")
|
||||
|
||||
systemd_install(force=force, system=True, run_as_user=run_as_user)
|
||||
return scope, True
|
||||
|
||||
@@ -2682,6 +2682,20 @@ def _stash_local_changes_if_needed(git_cmd: list[str], cwd: Path) -> Optional[st
|
||||
if not status.stdout.strip():
|
||||
return None
|
||||
|
||||
# If the index has unmerged entries (e.g. from an interrupted merge/rebase),
|
||||
# git stash will fail with "needs merge / could not write index". Clear the
|
||||
# conflict state with `git reset` so the stash can proceed. Working-tree
|
||||
# changes are preserved; only the index conflict markers are dropped.
|
||||
unmerged = subprocess.run(
|
||||
git_cmd + ["ls-files", "--unmerged"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if unmerged.stdout.strip():
|
||||
print("→ Clearing unmerged index entries from a previous conflict...")
|
||||
subprocess.run(git_cmd + ["reset"], cwd=cwd, capture_output=True)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
stash_name = datetime.now(timezone.utc).strftime("hermes-update-autostash-%Y%m%d-%H%M%S")
|
||||
@@ -2835,6 +2849,231 @@ def _restore_stashed_changes(
|
||||
print(" Review `git diff` / `git status` if Hermes behaves unexpectedly.")
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Fork detection and upstream management for `hermes update`
|
||||
# =========================================================================
|
||||
|
||||
OFFICIAL_REPO_URLS = {
|
||||
"https://github.com/NousResearch/hermes-agent.git",
|
||||
"git@github.com:NousResearch/hermes-agent.git",
|
||||
"https://github.com/NousResearch/hermes-agent",
|
||||
"git@github.com:NousResearch/hermes-agent",
|
||||
}
|
||||
OFFICIAL_REPO_URL = "https://github.com/NousResearch/hermes-agent.git"
|
||||
SKIP_UPSTREAM_PROMPT_FILE = ".skip_upstream_prompt"
|
||||
|
||||
|
||||
def _get_origin_url(git_cmd: list[str], cwd: Path) -> Optional[str]:
|
||||
"""Get the URL of the origin remote, or None if not set."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
git_cmd + ["remote", "get-url", "origin"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _is_fork(origin_url: Optional[str]) -> bool:
|
||||
"""Check if the origin remote points to a fork (not the official repo)."""
|
||||
if not origin_url:
|
||||
return False
|
||||
# Normalize URL for comparison (strip trailing .git if present)
|
||||
normalized = origin_url.rstrip("/")
|
||||
if normalized.endswith(".git"):
|
||||
normalized = normalized[:-4]
|
||||
for official in OFFICIAL_REPO_URLS:
|
||||
official_normalized = official.rstrip("/")
|
||||
if official_normalized.endswith(".git"):
|
||||
official_normalized = official_normalized[:-4]
|
||||
if normalized == official_normalized:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _has_upstream_remote(git_cmd: list[str], cwd: Path) -> bool:
|
||||
"""Check if an 'upstream' remote already exists."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
git_cmd + ["remote", "get-url", "upstream"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _add_upstream_remote(git_cmd: list[str], cwd: Path) -> bool:
|
||||
"""Add the official repo as the 'upstream' remote. Returns True on success."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
git_cmd + ["remote", "add", "upstream", OFFICIAL_REPO_URL],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _count_commits_between(git_cmd: list[str], cwd: Path, base: str, head: str) -> int:
|
||||
"""Count commits on `head` that are not on `base`. Returns -1 on error."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
git_cmd + ["rev-list", "--count", f"{base}..{head}"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return int(result.stdout.strip())
|
||||
except Exception:
|
||||
pass
|
||||
return -1
|
||||
|
||||
|
||||
def _should_skip_upstream_prompt() -> bool:
|
||||
"""Check if user previously declined to add upstream."""
|
||||
from hermes_constants import get_hermes_home
|
||||
return (get_hermes_home() / SKIP_UPSTREAM_PROMPT_FILE).exists()
|
||||
|
||||
|
||||
def _mark_skip_upstream_prompt():
|
||||
"""Create marker file to skip future upstream prompts."""
|
||||
try:
|
||||
from hermes_constants import get_hermes_home
|
||||
(get_hermes_home() / SKIP_UPSTREAM_PROMPT_FILE).touch()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _sync_fork_with_upstream(git_cmd: list[str], cwd: Path) -> bool:
|
||||
"""Attempt to push updated main to origin (sync fork).
|
||||
|
||||
Returns True if push succeeded, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
git_cmd + ["push", "origin", "main", "--force-with-lease"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _sync_with_upstream_if_needed(git_cmd: list[str], cwd: Path) -> None:
|
||||
"""Check if fork is behind upstream and sync if safe.
|
||||
|
||||
This implements the fork upstream sync logic:
|
||||
- If upstream remote doesn't exist, ask user if they want to add it
|
||||
- Compare origin/main with upstream/main
|
||||
- If origin/main is strictly behind upstream/main, pull from upstream
|
||||
- Try to sync fork back to origin if possible
|
||||
"""
|
||||
has_upstream = _has_upstream_remote(git_cmd, cwd)
|
||||
|
||||
if not has_upstream:
|
||||
# Check if user previously declined
|
||||
if _should_skip_upstream_prompt():
|
||||
return
|
||||
|
||||
# Ask user if they want to add upstream
|
||||
print()
|
||||
print("ℹ Your fork is not tracking the official Hermes repository.")
|
||||
print(" This means you may miss updates from NousResearch/hermes-agent.")
|
||||
print()
|
||||
try:
|
||||
response = input("Add official repo as 'upstream' remote? [Y/n]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
response = "n"
|
||||
|
||||
if response in ("", "y", "yes"):
|
||||
print("→ Adding upstream remote...")
|
||||
if _add_upstream_remote(git_cmd, cwd):
|
||||
print(" ✓ Added upstream: https://github.com/NousResearch/hermes-agent.git")
|
||||
has_upstream = True
|
||||
else:
|
||||
print(" ✗ Failed to add upstream remote. Skipping upstream sync.")
|
||||
return
|
||||
else:
|
||||
print(" Skipped. Run 'git remote add upstream https://github.com/NousResearch/hermes-agent.git' to add later.")
|
||||
_mark_skip_upstream_prompt()
|
||||
return
|
||||
|
||||
# Fetch upstream
|
||||
print()
|
||||
print("→ Fetching upstream...")
|
||||
try:
|
||||
subprocess.run(
|
||||
git_cmd + ["fetch", "upstream", "--quiet"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
print(" ✗ Failed to fetch upstream. Skipping upstream sync.")
|
||||
return
|
||||
|
||||
# Compare origin/main with upstream/main
|
||||
origin_ahead = _count_commits_between(git_cmd, cwd, "upstream/main", "origin/main")
|
||||
upstream_ahead = _count_commits_between(git_cmd, cwd, "origin/main", "upstream/main")
|
||||
|
||||
if origin_ahead < 0 or upstream_ahead < 0:
|
||||
print(" ✗ Could not compare branches. Skipping upstream sync.")
|
||||
return
|
||||
|
||||
# If origin/main has commits not on upstream, don't trample
|
||||
if origin_ahead > 0:
|
||||
print()
|
||||
print(f"ℹ Your fork has {origin_ahead} commit(s) not on upstream.")
|
||||
print(" Skipping upstream sync to preserve your changes.")
|
||||
print(" If you want to merge upstream changes, run:")
|
||||
print(" git pull upstream main")
|
||||
return
|
||||
|
||||
# If upstream is not ahead, fork is up to date
|
||||
if upstream_ahead == 0:
|
||||
print(" ✓ Fork is up to date with upstream")
|
||||
return
|
||||
|
||||
# origin/main is strictly behind upstream/main (can fast-forward)
|
||||
print()
|
||||
print(f"→ Fork is {upstream_ahead} commit(s) behind upstream")
|
||||
print("→ Pulling from upstream...")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
git_cmd + ["pull", "--ff-only", "upstream", "main"],
|
||||
cwd=cwd,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
print(" ✗ Failed to pull from upstream. You may need to resolve conflicts manually.")
|
||||
return
|
||||
|
||||
print(" ✓ Updated from upstream")
|
||||
|
||||
# Try to sync fork back to origin
|
||||
print("→ Syncing fork...")
|
||||
if _sync_fork_with_upstream(git_cmd, cwd):
|
||||
print(" ✓ Fork synced with upstream")
|
||||
else:
|
||||
print(" ℹ Got updates from upstream but couldn't push to fork (no write access?)")
|
||||
print(" Your local repo is updated, but your fork on GitHub may be behind.")
|
||||
|
||||
|
||||
def _invalidate_update_cache():
|
||||
"""Delete the update-check cache for ALL profiles so no banner
|
||||
reports a stale "commits behind" count after a successful update.
|
||||
@@ -2971,6 +3210,20 @@ def cmd_update(args):
|
||||
cwd=PROJECT_ROOT, check=False, capture_output=True
|
||||
)
|
||||
|
||||
# Build git command once — reused for fork detection and the update itself.
|
||||
git_cmd = ["git"]
|
||||
if sys.platform == "win32":
|
||||
git_cmd = ["git", "-c", "windows.appendAtomically=false"]
|
||||
|
||||
# Detect if we're updating from a fork (before any branch logic)
|
||||
origin_url = _get_origin_url(git_cmd, PROJECT_ROOT)
|
||||
is_fork = _is_fork(origin_url)
|
||||
|
||||
if is_fork:
|
||||
print("⚠ Updating from fork:")
|
||||
print(f" {origin_url}")
|
||||
print()
|
||||
|
||||
if use_zip_update:
|
||||
# ZIP-based update for Windows when git is broken
|
||||
_update_via_zip(args)
|
||||
@@ -2978,9 +3231,6 @@ def cmd_update(args):
|
||||
|
||||
# Fetch and pull
|
||||
try:
|
||||
git_cmd = ["git"]
|
||||
if sys.platform == "win32":
|
||||
git_cmd = ["git", "-c", "windows.appendAtomically=false"]
|
||||
|
||||
print("→ Fetching updates...")
|
||||
fetch_result = subprocess.run(
|
||||
@@ -3111,6 +3361,10 @@ def cmd_update(args):
|
||||
removed = _clear_bytecode_cache(PROJECT_ROOT)
|
||||
if removed:
|
||||
print(f" ✓ Cleared {removed} stale __pycache__ director{'y' if removed == 1 else 'ies'}")
|
||||
|
||||
# Fork upstream sync logic (only for main branch on forks)
|
||||
if is_fork and branch == "main":
|
||||
_sync_with_upstream_if_needed(git_cmd, PROJECT_ROOT)
|
||||
|
||||
# Reinstall Python dependencies. Prefer .[all], but if one optional extra
|
||||
# breaks on this machine, keep base deps and reinstall the remaining extras
|
||||
@@ -3269,8 +3523,8 @@ def cmd_update(args):
|
||||
from gateway.status import get_running_pid, remove_pid_file
|
||||
from hermes_cli.gateway import (
|
||||
get_service_name, get_launchd_plist_path, is_macos, is_linux,
|
||||
refresh_launchd_plist_if_needed,
|
||||
_ensure_user_systemd_env, get_systemd_linger_status,
|
||||
launchd_restart, _ensure_user_systemd_env,
|
||||
get_systemd_linger_status,
|
||||
)
|
||||
import signal as _signal
|
||||
|
||||
@@ -3374,26 +3628,15 @@ def cmd_update(args):
|
||||
print(" System services may require root. Try:")
|
||||
print(f" sudo systemctl restart {_gw_service_name}")
|
||||
elif has_launchd_service:
|
||||
# Refresh the plist first (picks up --replace and other
|
||||
# changes from the update we just pulled).
|
||||
refresh_launchd_plist_if_needed()
|
||||
# Explicit stop+start — don't rely on KeepAlive respawn
|
||||
# after a manual SIGTERM, which would race with the
|
||||
# PID file cleanup.
|
||||
# Use the shared launchd restart helper so we wait for the
|
||||
# old gateway process to fully exit before starting the new
|
||||
# one. This avoids stop/start races during self-update.
|
||||
print("→ Restarting gateway service...")
|
||||
_launchd_label = get_launchd_label()
|
||||
stop = subprocess.run(
|
||||
["launchctl", "stop", _launchd_label],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
start = subprocess.run(
|
||||
["launchctl", "start", _launchd_label],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
if start.returncode == 0:
|
||||
print("✓ Gateway restarted via launchd.")
|
||||
else:
|
||||
print(f"⚠ Gateway restart failed: {start.stderr.strip()}")
|
||||
try:
|
||||
launchd_restart()
|
||||
except subprocess.CalledProcessError as e:
|
||||
stderr = (getattr(e, "stderr", "") or "").strip()
|
||||
print(f"⚠ Gateway restart failed: {stderr}")
|
||||
print(" Try manually: hermes gateway restart")
|
||||
elif existing_pid:
|
||||
try:
|
||||
|
||||
@@ -28,7 +28,7 @@ GITHUB_MODELS_CATALOG_URL = COPILOT_MODELS_URL
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
("anthropic/claude-sonnet-4.6", ""),
|
||||
("qwen/qwen3.6-plus-preview:free", "free"),
|
||||
("qwen/qwen3.6-plus:free", "free"),
|
||||
("anthropic/claude-sonnet-4.5", ""),
|
||||
("anthropic/claude-haiku-4.5", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
@@ -59,7 +59,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
"qwen/qwen3.6-plus-preview:free",
|
||||
"qwen/qwen3.6-plus:free",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-haiku-4.5",
|
||||
"openai/gpt-5.4",
|
||||
|
||||
@@ -561,7 +561,7 @@ def _get_platform_tools(
|
||||
# MCP servers are expected to be available on all platforms by default.
|
||||
# If the platform explicitly lists one or more MCP server names, treat that
|
||||
# as an allowlist. Otherwise include every globally enabled MCP server.
|
||||
mcp_servers = config.get("mcp_servers", {})
|
||||
mcp_servers = config.get("mcp_servers") or {}
|
||||
enabled_mcp_servers = {
|
||||
name
|
||||
for name, server_cfg in mcp_servers.items()
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
@@ -108,6 +109,9 @@ CONCLUDE_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
ALL_TOOL_SCHEMAS = [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, CONCLUDE_SCHEMA]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -124,6 +128,34 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
self._prefetch_thread: Optional[threading.Thread] = None
|
||||
self._sync_thread: Optional[threading.Thread] = None
|
||||
|
||||
# B1: recall_mode — set during initialize from config
|
||||
self._recall_mode = "hybrid" # "context", "tools", or "hybrid"
|
||||
|
||||
# B4: First-turn context baking
|
||||
self._first_turn_context: Optional[str] = None
|
||||
self._first_turn_lock = threading.Lock()
|
||||
|
||||
# B5: Cost-awareness turn counting and cadence
|
||||
self._turn_count = 0
|
||||
self._injection_frequency = "every-turn" # or "first-turn"
|
||||
self._context_cadence = 1 # minimum turns between context API calls
|
||||
self._dialectic_cadence = 1 # minimum turns between dialectic API calls
|
||||
self._reasoning_level_cap: Optional[str] = None # "minimal", "low", "mid", "high"
|
||||
self._last_context_turn = -999
|
||||
self._last_dialectic_turn = -999
|
||||
|
||||
# B2: peer_memory_mode gating (stub)
|
||||
self._suppress_memory = False
|
||||
self._suppress_user_profile = False
|
||||
|
||||
# Port #1957: lazy session init for tools-only mode
|
||||
self._session_initialized = False
|
||||
self._lazy_init_kwargs: Optional[dict] = None
|
||||
self._lazy_init_session_id: Optional[str] = None
|
||||
|
||||
# Port #4053: cron guard — when True, plugin is fully inactive
|
||||
self._cron_skipped = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "honcho"
|
||||
@@ -133,6 +165,7 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
try:
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
cfg = HonchoClientConfig.from_global_config()
|
||||
# Port #2645: baseUrl-only verification — api_key OR base_url suffices
|
||||
return cfg.enabled and bool(cfg.api_key or cfg.base_url)
|
||||
except Exception:
|
||||
return False
|
||||
@@ -158,8 +191,22 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
]
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
"""Initialize Honcho session manager."""
|
||||
"""Initialize Honcho session manager.
|
||||
|
||||
Handles: cron guard, recall_mode, session name resolution,
|
||||
peer memory mode, SOUL.md ai_peer sync, memory file migration,
|
||||
and pre-warming context at init.
|
||||
"""
|
||||
try:
|
||||
# ----- Port #4053: cron guard -----
|
||||
agent_context = kwargs.get("agent_context", "")
|
||||
platform = kwargs.get("platform", "cli")
|
||||
if agent_context in ("cron", "flush") or platform == "cron":
|
||||
logger.debug("Honcho skipped: cron/flush context (agent_context=%s, platform=%s)",
|
||||
agent_context, platform)
|
||||
self._cron_skipped = True
|
||||
return
|
||||
|
||||
from plugins.memory.honcho.client import HonchoClientConfig, get_honcho_client
|
||||
from plugins.memory.honcho.session import HonchoSessionManager
|
||||
|
||||
@@ -169,20 +216,78 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
return
|
||||
|
||||
self._config = cfg
|
||||
client = get_honcho_client(cfg)
|
||||
self._manager = HonchoSessionManager(
|
||||
honcho=client,
|
||||
config=cfg,
|
||||
context_tokens=cfg.context_tokens,
|
||||
)
|
||||
|
||||
# Build session key from kwargs or session_id
|
||||
platform = kwargs.get("platform", "cli")
|
||||
user_id = kwargs.get("user_id", "")
|
||||
if user_id:
|
||||
self._session_key = f"{platform}:{user_id}"
|
||||
else:
|
||||
self._session_key = session_id
|
||||
# ----- B1: recall_mode from config -----
|
||||
self._recall_mode = cfg.recall_mode # "context", "tools", or "hybrid"
|
||||
logger.debug("Honcho recall_mode: %s", self._recall_mode)
|
||||
|
||||
# ----- B5: cost-awareness config -----
|
||||
try:
|
||||
raw = cfg.raw or {}
|
||||
self._injection_frequency = raw.get("injectionFrequency", "every-turn")
|
||||
self._context_cadence = int(raw.get("contextCadence", 1))
|
||||
self._dialectic_cadence = int(raw.get("dialecticCadence", 1))
|
||||
cap = raw.get("reasoningLevelCap")
|
||||
if cap and cap in ("minimal", "low", "mid", "high"):
|
||||
self._reasoning_level_cap = cap
|
||||
except Exception as e:
|
||||
logger.debug("Honcho cost-awareness config parse error: %s", e)
|
||||
|
||||
# ----- Port #1969: aiPeer sync from SOUL.md -----
|
||||
try:
|
||||
hermes_home = kwargs.get("hermes_home", "")
|
||||
if hermes_home and not cfg.raw.get("aiPeer"):
|
||||
soul_path = Path(hermes_home) / "SOUL.md"
|
||||
if soul_path.exists():
|
||||
soul_text = soul_path.read_text(encoding="utf-8").strip()
|
||||
if soul_text:
|
||||
# Try YAML frontmatter: "name: Foo"
|
||||
first_line = soul_text.split("\n")[0].strip()
|
||||
if first_line.startswith("---"):
|
||||
# Look for name: in frontmatter
|
||||
for line in soul_text.split("\n")[1:]:
|
||||
line = line.strip()
|
||||
if line == "---":
|
||||
break
|
||||
if line.lower().startswith("name:"):
|
||||
name_val = line.split(":", 1)[1].strip().strip("\"'")
|
||||
if name_val:
|
||||
cfg.ai_peer = name_val
|
||||
logger.debug("Honcho ai_peer set from SOUL.md: %s", name_val)
|
||||
break
|
||||
elif first_line.startswith("# "):
|
||||
# Markdown heading: "# AgentName"
|
||||
name_val = first_line[2:].strip()
|
||||
if name_val:
|
||||
cfg.ai_peer = name_val
|
||||
logger.debug("Honcho ai_peer set from SOUL.md heading: %s", name_val)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho SOUL.md ai_peer sync failed: %s", e)
|
||||
|
||||
# ----- B2: peer_memory_mode gating (stub) -----
|
||||
try:
|
||||
ai_mode = cfg.peer_memory_mode(cfg.ai_peer)
|
||||
user_mode = cfg.peer_memory_mode(cfg.peer_name or "user")
|
||||
# "honcho" means Honcho owns memory; suppress built-in
|
||||
self._suppress_memory = (ai_mode == "honcho")
|
||||
self._suppress_user_profile = (user_mode == "honcho")
|
||||
logger.debug("Honcho peer_memory_mode: ai=%s (suppress_memory=%s), user=%s (suppress_user_profile=%s)",
|
||||
ai_mode, self._suppress_memory, user_mode, self._suppress_user_profile)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho peer_memory_mode check failed: %s", e)
|
||||
|
||||
# ----- Port #1957: lazy session init for tools-only mode -----
|
||||
if self._recall_mode == "tools":
|
||||
# Defer actual session creation until first tool call
|
||||
self._lazy_init_kwargs = kwargs
|
||||
self._lazy_init_session_id = session_id
|
||||
# Still need a client reference for _ensure_session
|
||||
self._config = cfg
|
||||
logger.debug("Honcho tools-only mode — deferring session init until first tool call")
|
||||
return
|
||||
|
||||
# ----- Eager init (context or hybrid mode) -----
|
||||
self._do_session_init(cfg, session_id, **kwargs)
|
||||
|
||||
except ImportError:
|
||||
logger.debug("honcho-ai package not installed — plugin inactive")
|
||||
@@ -190,19 +295,180 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
logger.warning("Honcho init failed: %s", e)
|
||||
self._manager = None
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._manager or not self._session_key:
|
||||
return ""
|
||||
return (
|
||||
"# Honcho Memory\n"
|
||||
"Active. AI-native cross-session user modeling.\n"
|
||||
"Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user."
|
||||
def _do_session_init(self, cfg, session_id: str, **kwargs) -> None:
|
||||
"""Shared session initialization logic for both eager and lazy paths."""
|
||||
from plugins.memory.honcho.client import get_honcho_client
|
||||
from plugins.memory.honcho.session import HonchoSessionManager
|
||||
|
||||
client = get_honcho_client(cfg)
|
||||
self._manager = HonchoSessionManager(
|
||||
honcho=client,
|
||||
config=cfg,
|
||||
context_tokens=cfg.context_tokens,
|
||||
)
|
||||
|
||||
# ----- B3: resolve_session_name -----
|
||||
session_title = kwargs.get("session_title")
|
||||
self._session_key = (
|
||||
cfg.resolve_session_name(session_title=session_title, session_id=session_id)
|
||||
or session_id
|
||||
or "hermes-default"
|
||||
)
|
||||
logger.debug("Honcho session key resolved: %s", self._session_key)
|
||||
|
||||
# Create session eagerly
|
||||
session = self._manager.get_or_create(self._session_key)
|
||||
self._session_initialized = True
|
||||
|
||||
# ----- B6: Memory file migration (one-time, for new sessions) -----
|
||||
try:
|
||||
if not session.messages:
|
||||
from hermes_constants import get_hermes_home
|
||||
mem_dir = str(get_hermes_home() / "memories")
|
||||
self._manager.migrate_memory_files(self._session_key, mem_dir)
|
||||
logger.debug("Honcho memory file migration attempted for new session: %s", self._session_key)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho memory file migration skipped: %s", e)
|
||||
|
||||
# ----- B7: Pre-warming context at init -----
|
||||
if self._recall_mode in ("context", "hybrid"):
|
||||
try:
|
||||
self._manager.prefetch_context(self._session_key)
|
||||
self._manager.prefetch_dialectic(self._session_key, "What should I know about this user?")
|
||||
logger.debug("Honcho pre-warm threads started for session: %s", self._session_key)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho pre-warm failed: %s", e)
|
||||
|
||||
def _ensure_session(self) -> bool:
|
||||
"""Lazily initialize the Honcho session (for tools-only mode).
|
||||
|
||||
Returns True if the manager is ready, False otherwise.
|
||||
"""
|
||||
if self._manager and self._session_initialized:
|
||||
return True
|
||||
if self._cron_skipped:
|
||||
return False
|
||||
if not self._config or not self._lazy_init_kwargs:
|
||||
return False
|
||||
|
||||
try:
|
||||
self._do_session_init(
|
||||
self._config,
|
||||
self._lazy_init_session_id or "hermes-default",
|
||||
**self._lazy_init_kwargs,
|
||||
)
|
||||
# Clear lazy refs
|
||||
self._lazy_init_kwargs = None
|
||||
self._lazy_init_session_id = None
|
||||
return self._manager is not None
|
||||
except Exception as e:
|
||||
logger.warning("Honcho lazy session init failed: %s", e)
|
||||
return False
|
||||
|
||||
def _format_first_turn_context(self, ctx: dict) -> str:
|
||||
"""Format the prefetch context dict into a readable system prompt block."""
|
||||
parts = []
|
||||
|
||||
rep = ctx.get("representation", "")
|
||||
if rep:
|
||||
parts.append(f"## User Representation\n{rep}")
|
||||
|
||||
card = ctx.get("card", "")
|
||||
if card:
|
||||
parts.append(f"## User Peer Card\n{card}")
|
||||
|
||||
ai_rep = ctx.get("ai_representation", "")
|
||||
if ai_rep:
|
||||
parts.append(f"## AI Self-Representation\n{ai_rep}")
|
||||
|
||||
ai_card = ctx.get("ai_card", "")
|
||||
if ai_card:
|
||||
parts.append(f"## AI Identity Card\n{ai_card}")
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
"""Return system prompt text, adapted by recall_mode.
|
||||
|
||||
B4: On the FIRST call, fetch and bake the full Honcho context
|
||||
(user representation, peer card, AI representation, continuity synthesis).
|
||||
Subsequent calls return the cached block for prompt caching stability.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return ""
|
||||
if not self._manager or not self._session_key:
|
||||
# tools-only mode without session yet still returns a minimal block
|
||||
if self._recall_mode == "tools" and self._config:
|
||||
return (
|
||||
"# Honcho Memory\n"
|
||||
"Active (tools-only mode). Use honcho_profile, honcho_search, "
|
||||
"honcho_context, and honcho_conclude tools to access user memory."
|
||||
)
|
||||
return ""
|
||||
|
||||
# ----- B4: First-turn context baking -----
|
||||
first_turn_block = ""
|
||||
if self._recall_mode in ("context", "hybrid"):
|
||||
with self._first_turn_lock:
|
||||
if self._first_turn_context is None:
|
||||
# First call — fetch and cache
|
||||
try:
|
||||
ctx = self._manager.get_prefetch_context(self._session_key)
|
||||
self._first_turn_context = self._format_first_turn_context(ctx) if ctx else ""
|
||||
except Exception as e:
|
||||
logger.debug("Honcho first-turn context fetch failed: %s", e)
|
||||
self._first_turn_context = ""
|
||||
first_turn_block = self._first_turn_context
|
||||
|
||||
# ----- B1: adapt text based on recall_mode -----
|
||||
if self._recall_mode == "context":
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (context-injection mode). Relevant user context is automatically "
|
||||
"injected before each turn. No memory tools are available — context is "
|
||||
"managed automatically."
|
||||
)
|
||||
elif self._recall_mode == "tools":
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (tools-only mode). Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user. "
|
||||
"No automatic context injection — you must use tools to access memory."
|
||||
)
|
||||
else: # hybrid
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (hybrid mode). Relevant context is auto-injected AND memory tools are available. "
|
||||
"Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user."
|
||||
)
|
||||
|
||||
if first_turn_block:
|
||||
return f"{header}\n\n{first_turn_block}"
|
||||
return header
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
"""Return prefetched dialectic context from background thread."""
|
||||
"""Return prefetched dialectic context from background thread.
|
||||
|
||||
B1: Returns empty when recall_mode is "tools" (no injection).
|
||||
B5: Respects injection_frequency — "first-turn" returns cached/empty after turn 0.
|
||||
Port #3265: Truncates to context_tokens budget.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return ""
|
||||
|
||||
# B1: tools-only mode — no auto-injection
|
||||
if self._recall_mode == "tools":
|
||||
return ""
|
||||
|
||||
# B5: injection_frequency — if "first-turn" and past first turn, return empty
|
||||
if self._injection_frequency == "first-turn" and self._turn_count > 0:
|
||||
return ""
|
||||
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
@@ -210,13 +476,49 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
return ""
|
||||
|
||||
# ----- Port #3265: token budget enforcement -----
|
||||
result = self._truncate_to_budget(result)
|
||||
|
||||
return f"## Honcho Context\n{result}"
|
||||
|
||||
def _truncate_to_budget(self, text: str) -> str:
|
||||
"""Truncate text to fit within context_tokens budget if set."""
|
||||
if not self._config or not self._config.context_tokens:
|
||||
return text
|
||||
budget_chars = self._config.context_tokens * 4 # conservative char estimate
|
||||
if len(text) <= budget_chars:
|
||||
return text
|
||||
# Truncate at word boundary
|
||||
truncated = text[:budget_chars]
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > budget_chars * 0.8:
|
||||
truncated = truncated[:last_space]
|
||||
return truncated + " …"
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
"""Fire a background dialectic query for the upcoming turn."""
|
||||
"""Fire a background dialectic query for the upcoming turn.
|
||||
|
||||
B5: Checks cadence before firing background threads.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager or not self._session_key or not query:
|
||||
return
|
||||
|
||||
# B1: tools-only mode — no prefetch
|
||||
if self._recall_mode == "tools":
|
||||
return
|
||||
|
||||
# B5: cadence check — skip if too soon since last dialectic call
|
||||
if self._dialectic_cadence > 1:
|
||||
if (self._turn_count - self._last_dialectic_turn) < self._dialectic_cadence:
|
||||
logger.debug("Honcho dialectic prefetch skipped: cadence %d, turns since last: %d",
|
||||
self._dialectic_cadence, self._turn_count - self._last_dialectic_turn)
|
||||
return
|
||||
|
||||
self._last_dialectic_turn = self._turn_count
|
||||
|
||||
def _run():
|
||||
try:
|
||||
result = self._manager.dialectic_query(
|
||||
@@ -233,14 +535,28 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
)
|
||||
self._prefetch_thread.start()
|
||||
|
||||
# Also fire context prefetch if cadence allows
|
||||
if self._context_cadence <= 1 or (self._turn_count - self._last_context_turn) >= self._context_cadence:
|
||||
self._last_context_turn = self._turn_count
|
||||
try:
|
||||
self._manager.prefetch_context(self._session_key, query)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho context prefetch failed: %s", e)
|
||||
|
||||
def on_turn_start(self, turn_number: int, message: str, **kwargs) -> None:
|
||||
"""Track turn count for cadence and injection_frequency logic."""
|
||||
self._turn_count = turn_number
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Record the conversation turn in Honcho (non-blocking)."""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager or not self._session_key:
|
||||
return
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
session = self._manager.get_or_create_session(self._session_key)
|
||||
session = self._manager.get_or_create(self._session_key)
|
||||
session.add_message("user", user_content[:4000])
|
||||
session.add_message("assistant", assistant_content[:4000])
|
||||
# Flush to Honcho API
|
||||
@@ -259,6 +575,8 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
"""Mirror built-in user profile writes as Honcho conclusions."""
|
||||
if action != "add" or target != "user" or not content:
|
||||
return
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager or not self._session_key:
|
||||
return
|
||||
|
||||
@@ -273,6 +591,8 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
|
||||
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Flush all pending messages to Honcho on session end."""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager:
|
||||
return
|
||||
# Wait for pending sync
|
||||
@@ -284,9 +604,26 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
logger.debug("Honcho session-end flush failed: %s", e)
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, CONCLUDE_SCHEMA]
|
||||
"""Return tool schemas, respecting recall_mode.
|
||||
|
||||
B1: context-only mode hides all tools.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return []
|
||||
if self._recall_mode == "context":
|
||||
return []
|
||||
return list(ALL_TOOL_SCHEMAS)
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
"""Handle a Honcho tool call, with lazy session init for tools-only mode."""
|
||||
if self._cron_skipped:
|
||||
return json.dumps({"error": "Honcho is not active (cron context)."})
|
||||
|
||||
# Port #1957: ensure session is initialized for tools-only mode
|
||||
if not self._session_initialized:
|
||||
if not self._ensure_session():
|
||||
return json.dumps({"error": "Honcho session could not be initialized."})
|
||||
|
||||
if not self._manager or not self._session_key:
|
||||
return json.dumps({"error": "Honcho is not active for this session."})
|
||||
|
||||
|
||||
@@ -85,6 +85,16 @@ def _normalize_recall_mode(val: str) -> str:
|
||||
return val if val in _VALID_RECALL_MODES else "hybrid"
|
||||
|
||||
|
||||
_VALID_OBSERVATION_MODES = {"unified", "directional"}
|
||||
_OBSERVATION_MODE_ALIASES = {"shared": "unified", "separate": "directional", "cross": "directional"}
|
||||
|
||||
|
||||
def _normalize_observation_mode(val: str) -> str:
|
||||
"""Normalize observation mode values."""
|
||||
val = _OBSERVATION_MODE_ALIASES.get(val, val)
|
||||
return val if val in _VALID_OBSERVATION_MODES else "unified"
|
||||
|
||||
|
||||
def _resolve_memory_mode(
|
||||
global_val: str | dict,
|
||||
host_val: str | dict | None,
|
||||
@@ -154,6 +164,10 @@ class HonchoClientConfig:
|
||||
# "context" — auto-injected context only, Honcho tools removed
|
||||
# "tools" — Honcho tools only, no auto-injected context
|
||||
recall_mode: str = "hybrid"
|
||||
# Observation mode: how Honcho peers observe each other.
|
||||
# "unified" — user peer observes self; all agents share one observation pool
|
||||
# "directional" — AI peer observes user; each agent keeps its own view
|
||||
observation_mode: str = "unified"
|
||||
# Session resolution
|
||||
session_strategy: str = "per-directory"
|
||||
session_peer_prefix: bool = False
|
||||
@@ -313,6 +327,11 @@ class HonchoClientConfig:
|
||||
or raw.get("recallMode")
|
||||
or "hybrid"
|
||||
),
|
||||
observation_mode=_normalize_observation_mode(
|
||||
host_block.get("observationMode")
|
||||
or raw.get("observationMode")
|
||||
or "unified"
|
||||
),
|
||||
session_strategy=session_strategy,
|
||||
session_peer_prefix=session_peer_prefix,
|
||||
sessions=raw.get("sessions", {}),
|
||||
|
||||
@@ -110,6 +110,9 @@ class HonchoSessionManager:
|
||||
self._dialectic_max_chars: int = (
|
||||
config.dialectic_max_chars if config else 600
|
||||
)
|
||||
self._observation_mode: str = (
|
||||
config.observation_mode if config else "unified"
|
||||
)
|
||||
|
||||
# Async write queue — started lazily on first enqueue
|
||||
self._async_queue: queue.Queue | None = None
|
||||
@@ -159,13 +162,18 @@ class HonchoSessionManager:
|
||||
|
||||
session = self.honcho.session(session_id)
|
||||
|
||||
# Configure peer observation settings.
|
||||
# observe_me=True for AI peer so Honcho watches what the agent says
|
||||
# and builds its representation over time — enabling identity formation.
|
||||
# Configure peer observation settings based on observation_mode.
|
||||
# Unified: user peer observes self, AI peer passive — all agents share
|
||||
# one observation pool via user self-observations.
|
||||
# Directional: AI peer observes user — each agent keeps its own view.
|
||||
try:
|
||||
from honcho.session import SessionPeerConfig
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=True)
|
||||
ai_config = SessionPeerConfig(observe_me=True, observe_others=True)
|
||||
if self._observation_mode == "directional":
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=False)
|
||||
ai_config = SessionPeerConfig(observe_me=False, observe_others=True)
|
||||
else: # unified (default)
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=False)
|
||||
ai_config = SessionPeerConfig(observe_me=False, observe_others=False)
|
||||
|
||||
session.add_peers([(user_peer, user_config), (assistant_peer, ai_config)])
|
||||
except Exception as e:
|
||||
@@ -493,12 +501,27 @@ class HonchoSessionManager:
|
||||
if not session:
|
||||
return ""
|
||||
|
||||
peer_id = session.assistant_peer_id if peer == "ai" else session.user_peer_id
|
||||
target_peer = self._get_or_create_peer(peer_id)
|
||||
level = reasoning_level or self._dynamic_reasoning_level(query)
|
||||
|
||||
try:
|
||||
result = target_peer.chat(query, reasoning_level=level) or ""
|
||||
if self._observation_mode == "directional":
|
||||
# AI peer queries about the user (cross-observation)
|
||||
if peer == "ai":
|
||||
ai_peer_obj = self._get_or_create_peer(session.assistant_peer_id)
|
||||
result = ai_peer_obj.chat(query, reasoning_level=level) or ""
|
||||
else:
|
||||
ai_peer_obj = self._get_or_create_peer(session.assistant_peer_id)
|
||||
result = ai_peer_obj.chat(
|
||||
query,
|
||||
target=session.user_peer_id,
|
||||
reasoning_level=level,
|
||||
) or ""
|
||||
else:
|
||||
# Unified: user peer queries self, or AI peer queries self
|
||||
peer_id = session.assistant_peer_id if peer == "ai" else session.user_peer_id
|
||||
target_peer = self._get_or_create_peer(peer_id)
|
||||
result = target_peer.chat(query, reasoning_level=level) or ""
|
||||
|
||||
# Apply Hermes-side char cap before caching
|
||||
if result and self._dialectic_max_chars and len(result) > self._dialectic_max_chars:
|
||||
result = result[:self._dialectic_max_chars].rsplit(" ", 1)[0] + " …"
|
||||
@@ -895,9 +918,16 @@ class HonchoSessionManager:
|
||||
logger.warning("No session cached for '%s', skipping conclusion", session_key)
|
||||
return False
|
||||
|
||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||
try:
|
||||
conclusions_scope = assistant_peer.conclusions_of(session.user_peer_id)
|
||||
if self._observation_mode == "directional":
|
||||
# AI peer creates conclusion about user (cross-observation)
|
||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||
conclusions_scope = assistant_peer.conclusions_of(session.user_peer_id)
|
||||
else:
|
||||
# Unified: user peer creates self-conclusion
|
||||
user_peer = self._get_or_create_peer(session.user_peer_id)
|
||||
conclusions_scope = user_peer.conclusions_of(session.user_peer_id)
|
||||
|
||||
conclusions_scope.create([{
|
||||
"content": content.strip(),
|
||||
"session_id": session.honcho_session_id,
|
||||
|
||||
115
run_agent.py
115
run_agent.py
@@ -6656,10 +6656,21 @@ class AIAgent:
|
||||
if self.step_callback is not None:
|
||||
try:
|
||||
prev_tools = []
|
||||
for _m in reversed(messages):
|
||||
for _idx, _m in enumerate(reversed(messages)):
|
||||
if _m.get("role") == "assistant" and _m.get("tool_calls"):
|
||||
_fwd_start = len(messages) - _idx
|
||||
_results_by_id = {}
|
||||
for _tm in messages[_fwd_start:]:
|
||||
if _tm.get("role") != "tool":
|
||||
break
|
||||
_tcid = _tm.get("tool_call_id")
|
||||
if _tcid:
|
||||
_results_by_id[_tcid] = _tm.get("content", "")
|
||||
prev_tools = [
|
||||
tc["function"]["name"]
|
||||
{
|
||||
"name": tc["function"]["name"],
|
||||
"result": _results_by_id.get(tc.get("id")),
|
||||
}
|
||||
for tc in _m["tool_calls"]
|
||||
if isinstance(tc, dict)
|
||||
]
|
||||
@@ -7369,6 +7380,61 @@ class AIAgent:
|
||||
# compress history and retry, not abort immediately.
|
||||
status_code = getattr(api_error, "status_code", None)
|
||||
|
||||
# ── Anthropic Sonnet long-context tier gate ───────────
|
||||
# Anthropic returns HTTP 429 "Extra usage is required for
|
||||
# long context requests" when a Claude Max (or similar)
|
||||
# subscription doesn't include the 1M-context tier. This
|
||||
# is NOT a transient rate limit — retrying or switching
|
||||
# credentials won't help. Reduce context to 200k (the
|
||||
# standard tier) and compress.
|
||||
# Only applies to Sonnet — Opus 1M is general access.
|
||||
_is_long_context_tier_error = (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
and "sonnet" in self.model.lower()
|
||||
)
|
||||
if _is_long_context_tier_error:
|
||||
_reduced_ctx = 200000
|
||||
compressor = self.context_compressor
|
||||
old_ctx = compressor.context_length
|
||||
if old_ctx > _reduced_ctx:
|
||||
compressor.context_length = _reduced_ctx
|
||||
compressor.threshold_tokens = int(
|
||||
_reduced_ctx * compressor.threshold_percent
|
||||
)
|
||||
compressor._context_probed = True
|
||||
# Don't persist — this is a subscription-tier
|
||||
# limitation, not a model capability. If the user
|
||||
# later enables extra usage the 1M limit should
|
||||
# come back automatically.
|
||||
compressor._context_probe_persistable = False
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Anthropic long-context tier "
|
||||
f"requires extra usage — reducing context: "
|
||||
f"{old_ctx:,} → {_reduced_ctx:,} tokens",
|
||||
force=True,
|
||||
)
|
||||
|
||||
compression_attempts += 1
|
||||
if compression_attempts <= max_compression_attempts:
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message,
|
||||
approx_tokens=approx_tokens,
|
||||
task_id=effective_task_id,
|
||||
)
|
||||
if len(messages) < original_len or old_ctx > _reduced_ctx:
|
||||
self._emit_status(
|
||||
f"🗜️ Context reduced to {_reduced_ctx:,} tokens "
|
||||
f"(was {old_ctx:,}), retrying..."
|
||||
)
|
||||
time.sleep(2)
|
||||
restart_with_compressed_messages = True
|
||||
break
|
||||
# Fall through to normal error handling if compression
|
||||
# is exhausted or didn't help.
|
||||
|
||||
# Eager fallback for rate-limit errors (429 or quota exhaustion).
|
||||
# When a fallback model is configured, switch immediately instead
|
||||
# of burning through retries with exponential backoff -- the
|
||||
@@ -7474,7 +7540,33 @@ class AIAgent:
|
||||
f"treating as probable context overflow.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
|
||||
# Server disconnects on large sessions are often caused by
|
||||
# the request exceeding the provider's context/payload limit
|
||||
# without a proper HTTP error response. Treat these as
|
||||
# context-length errors to trigger compression rather than
|
||||
# burning through retries that will all fail the same way.
|
||||
# This breaks the death spiral: disconnect → no token data
|
||||
# → no compression → bigger session → more disconnects.
|
||||
# (#2153)
|
||||
if not is_context_length_error and not status_code:
|
||||
_is_server_disconnect = (
|
||||
'server disconnected' in error_msg
|
||||
or 'peer closed connection' in error_msg
|
||||
or error_type in ('ReadError', 'RemoteProtocolError', 'ServerDisconnectedError')
|
||||
)
|
||||
if _is_server_disconnect:
|
||||
ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000)
|
||||
_is_large = approx_tokens > ctx_len * 0.6 or len(api_messages) > 200
|
||||
if _is_large:
|
||||
is_context_length_error = True
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Server disconnected with large session "
|
||||
f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — "
|
||||
f"treating as context-length error, attempting compression.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
if is_context_length_error:
|
||||
compressor = self.context_compressor
|
||||
old_ctx = compressor.context_length
|
||||
@@ -8109,11 +8201,20 @@ class AIAgent:
|
||||
# threshold (default 50%) leaves ample headroom; if tool
|
||||
# results push past it, the next API call will report the
|
||||
# real total and trigger compression then.
|
||||
#
|
||||
# If last_prompt_tokens is 0 (stale after API disconnect
|
||||
# or provider returned no usage data), fall back to rough
|
||||
# estimate to avoid missing compression. Without this,
|
||||
# a session can grow unbounded after disconnects because
|
||||
# should_compress(0) never fires. (#2153)
|
||||
_compressor = self.context_compressor
|
||||
_real_tokens = (
|
||||
_compressor.last_prompt_tokens
|
||||
+ _compressor.last_completion_tokens
|
||||
)
|
||||
if _compressor.last_prompt_tokens > 0:
|
||||
_real_tokens = (
|
||||
_compressor.last_prompt_tokens
|
||||
+ _compressor.last_completion_tokens
|
||||
)
|
||||
else:
|
||||
_real_tokens = estimate_messages_tokens_rough(messages)
|
||||
|
||||
# ── Context pressure warnings (user-facing only) ──────────
|
||||
# Notify the user (NOT the LLM) as context approaches the
|
||||
|
||||
@@ -62,6 +62,33 @@ function formatOutgoingMessage(message) {
|
||||
return REPLY_PREFIX ? `${REPLY_PREFIX}${message}` : message;
|
||||
}
|
||||
|
||||
function normalizeWhatsAppId(value) {
|
||||
if (!value) return '';
|
||||
return String(value).replace(':', '@');
|
||||
}
|
||||
|
||||
function getMessageContent(msg) {
|
||||
const content = msg?.message || {};
|
||||
if (content.ephemeralMessage?.message) return content.ephemeralMessage.message;
|
||||
if (content.viewOnceMessage?.message) return content.viewOnceMessage.message;
|
||||
if (content.viewOnceMessageV2?.message) return content.viewOnceMessageV2.message;
|
||||
if (content.documentWithCaptionMessage?.message) return content.documentWithCaptionMessage.message;
|
||||
if (content.templateMessage?.hydratedTemplate) return content.templateMessage.hydratedTemplate;
|
||||
if (content.buttonsMessage) return content.buttonsMessage;
|
||||
if (content.listMessage) return content.listMessage;
|
||||
return content;
|
||||
}
|
||||
|
||||
function getContextInfo(messageContent) {
|
||||
if (!messageContent || typeof messageContent !== 'object') return {};
|
||||
for (const value of Object.values(messageContent)) {
|
||||
if (value && typeof value === 'object' && value.contextInfo) {
|
||||
return value.contextInfo;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
mkdirSync(SESSION_DIR, { recursive: true });
|
||||
|
||||
// Build LID → phone reverse map from session files (lid-mapping-{phone}.json)
|
||||
@@ -157,6 +184,11 @@ async function startSocket() {
|
||||
// than 'notify'. Accept both and filter agent echo-backs below.
|
||||
if (type !== 'notify' && type !== 'append') return;
|
||||
|
||||
const botIds = Array.from(new Set([
|
||||
normalizeWhatsAppId(sock.user?.id),
|
||||
normalizeWhatsAppId(sock.user?.lid),
|
||||
].filter(Boolean)));
|
||||
|
||||
for (const msg of messages) {
|
||||
if (!msg.message) continue;
|
||||
|
||||
@@ -200,23 +232,28 @@ async function startSocket() {
|
||||
continue;
|
||||
}
|
||||
|
||||
const messageContent = getMessageContent(msg);
|
||||
const contextInfo = getContextInfo(messageContent);
|
||||
const mentionedIds = Array.from(new Set((contextInfo?.mentionedJid || []).map(normalizeWhatsAppId).filter(Boolean)));
|
||||
const quotedParticipant = normalizeWhatsAppId(contextInfo?.participant || contextInfo?.remoteJid || '');
|
||||
|
||||
// Extract message body
|
||||
let body = '';
|
||||
let hasMedia = false;
|
||||
let mediaType = '';
|
||||
const mediaUrls = [];
|
||||
|
||||
if (msg.message.conversation) {
|
||||
body = msg.message.conversation;
|
||||
} else if (msg.message.extendedTextMessage?.text) {
|
||||
body = msg.message.extendedTextMessage.text;
|
||||
} else if (msg.message.imageMessage) {
|
||||
body = msg.message.imageMessage.caption || '';
|
||||
if (messageContent.conversation) {
|
||||
body = messageContent.conversation;
|
||||
} else if (messageContent.extendedTextMessage?.text) {
|
||||
body = messageContent.extendedTextMessage.text;
|
||||
} else if (messageContent.imageMessage) {
|
||||
body = messageContent.imageMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'image';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = msg.message.imageMessage.mimetype || 'image/jpeg';
|
||||
const mime = messageContent.imageMessage.mimetype || 'image/jpeg';
|
||||
const extMap = { 'image/jpeg': '.jpg', 'image/png': '.png', 'image/webp': '.webp', 'image/gif': '.gif' };
|
||||
const ext = extMap[mime] || '.jpg';
|
||||
mkdirSync(IMAGE_CACHE_DIR, { recursive: true });
|
||||
@@ -226,13 +263,13 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download image:', err.message);
|
||||
}
|
||||
} else if (msg.message.videoMessage) {
|
||||
body = msg.message.videoMessage.caption || '';
|
||||
} else if (messageContent.videoMessage) {
|
||||
body = messageContent.videoMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'video';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = msg.message.videoMessage.mimetype || 'video/mp4';
|
||||
const mime = messageContent.videoMessage.mimetype || 'video/mp4';
|
||||
const ext = mime.includes('mp4') ? '.mp4' : '.mkv';
|
||||
mkdirSync(DOCUMENT_CACHE_DIR, { recursive: true });
|
||||
const filePath = path.join(DOCUMENT_CACHE_DIR, `vid_${randomBytes(6).toString('hex')}${ext}`);
|
||||
@@ -241,11 +278,11 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download video:', err.message);
|
||||
}
|
||||
} else if (msg.message.audioMessage || msg.message.pttMessage) {
|
||||
} else if (messageContent.audioMessage || messageContent.pttMessage) {
|
||||
hasMedia = true;
|
||||
mediaType = msg.message.pttMessage ? 'ptt' : 'audio';
|
||||
mediaType = messageContent.pttMessage ? 'ptt' : 'audio';
|
||||
try {
|
||||
const audioMsg = msg.message.pttMessage || msg.message.audioMessage;
|
||||
const audioMsg = messageContent.pttMessage || messageContent.audioMessage;
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = audioMsg.mimetype || 'audio/ogg';
|
||||
const ext = mime.includes('ogg') ? '.ogg' : mime.includes('mp4') ? '.m4a' : '.ogg';
|
||||
@@ -256,11 +293,11 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download audio:', err.message);
|
||||
}
|
||||
} else if (msg.message.documentMessage) {
|
||||
body = msg.message.documentMessage.caption || '';
|
||||
} else if (messageContent.documentMessage) {
|
||||
body = messageContent.documentMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'document';
|
||||
const fileName = msg.message.documentMessage.fileName || 'document';
|
||||
const fileName = messageContent.documentMessage.fileName || 'document';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
mkdirSync(DOCUMENT_CACHE_DIR, { recursive: true });
|
||||
@@ -309,6 +346,9 @@ async function startSocket() {
|
||||
hasMedia,
|
||||
mediaType,
|
||||
mediaUrls,
|
||||
mentionedIds,
|
||||
quotedParticipant,
|
||||
botIds,
|
||||
timestamp: msg.messageTimestamp,
|
||||
};
|
||||
|
||||
|
||||
@@ -205,6 +205,47 @@ class TestStepCallback:
|
||||
assert "read_file" not in tool_call_ids
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_result_passed_to_build_tool_complete(self, mock_conn, event_loop_fixture):
|
||||
"""Tool result from prev_tools dict is forwarded to build_tool_complete."""
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
# Provide a result string in the tool info dict
|
||||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
||||
)
|
||||
|
||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||
"""When result is None (e.g. first iteration), None is passed through."""
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "web_search", "result": None}])
|
||||
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message callback
|
||||
|
||||
349
tests/acp/test_mcp_e2e.py
Normal file
349
tests/acp/test_mcp_e2e.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""End-to-end tests for ACP MCP server registration and tool-result reporting.
|
||||
|
||||
Exercises the full flow through the ACP server layer:
|
||||
new_session(mcpServers) → MCP tools registered → prompt() →
|
||||
tool_progress_callback (ToolCallStart) →
|
||||
step_callback with results (ToolCallUpdate with rawOutput) →
|
||||
session_update events arrive at the mock client
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.schema import (
|
||||
EnvVariable,
|
||||
HttpHeader,
|
||||
McpServerHttp,
|
||||
McpServerStdio,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
TextContentBlock,
|
||||
ToolCallProgress,
|
||||
ToolCallStart,
|
||||
)
|
||||
|
||||
from acp_adapter.server import HermesACPAgent
|
||||
from acp_adapter.session import SessionManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_manager():
|
||||
return SessionManager(agent_factory=lambda: MagicMock(name="MockAIAgent"))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def acp_agent(mock_manager):
|
||||
return HermesACPAgent(session_manager=mock_manager)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E: MCP registration → prompt → tool events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMcpRegistrationE2E:
|
||||
"""Full flow: session with MCP servers → prompt with tool calls → ACP events."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_with_mcp_servers_registers_tools(self, acp_agent, mock_manager):
|
||||
"""new_session with mcpServers converts them to Hermes config and registers."""
|
||||
servers = [
|
||||
McpServerStdio(
|
||||
name="test-fs",
|
||||
command="/usr/bin/mcp-fs",
|
||||
args=["--root", "/tmp"],
|
||||
env=[EnvVariable(name="DEBUG", value="1")],
|
||||
),
|
||||
McpServerHttp(
|
||||
name="test-api",
|
||||
url="https://api.example.com/mcp",
|
||||
headers=[HttpHeader(name="Authorization", value="Bearer tok123")],
|
||||
),
|
||||
]
|
||||
|
||||
registered_configs = {}
|
||||
|
||||
def mock_register(config_map):
|
||||
registered_configs.update(config_map)
|
||||
return ["mcp_test_fs_read", "mcp_test_fs_write", "mcp_test_api_search"]
|
||||
|
||||
fake_tools = [
|
||||
{"function": {"name": "mcp_test_fs_read"}},
|
||||
{"function": {"name": "mcp_test_fs_write"}},
|
||||
{"function": {"name": "mcp_test_api_search"}},
|
||||
{"function": {"name": "terminal"}},
|
||||
]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
resp = await acp_agent.new_session(cwd="/tmp", mcp_servers=servers)
|
||||
|
||||
assert isinstance(resp, NewSessionResponse)
|
||||
state = mock_manager.get_session(resp.session_id)
|
||||
|
||||
# Verify stdio server was converted correctly
|
||||
assert "test-fs" in registered_configs
|
||||
fs_cfg = registered_configs["test-fs"]
|
||||
assert fs_cfg["command"] == "/usr/bin/mcp-fs"
|
||||
assert fs_cfg["args"] == ["--root", "/tmp"]
|
||||
assert fs_cfg["env"] == {"DEBUG": "1"}
|
||||
|
||||
# Verify HTTP server was converted correctly
|
||||
assert "test-api" in registered_configs
|
||||
api_cfg = registered_configs["test-api"]
|
||||
assert api_cfg["url"] == "https://api.example.com/mcp"
|
||||
assert api_cfg["headers"] == {"Authorization": "Bearer tok123"}
|
||||
|
||||
# Verify agent tool surface was refreshed
|
||||
assert state.agent.tools == fake_tools
|
||||
assert state.agent.valid_tool_names == {
|
||||
"mcp_test_fs_read", "mcp_test_fs_write", "mcp_test_api_search", "terminal"
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_with_tool_calls_emits_acp_events(self, acp_agent, mock_manager):
|
||||
"""Prompt → agent fires callbacks → ACP ToolCallStart + ToolCallUpdate events."""
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
session_id = resp.session_id
|
||||
state = mock_manager.get_session(session_id)
|
||||
|
||||
# Wire up a mock ACP client connection
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
mock_conn.request_permission = AsyncMock()
|
||||
acp_agent._conn = mock_conn
|
||||
|
||||
def mock_run_conversation(user_message, conversation_history=None, task_id=None):
|
||||
"""Simulate an agent turn that calls terminal, gets a result, then responds."""
|
||||
agent = state.agent
|
||||
|
||||
# 1) Agent fires tool_progress_callback (ToolCallStart)
|
||||
if agent.tool_progress_callback:
|
||||
agent.tool_progress_callback(
|
||||
"terminal", "$ echo hello", {"command": "echo hello"}
|
||||
)
|
||||
|
||||
# 2) Agent fires step_callback with tool results (ToolCallUpdate)
|
||||
if agent.step_callback:
|
||||
agent.step_callback(1, [
|
||||
{"name": "terminal", "result": '{"output": "hello\\n", "exit_code": 0}'}
|
||||
])
|
||||
|
||||
return {
|
||||
"final_response": "The command output 'hello'.",
|
||||
"messages": [
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": "The command output 'hello'."},
|
||||
],
|
||||
}
|
||||
|
||||
state.agent.run_conversation = mock_run_conversation
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="run echo hello")]
|
||||
resp = await acp_agent.prompt(prompt=prompt, session_id=session_id)
|
||||
|
||||
assert isinstance(resp, PromptResponse)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
# Collect all session_update calls
|
||||
updates = []
|
||||
for call in mock_conn.session_update.call_args_list:
|
||||
# session_update(session_id, update) — grab the update
|
||||
update_arg = call[1].get("update") or call[0][1]
|
||||
updates.append(update_arg)
|
||||
|
||||
# Find tool_call (start) and tool_call_update (completion) events
|
||||
starts = [u for u in updates if getattr(u, "session_update", None) == "tool_call"]
|
||||
completions = [u for u in updates if getattr(u, "session_update", None) == "tool_call_update"]
|
||||
|
||||
# Should have at least one ToolCallStart for "terminal"
|
||||
assert len(starts) >= 1, f"Expected ToolCallStart, got updates: {[getattr(u, 'session_update', '?') for u in updates]}"
|
||||
start_event = starts[0]
|
||||
assert isinstance(start_event, ToolCallStart)
|
||||
assert start_event.title.startswith("terminal:")
|
||||
|
||||
# Should have at least one ToolCallUpdate (completion) with rawOutput
|
||||
assert len(completions) >= 1, f"Expected ToolCallUpdate, got updates: {[getattr(u, 'session_update', '?') for u in updates]}"
|
||||
complete_event = completions[0]
|
||||
assert isinstance(complete_event, ToolCallProgress)
|
||||
assert complete_event.status == "completed"
|
||||
# rawOutput should contain the tool result string
|
||||
assert complete_event.raw_output is not None
|
||||
assert "hello" in str(complete_event.raw_output)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
session_id = resp.session_id
|
||||
state = mock_manager.get_session(session_id)
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
mock_conn.request_permission = AsyncMock()
|
||||
acp_agent._conn = mock_conn
|
||||
|
||||
def mock_run(user_message, conversation_history=None, task_id=None):
|
||||
agent = state.agent
|
||||
# Fire two tool calls
|
||||
if agent.tool_progress_callback:
|
||||
agent.tool_progress_callback("read_file", "read: /etc/hosts", {"path": "/etc/hosts"})
|
||||
agent.tool_progress_callback("web_search", "web search: test", {"query": "test"})
|
||||
|
||||
if agent.step_callback:
|
||||
agent.step_callback(1, [
|
||||
{"name": "read_file", "result": '{"content": "127.0.0.1 localhost"}'},
|
||||
{"name": "web_search", "result": '{"data": {"web": []}}'},
|
||||
])
|
||||
|
||||
return {"final_response": "Done.", "messages": []}
|
||||
|
||||
state.agent.run_conversation = mock_run
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="test")]
|
||||
await acp_agent.prompt(prompt=prompt, session_id=session_id)
|
||||
|
||||
updates = []
|
||||
for call in mock_conn.session_update.call_args_list:
|
||||
update_arg = call[1].get("update") or call[0][1]
|
||||
updates.append(update_arg)
|
||||
|
||||
starts = [u for u in updates if getattr(u, "session_update", None) == "tool_call"]
|
||||
completions = [u for u in updates if getattr(u, "session_update", None) == "tool_call_update"]
|
||||
|
||||
assert len(starts) == 2, f"Expected 2 starts, got {len(starts)}"
|
||||
assert len(completions) == 2, f"Expected 2 completions, got {len(completions)}"
|
||||
|
||||
# Each completion's toolCallId must match a start's toolCallId
|
||||
start_ids = {s.tool_call_id for s in starts}
|
||||
completion_ids = {c.tool_call_id for c in completions}
|
||||
assert start_ids == completion_ids, (
|
||||
f"IDs must match: starts={start_ids}, completions={completion_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestMcpSanitizationE2E:
|
||||
"""Verify server names with special chars work end-to-end."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slashed_server_name_registers_cleanly(self, acp_agent, mock_manager):
|
||||
"""Server name 'ai.exa/exa' should not crash — tools get sanitized names."""
|
||||
servers = [
|
||||
McpServerHttp(
|
||||
name="ai.exa/exa",
|
||||
url="https://exa.ai/mcp",
|
||||
headers=[],
|
||||
),
|
||||
]
|
||||
|
||||
registered_configs = {}
|
||||
def mock_register(config_map):
|
||||
registered_configs.update(config_map)
|
||||
return ["mcp_ai_exa_exa_search"]
|
||||
|
||||
fake_tools = [{"function": {"name": "mcp_ai_exa_exa_search"}}]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
resp = await acp_agent.new_session(cwd="/tmp", mcp_servers=servers)
|
||||
|
||||
state = mock_manager.get_session(resp.session_id)
|
||||
|
||||
# Raw server name preserved as config key
|
||||
assert "ai.exa/exa" in registered_configs
|
||||
# Agent tools refreshed with sanitized name
|
||||
assert "mcp_ai_exa_exa_search" in state.agent.valid_tool_names
|
||||
|
||||
|
||||
class TestSessionLifecycleMcpE2E:
|
||||
"""Verify MCP servers are registered on all session lifecycle methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""load_session re-registers MCP servers (spec says agents may not retain them)."""
|
||||
# Create a session first
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerStdio(name="srv", command="/bin/test", args=[], env=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
state = mock_manager.get_session(sid)
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await acp_agent.load_session(cwd="/tmp", session_id=sid, mcp_servers=servers)
|
||||
|
||||
assert "srv" in registered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""resume_session re-registers MCP servers."""
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerStdio(name="srv2", command="/bin/test2", args=[], env=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
state = mock_manager.get_session(sid)
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await acp_agent.resume_session(cwd="/tmp", session_id=sid, mcp_servers=servers)
|
||||
|
||||
assert "srv2" in registered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fork_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""fork_session registers MCP servers on the new forked session."""
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerHttp(name="api", url="https://api.test/mcp", headers=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
# Need to set up the forked session's agent too
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
fork_resp = await acp_agent.fork_session(
|
||||
cwd="/tmp", session_id=sid, mcp_servers=servers
|
||||
)
|
||||
|
||||
assert fork_resp.session_id != ""
|
||||
assert "api" in registered
|
||||
@@ -505,3 +505,179 @@ class TestSlashCommands:
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _register_session_mcp_servers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterSessionMcpServers:
|
||||
"""Tests for ACP MCP server registration in session lifecycle."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_no_servers(self, agent, mock_manager):
|
||||
"""No-op when mcp_servers is None or empty."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
# Should not raise
|
||||
await agent._register_session_mcp_servers(state, None)
|
||||
await agent._register_session_mcp_servers(state, [])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_stdio_servers(self, agent, mock_manager):
|
||||
"""McpServerStdio servers are converted and passed to register_mcp_servers."""
|
||||
from acp.schema import McpServerStdio, EnvVariable
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
# Give the mock agent the attributes _register_session_mcp_servers reads
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
server = McpServerStdio(
|
||||
name="test-server",
|
||||
command="/usr/bin/test",
|
||||
args=["--flag"],
|
||||
env=[EnvVariable(name="KEY", value="val")],
|
||||
)
|
||||
|
||||
registered_config = {}
|
||||
def capture_register(config_map):
|
||||
registered_config.update(config_map)
|
||||
return ["mcp_test_server_tool1"]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert "test-server" in registered_config
|
||||
cfg = registered_config["test-server"]
|
||||
assert cfg["command"] == "/usr/bin/test"
|
||||
assert cfg["args"] == ["--flag"]
|
||||
assert cfg["env"] == {"KEY": "val"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_http_servers(self, agent, mock_manager):
|
||||
"""McpServerHttp servers are converted correctly."""
|
||||
from acp.schema import McpServerHttp, HttpHeader
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
server = McpServerHttp(
|
||||
name="http-server",
|
||||
url="https://api.example.com/mcp",
|
||||
headers=[HttpHeader(name="Authorization", value="Bearer tok")],
|
||||
)
|
||||
|
||||
registered_config = {}
|
||||
def capture_register(config_map):
|
||||
registered_config.update(config_map)
|
||||
return []
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert "http-server" in registered_config
|
||||
cfg = registered_config["http-server"]
|
||||
assert cfg["url"] == "https://api.example.com/mcp"
|
||||
assert cfg["headers"] == {"Authorization": "Bearer tok"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refreshes_agent_tool_surface(self, agent, mock_manager):
|
||||
"""After MCP registration, agent.tools and valid_tool_names are refreshed."""
|
||||
from acp.schema import McpServerStdio
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
state.agent._cached_system_prompt = "old prompt"
|
||||
|
||||
server = McpServerStdio(
|
||||
name="srv",
|
||||
command="/bin/test",
|
||||
args=[],
|
||||
env=[],
|
||||
)
|
||||
|
||||
fake_tools = [
|
||||
{"function": {"name": "mcp_srv_search"}},
|
||||
{"function": {"name": "terminal"}},
|
||||
]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", return_value=["mcp_srv_search"]), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert state.agent.tools == fake_tools
|
||||
assert state.agent.valid_tool_names == {"mcp_srv_search", "terminal"}
|
||||
# _invalidate_system_prompt should have been called
|
||||
state.agent._invalidate_system_prompt.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_failure_logs_warning(self, agent, mock_manager):
|
||||
"""If register_mcp_servers raises, warning is logged but no crash."""
|
||||
from acp.schema import McpServerStdio
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
server = McpServerStdio(
|
||||
name="bad",
|
||||
command="/nonexistent",
|
||||
args=[],
|
||||
env=[],
|
||||
)
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=RuntimeError("boom")):
|
||||
# Should not raise
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_calls_register(self, agent, mock_manager):
|
||||
"""new_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.new_session(cwd="/tmp", mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
# Second arg should be the mcp_servers list
|
||||
assert mock_reg.call_args[0][1] == ["fake"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_calls_register(self, agent, mock_manager):
|
||||
"""load_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
# Create a session first so load can find it
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.load_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_calls_register(self, agent, mock_manager):
|
||||
"""resume_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.resume_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fork_session_calls_register(self, agent, mock_manager):
|
||||
"""fork_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.fork_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
133
tests/gateway/test_step_callback_compat.py
Normal file
133
tests/gateway/test_step_callback_compat.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for step_callback backward compatibility.
|
||||
|
||||
Verifies that the gateway's step_callback normalization keeps
|
||||
``tool_names`` as a list of strings for backward-compatible hooks,
|
||||
while also providing the enriched ``tools`` list with results.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStepCallbackNormalization:
|
||||
"""The gateway's _step_callback_sync normalizes prev_tools from run_agent."""
|
||||
|
||||
def _extract_step_callback(self):
|
||||
"""Build a minimal _step_callback_sync using the same logic as gateway/run.py.
|
||||
|
||||
We replicate the closure so we can test normalisation in isolation
|
||||
without spinning up the full gateway.
|
||||
"""
|
||||
captured_events = []
|
||||
|
||||
class FakeHooks:
|
||||
async def emit(self, event_type, data):
|
||||
captured_events.append((event_type, data))
|
||||
|
||||
hooks_ref = FakeHooks()
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
||||
_names: list[str] = []
|
||||
for _t in (prev_tools or []):
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
hooks_ref.emit("agent:step", {
|
||||
"iteration": iteration,
|
||||
"tool_names": _names,
|
||||
"tools": prev_tools,
|
||||
}),
|
||||
loop,
|
||||
)
|
||||
|
||||
return _step_callback_sync, captured_events, loop
|
||||
|
||||
def test_dict_prev_tools_produce_string_tool_names(self):
|
||||
"""When prev_tools is list[dict], tool_names should be list[str]."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
# Simulate the enriched format from run_agent.py
|
||||
prev_tools = [
|
||||
{"name": "terminal", "result": '{"output": "hello"}'},
|
||||
{"name": "read_file", "result": '{"content": "..."}'},
|
||||
]
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0)) # prime the loop
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(1, prev_tools))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
# tool_names must be strings for backward compat
|
||||
assert data["tool_names"] == ["terminal", "read_file"]
|
||||
assert all(isinstance(n, str) for n in data["tool_names"])
|
||||
# tools should be the enriched dicts
|
||||
assert data["tools"] == prev_tools
|
||||
|
||||
def test_string_prev_tools_still_work(self):
|
||||
"""When prev_tools is list[str] (legacy), tool_names should pass through."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
prev_tools = ["terminal", "read_file"]
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0))
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(2, prev_tools))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
assert data["tool_names"] == ["terminal", "read_file"]
|
||||
|
||||
def test_empty_prev_tools(self):
|
||||
"""Empty or None prev_tools should produce empty tool_names."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0))
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(1, []))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
assert data["tool_names"] == []
|
||||
|
||||
def test_joinable_for_hook_example(self):
|
||||
"""The documented hook example: ', '.join(tool_names) should work."""
|
||||
# This is the exact pattern from the docs
|
||||
prev_tools = [
|
||||
{"name": "terminal", "result": "ok"},
|
||||
{"name": "web_search", "result": None},
|
||||
]
|
||||
|
||||
_names = []
|
||||
for _t in prev_tools:
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
|
||||
# This must not raise — documented hook pattern
|
||||
result = ", ".join(_names)
|
||||
assert result == "terminal, web_search"
|
||||
142
tests/gateway/test_whatsapp_group_gating.py
Normal file
142
tests/gateway/test_whatsapp_group_gating.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig, load_gateway_config
|
||||
|
||||
|
||||
def _make_adapter(require_mention=None, mention_patterns=None, free_response_chats=None):
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
|
||||
extra = {}
|
||||
if require_mention is not None:
|
||||
extra["require_mention"] = require_mention
|
||||
if mention_patterns is not None:
|
||||
extra["mention_patterns"] = mention_patterns
|
||||
if free_response_chats is not None:
|
||||
extra["free_response_chats"] = free_response_chats
|
||||
|
||||
adapter = object.__new__(WhatsAppAdapter)
|
||||
adapter.platform = Platform.WHATSAPP
|
||||
adapter.config = PlatformConfig(enabled=True, extra=extra)
|
||||
adapter._message_handler = AsyncMock()
|
||||
adapter._mention_patterns = adapter._compile_mention_patterns()
|
||||
return adapter
|
||||
|
||||
|
||||
def _group_message(body="hello", **overrides):
|
||||
data = {
|
||||
"isGroup": True,
|
||||
"body": body,
|
||||
"chatId": "120363001234567890@g.us",
|
||||
"mentionedIds": [],
|
||||
"botIds": ["15551230000@s.whatsapp.net", "15551230000@lid"],
|
||||
"quotedParticipant": "",
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
|
||||
def test_group_messages_can_be_opened_via_config():
|
||||
adapter = _make_adapter(require_mention=False)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is True
|
||||
|
||||
|
||||
def test_group_messages_can_require_direct_trigger_via_config():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
assert adapter._should_process_message(
|
||||
_group_message(
|
||||
"hi there",
|
||||
mentionedIds=["15551230000@s.whatsapp.net"],
|
||||
)
|
||||
) is True
|
||||
assert adapter._should_process_message(
|
||||
_group_message(
|
||||
"replying",
|
||||
quotedParticipant="15551230000@lid",
|
||||
)
|
||||
) is True
|
||||
assert adapter._should_process_message(_group_message("/status")) is True
|
||||
|
||||
|
||||
def test_regex_mention_patterns_allow_custom_wake_words():
|
||||
adapter = _make_adapter(require_mention=True, mention_patterns=[r"^\s*chompy\b"])
|
||||
|
||||
assert adapter._should_process_message(_group_message("chompy status")) is True
|
||||
assert adapter._should_process_message(_group_message(" chompy help")) is True
|
||||
assert adapter._should_process_message(_group_message("hey chompy")) is False
|
||||
|
||||
|
||||
def test_invalid_regex_patterns_are_ignored():
|
||||
adapter = _make_adapter(require_mention=True, mention_patterns=[r"(", r"^\s*chompy\b"])
|
||||
|
||||
assert adapter._should_process_message(_group_message("chompy status")) is True
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
|
||||
|
||||
def test_config_bridges_whatsapp_group_settings(monkeypatch, tmp_path):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"whatsapp:\n"
|
||||
" require_mention: true\n"
|
||||
" mention_patterns:\n"
|
||||
" - \"^\\\\s*chompy\\\\b\"\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("WHATSAPP_REQUIRE_MENTION", raising=False)
|
||||
monkeypatch.delenv("WHATSAPP_MENTION_PATTERNS", raising=False)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config is not None
|
||||
assert config.platforms[Platform.WHATSAPP].extra["require_mention"] is True
|
||||
assert config.platforms[Platform.WHATSAPP].extra["mention_patterns"] == [r"^\s*chompy\b"]
|
||||
assert __import__("os").environ["WHATSAPP_REQUIRE_MENTION"] == "true"
|
||||
assert json.loads(__import__("os").environ["WHATSAPP_MENTION_PATTERNS"]) == [r"^\s*chompy\b"]
|
||||
|
||||
|
||||
def test_free_response_chats_bypass_mention_gating():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
free_response_chats=["120363001234567890@g.us"],
|
||||
)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is True
|
||||
|
||||
|
||||
def test_free_response_chats_does_not_bypass_other_groups():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
free_response_chats=["999999999999@g.us"],
|
||||
)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
|
||||
|
||||
def test_dm_always_passes_even_with_require_mention():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
dm = {"isGroup": False, "body": "hello", "botIds": [], "mentionedIds": []}
|
||||
assert adapter._should_process_message(dm) is True
|
||||
|
||||
|
||||
def test_mention_stripping_removes_bot_phone_from_body():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
data = _group_message("@15551230000 what is the weather?")
|
||||
cleaned = adapter._clean_bot_mention_text(data["body"], data)
|
||||
assert "15551230000" not in cleaned
|
||||
assert "weather" in cleaned
|
||||
|
||||
|
||||
def test_mention_stripping_preserves_body_when_no_mention():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
data = _group_message("just a normal message")
|
||||
cleaned = adapter._clean_bot_mention_text(data["body"], data)
|
||||
assert cleaned == "just a normal message"
|
||||
@@ -466,6 +466,51 @@ class TestGeneratedUnitIncludesLocalBin:
|
||||
assert "/.local/bin" in unit
|
||||
|
||||
|
||||
class TestSystemServiceIdentityRootHandling:
|
||||
"""Root user handling in _system_service_identity()."""
|
||||
|
||||
def test_auto_detected_root_is_rejected(self, monkeypatch):
|
||||
"""When root is auto-detected (not explicitly requested), raise."""
|
||||
import pwd
|
||||
import grp
|
||||
|
||||
monkeypatch.delenv("SUDO_USER", raising=False)
|
||||
monkeypatch.setenv("USER", "root")
|
||||
monkeypatch.setenv("LOGNAME", "root")
|
||||
|
||||
import pytest
|
||||
with pytest.raises(ValueError, match="pass --run-as-user root to override"):
|
||||
gateway_cli._system_service_identity(run_as_user=None)
|
||||
|
||||
def test_explicit_root_is_allowed(self, monkeypatch):
|
||||
"""When root is explicitly passed via --run-as-user root, allow it."""
|
||||
import pwd
|
||||
import grp
|
||||
|
||||
root_info = pwd.getpwnam("root")
|
||||
root_group = grp.getgrgid(root_info.pw_gid).gr_name
|
||||
|
||||
username, group, home = gateway_cli._system_service_identity(run_as_user="root")
|
||||
assert username == "root"
|
||||
assert home == root_info.pw_dir
|
||||
|
||||
def test_non_root_user_passes_through(self, monkeypatch):
|
||||
"""Normal non-root user works as before."""
|
||||
import pwd
|
||||
import grp
|
||||
|
||||
monkeypatch.delenv("SUDO_USER", raising=False)
|
||||
monkeypatch.setenv("USER", "nobody")
|
||||
monkeypatch.setenv("LOGNAME", "nobody")
|
||||
|
||||
try:
|
||||
username, group, home = gateway_cli._system_service_identity(run_as_user=None)
|
||||
assert username == "nobody"
|
||||
except ValueError as e:
|
||||
# "nobody" might not exist on all systems
|
||||
assert "Unknown user" in str(e)
|
||||
|
||||
|
||||
class TestEnsureUserSystemdEnv:
|
||||
"""Tests for _ensure_user_systemd_env() D-Bus session bus auto-detection."""
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ def test_stash_local_changes_if_needed_returns_specific_stash_commit(monkeypatch
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[-2:] == ["status", "--porcelain"]:
|
||||
return SimpleNamespace(stdout=" M hermes_cli/main.py\n?? notes.txt\n", returncode=0)
|
||||
if cmd[-2:] == ["ls-files", "--unmerged"]:
|
||||
return SimpleNamespace(stdout="", returncode=0)
|
||||
if cmd[1:4] == ["stash", "push", "--include-untracked"]:
|
||||
return SimpleNamespace(stdout="Saved working directory\n", returncode=0)
|
||||
if cmd[-3:] == ["rev-parse", "--verify", "refs/stash"]:
|
||||
@@ -43,8 +45,9 @@ def test_stash_local_changes_if_needed_returns_specific_stash_commit(monkeypatch
|
||||
stash_ref = hermes_main._stash_local_changes_if_needed(["git"], tmp_path)
|
||||
|
||||
assert stash_ref == "abc123"
|
||||
assert calls[1][0][1:4] == ["stash", "push", "--include-untracked"]
|
||||
assert calls[2][0][-3:] == ["rev-parse", "--verify", "refs/stash"]
|
||||
assert calls[1][0][-2:] == ["ls-files", "--unmerged"]
|
||||
assert calls[2][0][1:4] == ["stash", "push", "--include-untracked"]
|
||||
assert calls[3][0][-3:] == ["rev-parse", "--verify", "refs/stash"]
|
||||
|
||||
|
||||
def test_resolve_stash_selector_returns_matching_entry(monkeypatch, tmp_path):
|
||||
@@ -296,6 +299,8 @@ def test_stash_local_changes_if_needed_raises_when_stash_ref_missing(monkeypatch
|
||||
def fake_run(cmd, **kwargs):
|
||||
if cmd[-2:] == ["status", "--porcelain"]:
|
||||
return SimpleNamespace(stdout=" M hermes_cli/main.py\n", returncode=0)
|
||||
if cmd[-2:] == ["ls-files", "--unmerged"]:
|
||||
return SimpleNamespace(stdout="", returncode=0)
|
||||
if cmd[1:4] == ["stash", "push", "--include-untracked"]:
|
||||
return SimpleNamespace(stdout="Saved working directory\n", returncode=0)
|
||||
if cmd[-3:] == ["rev-parse", "--verify", "refs/stash"]:
|
||||
|
||||
@@ -307,21 +307,14 @@ class TestCmdUpdateLaunchdRestart:
|
||||
|
||||
# Mock get_running_pid to return a PID
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"):
|
||||
patch("gateway.status.remove_pid_file"), \
|
||||
patch.object(gateway_cli, "launchd_restart") as mock_launchd_restart:
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Gateway restarted via launchd" in captured
|
||||
assert "Restarting gateway service" in captured
|
||||
assert "Restart it with: hermes gateway run" not in captured
|
||||
# Verify launchctl stop + start were called (not manual SIGTERM)
|
||||
launchctl_calls = [
|
||||
c for c in mock_run.call_args_list
|
||||
if len(c.args[0]) > 0 and c.args[0][0] == "launchctl"
|
||||
]
|
||||
stop_calls = [c for c in launchctl_calls if "stop" in c.args[0]]
|
||||
start_calls = [c for c in launchctl_calls if "start" in c.args[0]]
|
||||
assert len(stop_calls) >= 1
|
||||
assert len(start_calls) >= 1
|
||||
mock_launchd_restart.assert_called_once_with()
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
|
||||
@@ -191,6 +191,60 @@ class TestHistoryDisplay:
|
||||
assert "A" * 250 in output
|
||||
assert "A" * 250 + "..." not in output
|
||||
|
||||
def test_history_shows_recent_sessions_when_current_chat_is_empty(self, capsys):
|
||||
cli = _make_cli()
|
||||
cli.session_id = "current"
|
||||
cli._session_db = MagicMock()
|
||||
cli._session_db.list_sessions_rich.return_value = [
|
||||
{
|
||||
"id": "current",
|
||||
"title": "Current",
|
||||
"preview": "Current preview",
|
||||
"last_active": 0,
|
||||
},
|
||||
{
|
||||
"id": "20260401_201329_d85961",
|
||||
"title": "Checking Running Hermes Agent",
|
||||
"preview": "check running gateways for hermes agent",
|
||||
"last_active": 0,
|
||||
},
|
||||
]
|
||||
|
||||
cli.show_history()
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "No messages in the current chat yet" in output
|
||||
assert "Checking Running Hermes Agent" in output
|
||||
assert "20260401_201329_d85961" in output
|
||||
assert "/resume" in output
|
||||
assert "Current preview" not in output
|
||||
|
||||
def test_resume_without_target_lists_recent_sessions(self, capsys):
|
||||
cli = _make_cli()
|
||||
cli.session_id = "current"
|
||||
cli._session_db = MagicMock()
|
||||
cli._session_db.list_sessions_rich.return_value = [
|
||||
{
|
||||
"id": "current",
|
||||
"title": "Current",
|
||||
"preview": "Current preview",
|
||||
"last_active": 0,
|
||||
},
|
||||
{
|
||||
"id": "20260401_201329_d85961",
|
||||
"title": "Checking Running Hermes Agent",
|
||||
"preview": "check running gateways for hermes agent",
|
||||
"last_active": 0,
|
||||
},
|
||||
]
|
||||
|
||||
cli._handle_resume_command("/resume")
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Recent sessions" in output
|
||||
assert "Checking Running Hermes Agent" in output
|
||||
assert "Use /resume <session id or title> to continue" in output
|
||||
|
||||
|
||||
class TestRootLevelProviderOverride:
|
||||
"""Root-level provider/base_url in config.yaml must NOT override model.provider."""
|
||||
|
||||
209
tests/test_long_context_tier_429.py
Normal file
209
tests/test_long_context_tier_429.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Tests for Anthropic Sonnet long-context tier 429 handling.
|
||||
|
||||
When Claude Max users without "extra usage" hit the 1M context tier
|
||||
on Sonnet, Anthropic returns HTTP 429 "Extra usage is required for long
|
||||
context requests." This is NOT a transient rate limit — the agent should
|
||||
reduce context_length to 200k and compress instead of retrying.
|
||||
|
||||
Only Sonnet is affected — Opus 1M is general access.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLongContextTierDetection:
|
||||
"""Verify the detection heuristic matches the Anthropic error."""
|
||||
|
||||
@staticmethod
|
||||
def _is_long_context_tier_error(status_code, error_msg, model="claude-sonnet-4.6"):
|
||||
error_msg = error_msg.lower()
|
||||
return (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
and "sonnet" in model.lower()
|
||||
)
|
||||
|
||||
def test_matches_anthropic_error(self):
|
||||
assert self._is_long_context_tier_error(
|
||||
429,
|
||||
"Extra usage is required for long context requests.",
|
||||
)
|
||||
|
||||
def test_matches_lowercase(self):
|
||||
assert self._is_long_context_tier_error(
|
||||
429,
|
||||
"extra usage is required for long context requests.",
|
||||
)
|
||||
|
||||
def test_matches_openrouter_model_id(self):
|
||||
assert self._is_long_context_tier_error(
|
||||
429,
|
||||
"Extra usage is required for long context requests.",
|
||||
model="anthropic/claude-sonnet-4.6",
|
||||
)
|
||||
|
||||
def test_matches_nous_model_id(self):
|
||||
assert self._is_long_context_tier_error(
|
||||
429,
|
||||
"Extra usage is required for long context requests.",
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
|
||||
def test_rejects_opus(self):
|
||||
"""Opus 1M is general access — should NOT trigger reduction."""
|
||||
assert not self._is_long_context_tier_error(
|
||||
429,
|
||||
"Extra usage is required for long context requests.",
|
||||
model="claude-opus-4.6",
|
||||
)
|
||||
|
||||
def test_rejects_opus_openrouter(self):
|
||||
assert not self._is_long_context_tier_error(
|
||||
429,
|
||||
"Extra usage is required for long context requests.",
|
||||
model="anthropic/claude-opus-4.6",
|
||||
)
|
||||
|
||||
def test_rejects_normal_429(self):
|
||||
assert not self._is_long_context_tier_error(
|
||||
429,
|
||||
"Rate limit exceeded. Please retry after 30 seconds.",
|
||||
)
|
||||
|
||||
def test_rejects_wrong_status(self):
|
||||
assert not self._is_long_context_tier_error(
|
||||
400,
|
||||
"Extra usage is required for long context requests.",
|
||||
)
|
||||
|
||||
def test_rejects_partial_match(self):
|
||||
"""Both 'extra usage' AND 'long context' must be present."""
|
||||
assert not self._is_long_context_tier_error(
|
||||
429, "extra usage required"
|
||||
)
|
||||
assert not self._is_long_context_tier_error(
|
||||
429, "long context requests not supported"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context reduction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestContextReduction:
|
||||
"""When the long-context tier error fires, context_length should
|
||||
drop to 200k and the reduced flag should be set correctly."""
|
||||
|
||||
def _make_compressor(self, context_length=1_000_000, threshold_percent=0.5):
|
||||
c = SimpleNamespace(
|
||||
context_length=context_length,
|
||||
threshold_percent=threshold_percent,
|
||||
threshold_tokens=int(context_length * threshold_percent),
|
||||
_context_probed=False,
|
||||
_context_probe_persistable=False,
|
||||
)
|
||||
return c
|
||||
|
||||
def test_reduces_1m_to_200k(self):
|
||||
comp = self._make_compressor(1_000_000)
|
||||
reduced_ctx = 200_000
|
||||
|
||||
if comp.context_length > reduced_ctx:
|
||||
comp.context_length = reduced_ctx
|
||||
comp.threshold_tokens = int(reduced_ctx * comp.threshold_percent)
|
||||
comp._context_probed = True
|
||||
comp._context_probe_persistable = False
|
||||
|
||||
assert comp.context_length == 200_000
|
||||
assert comp.threshold_tokens == 100_000
|
||||
assert comp._context_probed is True
|
||||
# Must NOT persist — subscription tier, not model capability
|
||||
assert comp._context_probe_persistable is False
|
||||
|
||||
def test_no_reduction_when_already_200k(self):
|
||||
comp = self._make_compressor(200_000)
|
||||
reduced_ctx = 200_000
|
||||
|
||||
original = comp.context_length
|
||||
if comp.context_length > reduced_ctx:
|
||||
comp.context_length = reduced_ctx
|
||||
|
||||
assert comp.context_length == original # unchanged
|
||||
|
||||
def test_no_reduction_when_below_200k(self):
|
||||
comp = self._make_compressor(128_000)
|
||||
reduced_ctx = 200_000
|
||||
|
||||
original = comp.context_length
|
||||
if comp.context_length > reduced_ctx:
|
||||
comp.context_length = reduced_ctx
|
||||
|
||||
assert comp.context_length == original # unchanged
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: agent error handler path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAgentErrorPath:
|
||||
"""Verify the long-context 429 doesn't hit the generic rate-limit
|
||||
or client-error handlers."""
|
||||
|
||||
def test_long_context_429_not_treated_as_rate_limit(self):
|
||||
"""The error should be intercepted before the generic
|
||||
is_rate_limited check fires a fallback switch."""
|
||||
error_msg = "extra usage is required for long context requests."
|
||||
status_code = 429
|
||||
model = "claude-sonnet-4.6"
|
||||
|
||||
_is_long_context_tier_error = (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
and "sonnet" in model.lower()
|
||||
)
|
||||
assert _is_long_context_tier_error
|
||||
|
||||
def test_opus_429_falls_through_to_rate_limit(self):
|
||||
"""Opus should NOT match — falls through to generic rate-limit."""
|
||||
error_msg = "extra usage is required for long context requests."
|
||||
status_code = 429
|
||||
model = "claude-opus-4.6"
|
||||
|
||||
_is_long_context_tier_error = (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
and "sonnet" in model.lower()
|
||||
)
|
||||
assert not _is_long_context_tier_error
|
||||
|
||||
def test_normal_429_still_treated_as_rate_limit(self):
|
||||
"""A normal 429 should NOT match the long-context check."""
|
||||
error_msg = "rate limit exceeded"
|
||||
status_code = 429
|
||||
model = "claude-sonnet-4.6"
|
||||
|
||||
_is_long_context_tier_error = (
|
||||
status_code == 429
|
||||
and "extra usage" in error_msg
|
||||
and "long context" in error_msg
|
||||
and "sonnet" in model.lower()
|
||||
)
|
||||
assert not _is_long_context_tier_error
|
||||
|
||||
is_rate_limited = (
|
||||
status_code == 429
|
||||
or "rate limit" in error_msg
|
||||
)
|
||||
assert is_rate_limited
|
||||
@@ -9,10 +9,13 @@ import pytest
|
||||
|
||||
from tools.mcp_oauth import (
|
||||
HermesTokenStorage,
|
||||
OAuthNonInteractiveError,
|
||||
build_oauth_auth,
|
||||
remove_oauth_tokens,
|
||||
_find_free_port,
|
||||
_can_open_browser,
|
||||
_is_interactive,
|
||||
_wait_for_callback,
|
||||
)
|
||||
|
||||
|
||||
@@ -236,3 +239,99 @@ class TestRemoveOAuthTokens:
|
||||
def test_no_error_when_files_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
remove_oauth_tokens("nonexistent") # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-interactive / startup-safety tests (issue #4462)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsInteractive:
|
||||
"""_is_interactive() detects headless/daemon/container environments."""
|
||||
|
||||
def test_false_when_stdin_not_tty(self, monkeypatch):
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = False
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
assert _is_interactive() is False
|
||||
|
||||
def test_true_when_stdin_is_tty(self, monkeypatch):
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = True
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
assert _is_interactive() is True
|
||||
|
||||
def test_false_when_stdin_has_no_isatty(self, monkeypatch):
|
||||
"""Some environments replace stdin with an object without isatty()."""
|
||||
mock_stdin = object() # no isatty attribute
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
assert _is_interactive() is False
|
||||
|
||||
|
||||
class TestWaitForCallbackNoBlocking:
|
||||
"""_wait_for_callback() must never call input() — it raises instead."""
|
||||
|
||||
def test_raises_on_timeout_instead_of_input(self):
|
||||
"""When no auth code arrives, raises OAuthNonInteractiveError."""
|
||||
import tools.mcp_oauth as mod
|
||||
import asyncio
|
||||
|
||||
mod._oauth_port = _find_free_port()
|
||||
|
||||
async def instant_sleep(_seconds):
|
||||
pass
|
||||
|
||||
with patch.object(mod.asyncio, "sleep", instant_sleep):
|
||||
with patch("builtins.input", side_effect=AssertionError("input() must not be called")):
|
||||
with pytest.raises(OAuthNonInteractiveError, match="callback timed out"):
|
||||
asyncio.run(_wait_for_callback())
|
||||
|
||||
|
||||
class TestBuildOAuthAuthNonInteractive:
|
||||
"""build_oauth_auth() in non-interactive mode."""
|
||||
|
||||
def test_noninteractive_without_cached_tokens_warns(self, tmp_path, monkeypatch, caplog):
|
||||
"""Without cached tokens, non-interactive mode logs a clear warning."""
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = False
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
|
||||
import logging
|
||||
with caplog.at_level(logging.WARNING, logger="tools.mcp_oauth"):
|
||||
auth = build_oauth_auth("atlassian", "https://mcp.atlassian.com/v1/mcp")
|
||||
|
||||
assert auth is not None
|
||||
assert "no cached tokens found" in caplog.text.lower()
|
||||
assert "non-interactive" in caplog.text.lower()
|
||||
|
||||
def test_noninteractive_with_cached_tokens_no_warning(self, tmp_path, monkeypatch, caplog):
|
||||
"""With cached tokens, non-interactive mode logs no 'no cached tokens' warning."""
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = False
|
||||
monkeypatch.setattr("tools.mcp_oauth.sys.stdin", mock_stdin)
|
||||
|
||||
# Pre-populate cached tokens
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "atlassian.json").write_text(json.dumps({
|
||||
"access_token": "cached",
|
||||
"token_type": "Bearer",
|
||||
}))
|
||||
|
||||
import logging
|
||||
with caplog.at_level(logging.WARNING, logger="tools.mcp_oauth"):
|
||||
auth = build_oauth_auth("atlassian", "https://mcp.atlassian.com/v1/mcp")
|
||||
|
||||
assert auth is not None
|
||||
assert "no cached tokens found" not in caplog.text.lower()
|
||||
|
||||
143
tests/tools/test_mcp_stability.py
Normal file
143
tests/tools/test_mcp_stability.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for MCP stability fixes — event loop handler, PID tracking, shutdown robustness."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 1: MCP event loop exception handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMCPLoopExceptionHandler:
|
||||
"""_mcp_loop_exception_handler suppresses benign 'Event loop is closed'."""
|
||||
|
||||
def test_suppresses_event_loop_closed(self):
|
||||
from tools.mcp_tool import _mcp_loop_exception_handler
|
||||
loop = MagicMock()
|
||||
context = {"exception": RuntimeError("Event loop is closed")}
|
||||
# Should NOT call default handler
|
||||
_mcp_loop_exception_handler(loop, context)
|
||||
loop.default_exception_handler.assert_not_called()
|
||||
|
||||
def test_forwards_other_runtime_errors(self):
|
||||
from tools.mcp_tool import _mcp_loop_exception_handler
|
||||
loop = MagicMock()
|
||||
context = {"exception": RuntimeError("some other error")}
|
||||
_mcp_loop_exception_handler(loop, context)
|
||||
loop.default_exception_handler.assert_called_once_with(context)
|
||||
|
||||
def test_forwards_non_runtime_errors(self):
|
||||
from tools.mcp_tool import _mcp_loop_exception_handler
|
||||
loop = MagicMock()
|
||||
context = {"exception": ValueError("bad value")}
|
||||
_mcp_loop_exception_handler(loop, context)
|
||||
loop.default_exception_handler.assert_called_once_with(context)
|
||||
|
||||
def test_forwards_contexts_without_exception(self):
|
||||
from tools.mcp_tool import _mcp_loop_exception_handler
|
||||
loop = MagicMock()
|
||||
context = {"message": "just a message"}
|
||||
_mcp_loop_exception_handler(loop, context)
|
||||
loop.default_exception_handler.assert_called_once_with(context)
|
||||
|
||||
def test_handler_installed_on_mcp_loop(self):
|
||||
"""_ensure_mcp_loop installs the exception handler on the new loop."""
|
||||
import tools.mcp_tool as mcp_mod
|
||||
try:
|
||||
mcp_mod._ensure_mcp_loop()
|
||||
with mcp_mod._lock:
|
||||
loop = mcp_mod._mcp_loop
|
||||
assert loop is not None
|
||||
assert loop.get_exception_handler() is mcp_mod._mcp_loop_exception_handler
|
||||
finally:
|
||||
mcp_mod._stop_mcp_loop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 2: stdio PID tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStdioPidTracking:
|
||||
"""_snapshot_child_pids and _stdio_pids track subprocess PIDs."""
|
||||
|
||||
def test_snapshot_returns_set(self):
|
||||
from tools.mcp_tool import _snapshot_child_pids
|
||||
result = _snapshot_child_pids()
|
||||
assert isinstance(result, set)
|
||||
# All elements should be ints
|
||||
for pid in result:
|
||||
assert isinstance(pid, int)
|
||||
|
||||
def test_stdio_pids_starts_empty(self):
|
||||
from tools.mcp_tool import _stdio_pids, _lock
|
||||
with _lock:
|
||||
# Might have residual state from other tests, just check type
|
||||
assert isinstance(_stdio_pids, set)
|
||||
|
||||
def test_kill_orphaned_noop_when_empty(self):
|
||||
"""_kill_orphaned_mcp_children does nothing when no PIDs tracked."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
|
||||
# Should not raise
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
def test_kill_orphaned_handles_dead_pids(self):
|
||||
"""_kill_orphaned_mcp_children gracefully handles already-dead PIDs."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
# Use a PID that definitely doesn't exist
|
||||
fake_pid = 999999999
|
||||
with _lock:
|
||||
_stdio_pids.add(fake_pid)
|
||||
|
||||
# Should not raise (ProcessLookupError is caught)
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 3: MCP reload timeout (cli.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMCPReloadTimeout:
|
||||
"""_check_config_mcp_changes uses a timeout on _reload_mcp."""
|
||||
|
||||
def test_reload_timeout_does_not_block_forever(self, tmp_path, monkeypatch):
|
||||
"""If _reload_mcp hangs, the config watcher times out and returns."""
|
||||
import time
|
||||
|
||||
# Create a mock HermesCLI-like object with the needed attributes
|
||||
class FakeCLI:
|
||||
_config_mtime = 0.0
|
||||
_config_mcp_servers = {}
|
||||
_last_config_check = 0.0
|
||||
_command_running = False
|
||||
config = {}
|
||||
agent = None
|
||||
|
||||
def _reload_mcp(self):
|
||||
# Simulate a hang — sleep longer than the timeout
|
||||
time.sleep(60)
|
||||
|
||||
def _slow_command_status(self, cmd):
|
||||
return cmd
|
||||
|
||||
# This test verifies the timeout mechanism exists in the code
|
||||
# by checking that _check_config_mcp_changes doesn't call
|
||||
# _reload_mcp directly (it uses a thread now)
|
||||
import inspect
|
||||
from cli import HermesCLI
|
||||
source = inspect.getsource(HermesCLI._check_config_mcp_changes)
|
||||
# The fix adds threading.Thread for _reload_mcp
|
||||
assert "Thread" in source or "thread" in source.lower(), \
|
||||
"_check_config_mcp_changes should use a thread for _reload_mcp"
|
||||
@@ -2900,3 +2900,164 @@ class TestMCPBuiltinCollisionGuard:
|
||||
assert mock_registry.get_toolset_for_tool("mcp_srv_do_thing") == "mcp-srv"
|
||||
|
||||
_servers.pop("srv", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sanitize_mcp_name_component
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeMcpNameComponent:
|
||||
"""Verify sanitize_mcp_name_component handles all edge cases."""
|
||||
|
||||
def test_hyphens_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("my-server") == "my_server"
|
||||
|
||||
def test_dots_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("ai.exa") == "ai_exa"
|
||||
|
||||
def test_slashes_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("ai.exa/exa") == "ai_exa_exa"
|
||||
|
||||
def test_mixed_special_characters(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("@scope/my-pkg.v2") == "_scope_my_pkg_v2"
|
||||
|
||||
def test_alphanumeric_and_underscores_preserved(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("my_server_123") == "my_server_123"
|
||||
|
||||
def test_empty_string(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("") == ""
|
||||
|
||||
def test_none_returns_empty(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component(None) == ""
|
||||
|
||||
def test_slash_in_convert_mcp_schema(self):
|
||||
"""Server names with slashes produce valid tool names via _convert_mcp_schema."""
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
mcp_tool = _make_mcp_tool(name="search")
|
||||
schema = _convert_mcp_schema("ai.exa/exa", mcp_tool)
|
||||
assert schema["name"] == "mcp_ai_exa_exa_search"
|
||||
# Must match Anthropic's pattern: ^[a-zA-Z0-9_-]{1,128}$
|
||||
import re
|
||||
assert re.match(r"^[a-zA-Z0-9_-]{1,128}$", schema["name"])
|
||||
|
||||
def test_slash_in_build_utility_schemas(self):
|
||||
"""Server names with slashes produce valid utility tool names."""
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("ai.exa/exa")
|
||||
for s in schemas:
|
||||
name = s["schema"]["name"]
|
||||
assert "/" not in name
|
||||
assert "." not in name
|
||||
|
||||
def test_slash_in_sync_mcp_toolsets(self):
|
||||
"""_sync_mcp_toolsets uses sanitize consistently with _convert_mcp_schema."""
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
|
||||
# Verify the prefix generation matches what _convert_mcp_schema produces
|
||||
server_name = "ai.exa/exa"
|
||||
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
|
||||
assert safe_prefix == "mcp_ai_exa_exa_"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_mcp_servers public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterMcpServers:
|
||||
"""Verify the new register_mcp_servers() public API."""
|
||||
|
||||
def test_empty_servers_returns_empty(self):
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True):
|
||||
result = register_mcp_servers({})
|
||||
assert result == []
|
||||
|
||||
def test_mcp_not_available_returns_empty(self):
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", False):
|
||||
result = register_mcp_servers({"srv": {"command": "test"}})
|
||||
assert result == []
|
||||
|
||||
def test_skips_already_connected_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers
|
||||
|
||||
mock_server = _make_mock_server("existing")
|
||||
_servers["existing"] = mock_server
|
||||
|
||||
try:
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_existing_tool"]):
|
||||
result = register_mcp_servers({"existing": {"command": "test"}})
|
||||
assert result == ["mcp_existing_tool"]
|
||||
finally:
|
||||
_servers.pop("existing", None)
|
||||
|
||||
def test_skips_disabled_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers
|
||||
|
||||
try:
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=[]):
|
||||
result = register_mcp_servers({"srv": {"command": "test", "enabled": False}})
|
||||
assert result == []
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_connects_new_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
|
||||
|
||||
fake_config = {"my_server": {"command": "npx", "args": ["test"]}}
|
||||
|
||||
async def fake_register(name, cfg):
|
||||
server = _make_mock_server(name)
|
||||
server._registered_tool_names = ["mcp_my_server_tool1"]
|
||||
_servers[name] = server
|
||||
return ["mcp_my_server_tool1"]
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_my_server_tool1"]):
|
||||
_ensure_mcp_loop()
|
||||
result = register_mcp_servers(fake_config)
|
||||
|
||||
assert "mcp_my_server_tool1" in result
|
||||
_servers.pop("my_server", None)
|
||||
|
||||
def test_logs_summary_on_success(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
|
||||
|
||||
fake_config = {"srv": {"command": "npx", "args": ["test"]}}
|
||||
|
||||
async def fake_register(name, cfg):
|
||||
server = _make_mock_server(name)
|
||||
server._registered_tool_names = ["mcp_srv_t1", "mcp_srv_t2"]
|
||||
_servers[name] = server
|
||||
return ["mcp_srv_t1", "mcp_srv_t2"]
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_srv_t1", "mcp_srv_t2"]):
|
||||
_ensure_mcp_loop()
|
||||
|
||||
with patch("tools.mcp_tool.logger") as mock_logger:
|
||||
register_mcp_servers(fake_config)
|
||||
|
||||
info_calls = [str(c) for c in mock_logger.info.call_args_list]
|
||||
assert any("2 tool(s)" in c and "1 server(s)" in c for c in info_calls), (
|
||||
f"Summary should report 2 tools from 1 server, got: {info_calls}"
|
||||
)
|
||||
|
||||
_servers.pop("srv", None)
|
||||
|
||||
@@ -5,6 +5,12 @@ Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
|
||||
authorization. The SDK handles all of the heavy lifting: PKCE generation,
|
||||
metadata discovery, dynamic client registration, token exchange, and refresh.
|
||||
|
||||
Startup safety:
|
||||
The callback handler never calls blocking ``input()`` on the event loop.
|
||||
In non-interactive environments (no TTY, SSH, headless), the OAuth flow
|
||||
raises ``OAuthNonInteractiveError`` instead of blocking, so that the
|
||||
server degrades gracefully and other MCP servers are not affected.
|
||||
|
||||
Usage in mcp_tool.py::
|
||||
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
@@ -19,6 +25,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
@@ -28,6 +35,11 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthNonInteractiveError(RuntimeError):
|
||||
"""Raised when OAuth requires user interaction but the environment is non-interactive."""
|
||||
pass
|
||||
|
||||
_TOKEN_DIR_NAME = "mcp-tokens"
|
||||
|
||||
|
||||
@@ -164,7 +176,13 @@ async def _redirect_to_browser(auth_url: str) -> None:
|
||||
|
||||
|
||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect."""
|
||||
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect.
|
||||
|
||||
If the callback times out, raises ``OAuthNonInteractiveError`` instead of
|
||||
calling blocking ``input()`` — the old ``input()`` call would block the
|
||||
entire MCP asyncio event loop, preventing all other MCP servers from
|
||||
connecting and potentially hanging Hermes startup indefinitely.
|
||||
"""
|
||||
global _oauth_port
|
||||
port = _oauth_port or _find_free_port()
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
@@ -186,8 +204,10 @@ async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
code = result["auth_code"] or ""
|
||||
state = result["state"]
|
||||
if not code:
|
||||
print(" Browser callback timed out. Paste the authorization code manually:")
|
||||
code = input(" Code: ").strip()
|
||||
raise OAuthNonInteractiveError(
|
||||
"OAuth browser callback timed out after 120 seconds. "
|
||||
"Run 'hermes mcp auth <server-name>' to authorize interactively."
|
||||
)
|
||||
return code, state
|
||||
|
||||
|
||||
@@ -199,6 +219,17 @@ def _can_open_browser() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _is_interactive() -> bool:
|
||||
"""Check if the current environment can support interactive OAuth flows.
|
||||
|
||||
Returns False in headless/daemon/container environments where no user
|
||||
can interact with a browser or paste an auth code.
|
||||
"""
|
||||
if not hasattr(sys.stdin, "isatty") or not sys.stdin.isatty():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -209,6 +240,11 @@ def build_oauth_auth(server_name: str, server_url: str):
|
||||
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
|
||||
registration, PKCE, token exchange, and refresh automatically.
|
||||
|
||||
In non-interactive environments (no TTY), this still returns a provider
|
||||
so that **cached tokens and refresh flows work**. Only the interactive
|
||||
authorization-code grant will fail fast with a clear error instead of
|
||||
blocking the event loop.
|
||||
|
||||
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
|
||||
or ``None`` if the MCP SDK auth module is not available.
|
||||
"""
|
||||
@@ -219,6 +255,25 @@ def build_oauth_auth(server_name: str, server_url: str):
|
||||
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
||||
return None
|
||||
|
||||
storage = HermesTokenStorage(server_name)
|
||||
interactive = _is_interactive()
|
||||
|
||||
if not interactive:
|
||||
# Check whether cached tokens exist. If they do, the SDK can still
|
||||
# use them (and refresh them) without any user interaction. If not,
|
||||
# we still build the provider — the callback_handler will raise
|
||||
# OAuthNonInteractiveError if a fresh authorization is actually
|
||||
# needed, which surfaces as a clean connection failure for this
|
||||
# server only (other MCP servers are unaffected).
|
||||
has_cached = storage._read_json(storage._tokens_path()) is not None
|
||||
if not has_cached:
|
||||
logger.warning(
|
||||
"MCP server '%s' requires OAuth but no cached tokens found "
|
||||
"and environment is non-interactive. The server will fail to "
|
||||
"connect. Run 'hermes mcp auth %s' to authorize interactively.",
|
||||
server_name, server_name,
|
||||
)
|
||||
|
||||
global _oauth_port
|
||||
_oauth_port = _find_free_port()
|
||||
redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback"
|
||||
@@ -232,14 +287,36 @@ def build_oauth_auth(server_name: str, server_url: str):
|
||||
token_endpoint_auth_method="none",
|
||||
)
|
||||
|
||||
storage = HermesTokenStorage(server_name)
|
||||
# In non-interactive mode, the redirect handler logs the URL and the
|
||||
# callback handler raises immediately — no blocking, no input().
|
||||
redirect_handler = _redirect_to_browser
|
||||
callback_handler = _wait_for_callback
|
||||
|
||||
if not interactive:
|
||||
async def _noninteractive_redirect(auth_url: str) -> None:
|
||||
logger.warning(
|
||||
"MCP server '%s' needs OAuth authorization (non-interactive, "
|
||||
"cannot open browser). URL: %s",
|
||||
server_name, auth_url,
|
||||
)
|
||||
|
||||
async def _noninteractive_callback() -> tuple[str, str | None]:
|
||||
raise OAuthNonInteractiveError(
|
||||
f"MCP server '{server_name}' requires interactive OAuth "
|
||||
f"authorization but the environment is non-interactive "
|
||||
f"(no TTY). Run 'hermes mcp auth {server_name}' to "
|
||||
f"authorize, then restart."
|
||||
)
|
||||
|
||||
redirect_handler = _noninteractive_redirect
|
||||
callback_handler = _noninteractive_callback
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=_redirect_to_browser,
|
||||
callback_handler=_wait_for_callback,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
|
||||
@@ -842,13 +842,25 @@ class MCPServerTask:
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
sampling_kwargs["message_handler"] = self._make_message_handler()
|
||||
|
||||
# Snapshot child PIDs before spawning so we can track the new one.
|
||||
pids_before = _snapshot_child_pids()
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
# Capture the newly spawned subprocess PID for force-kill cleanup.
|
||||
new_pids = _snapshot_child_pids() - pids_before
|
||||
if new_pids:
|
||||
with _lock:
|
||||
_stdio_pids.update(new_pids)
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
await self._shutdown_event.wait()
|
||||
# Context exited cleanly — subprocess was terminated by the SDK.
|
||||
if new_pids:
|
||||
with _lock:
|
||||
_stdio_pids.difference_update(new_pids)
|
||||
|
||||
async def _run_http(self, config: dict):
|
||||
"""Run the server using HTTP/StreamableHTTP transport."""
|
||||
@@ -863,7 +875,10 @@ class MCPServerTask:
|
||||
headers = dict(config.get("headers") or {})
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK
|
||||
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK.
|
||||
# If OAuth setup fails (e.g. non-interactive environment without
|
||||
# cached tokens), re-raise so this server is reported as failed
|
||||
# without blocking other MCP servers from connecting.
|
||||
_oauth_auth = None
|
||||
if self._auth_type == "oauth":
|
||||
try:
|
||||
@@ -871,6 +886,7 @@ class MCPServerTask:
|
||||
_oauth_auth = build_oauth_auth(self.name, url)
|
||||
except Exception as exc:
|
||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||
raise
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED:
|
||||
@@ -1044,9 +1060,56 @@ _servers: Dict[str, MCPServerTask] = {}
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
|
||||
# Protects _mcp_loop, _mcp_thread, _servers, and _stdio_pids.
|
||||
_lock = threading.Lock()
|
||||
|
||||
# PIDs of stdio MCP server subprocesses. Tracked so we can force-kill
|
||||
# them on shutdown if the graceful cleanup (SDK context-manager teardown)
|
||||
# fails or times out. PIDs are added after connection and removed on
|
||||
# normal server shutdown.
|
||||
_stdio_pids: set = set()
|
||||
|
||||
|
||||
def _snapshot_child_pids() -> set:
|
||||
"""Return a set of current child process PIDs.
|
||||
|
||||
Uses /proc on Linux, falls back to psutil, then empty set.
|
||||
Used by _run_stdio to identify the subprocess spawned by stdio_client.
|
||||
"""
|
||||
my_pid = os.getpid()
|
||||
|
||||
# Linux: read from /proc
|
||||
try:
|
||||
children_path = f"/proc/{my_pid}/task/{my_pid}/children"
|
||||
with open(children_path) as f:
|
||||
return {int(p) for p in f.read().split() if p.strip()}
|
||||
except (FileNotFoundError, OSError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: psutil
|
||||
try:
|
||||
import psutil
|
||||
return {c.pid for c in psutil.Process(my_pid).children()}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return set()
|
||||
|
||||
|
||||
def _mcp_loop_exception_handler(loop, context):
|
||||
"""Suppress benign 'Event loop is closed' noise during shutdown.
|
||||
|
||||
When the MCP event loop is stopped and closed, httpx/httpcore async
|
||||
transports may fire __del__ finalizers that call call_soon() on the
|
||||
dead loop. asyncio catches that RuntimeError and routes it here.
|
||||
We silence it because the connection is being torn down anyway; all
|
||||
other exceptions are forwarded to the default handler.
|
||||
"""
|
||||
exc = context.get("exception")
|
||||
if isinstance(exc, RuntimeError) and "Event loop is closed" in str(exc):
|
||||
return # benign shutdown race — suppress
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
|
||||
def _ensure_mcp_loop():
|
||||
"""Start the background event loop thread if not already running."""
|
||||
@@ -1055,6 +1118,7 @@ def _ensure_mcp_loop():
|
||||
if _mcp_loop is not None and _mcp_loop.is_running():
|
||||
return
|
||||
_mcp_loop = asyncio.new_event_loop()
|
||||
_mcp_loop.set_exception_handler(_mcp_loop_exception_handler)
|
||||
_mcp_thread = threading.Thread(
|
||||
target=_mcp_loop.run_forever,
|
||||
name="mcp-event-loop",
|
||||
@@ -1406,6 +1470,17 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
|
||||
return schema
|
||||
|
||||
|
||||
def sanitize_mcp_name_component(value: str) -> str:
|
||||
"""Return an MCP name component safe for tool and prefix generation.
|
||||
|
||||
Preserves Hermes's historical behavior of converting hyphens to
|
||||
underscores, and also replaces any other character outside
|
||||
``[A-Za-z0-9_]`` with ``_`` so generated tool names are compatible with
|
||||
provider validation rules.
|
||||
"""
|
||||
return re.sub(r"[^A-Za-z0-9_]", "_", str(value or ""))
|
||||
|
||||
|
||||
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
"""Convert an MCP tool listing to the Hermes registry schema format.
|
||||
|
||||
@@ -1417,9 +1492,8 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
Returns:
|
||||
A dict suitable for ``registry.register(schema=...)``.
|
||||
"""
|
||||
# Sanitize: replace hyphens and dots with underscores for LLM API compatibility
|
||||
safe_tool_name = mcp_tool.name.replace("-", "_").replace(".", "_")
|
||||
safe_server_name = server_name.replace("-", "_").replace(".", "_")
|
||||
safe_tool_name = sanitize_mcp_name_component(mcp_tool.name)
|
||||
safe_server_name = sanitize_mcp_name_component(server_name)
|
||||
prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}"
|
||||
return {
|
||||
"name": prefixed_name,
|
||||
@@ -1449,7 +1523,7 @@ def _sync_mcp_toolsets(server_names: Optional[List[str]] = None) -> None:
|
||||
all_mcp_tools: List[str] = []
|
||||
|
||||
for server_name in server_names:
|
||||
safe_prefix = f"mcp_{server_name.replace('-', '_').replace('.', '_')}_"
|
||||
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
|
||||
server_tools = sorted(
|
||||
t for t in existing if t.startswith(safe_prefix)
|
||||
)
|
||||
@@ -1485,7 +1559,7 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
|
||||
with keys: schema, handler_key.
|
||||
"""
|
||||
safe_name = server_name.replace("-", "_").replace(".", "_")
|
||||
safe_name = sanitize_mcp_name_component(server_name)
|
||||
return [
|
||||
{
|
||||
"schema": {
|
||||
@@ -1772,6 +1846,86 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def register_mcp_servers(servers: Dict[str, dict]) -> List[str]:
|
||||
"""Connect to explicit MCP servers and register their tools.
|
||||
|
||||
Idempotent for already-connected server names. Servers with
|
||||
``enabled: false`` are skipped without disconnecting existing sessions.
|
||||
|
||||
Args:
|
||||
servers: Mapping of ``{server_name: server_config}``.
|
||||
|
||||
Returns:
|
||||
List of all currently registered MCP tool names.
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
logger.debug("MCP SDK not available -- skipping explicit MCP registration")
|
||||
return []
|
||||
|
||||
if not servers:
|
||||
logger.debug("No explicit MCP servers provided")
|
||||
return []
|
||||
|
||||
# Only attempt servers that aren't already connected and are enabled
|
||||
# (enabled: false skips the server entirely without removing its config)
|
||||
with _lock:
|
||||
new_servers = {
|
||||
k: v
|
||||
for k, v in servers.items()
|
||||
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
|
||||
if not new_servers:
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
return _existing_tool_names()
|
||||
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
return await _discover_and_register_server(name, cfg)
|
||||
|
||||
async def _discover_all():
|
||||
server_names = list(new_servers.keys())
|
||||
# Connect to all servers in PARALLEL
|
||||
results = await asyncio.gather(
|
||||
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for name, result in zip(server_names, results):
|
||||
if isinstance(result, Exception):
|
||||
command = new_servers.get(name, {}).get("command")
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s'%s: %s",
|
||||
name,
|
||||
f" (command={command})" if command else "",
|
||||
_format_connect_error(result),
|
||||
)
|
||||
|
||||
# Per-server timeouts are handled inside _discover_and_register_server.
|
||||
# The outer timeout is generous: 120s total for parallel discovery.
|
||||
_run_on_mcp_loop(_discover_all(), timeout=120)
|
||||
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
|
||||
# Log a summary so ACP callers get visibility into what was registered.
|
||||
with _lock:
|
||||
connected = [n for n in new_servers if n in _servers]
|
||||
new_tool_count = sum(
|
||||
len(getattr(_servers[n], "_registered_tool_names", []))
|
||||
for n in connected
|
||||
)
|
||||
failed = len(new_servers) - len(connected)
|
||||
if new_tool_count or failed:
|
||||
summary = f"MCP: registered {new_tool_count} tool(s) from {len(connected)} server(s)"
|
||||
if failed:
|
||||
summary += f" ({failed} failed)"
|
||||
logger.info(summary)
|
||||
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def discover_mcp_tools() -> List[str]:
|
||||
"""Entry point: load config, connect to MCP servers, register tools.
|
||||
|
||||
@@ -1793,69 +1947,32 @@ def discover_mcp_tools() -> List[str]:
|
||||
logger.debug("No MCP servers configured")
|
||||
return []
|
||||
|
||||
# Only attempt servers that aren't already connected and are enabled
|
||||
# (enabled: false skips the server entirely without removing its config)
|
||||
with _lock:
|
||||
new_servers = {
|
||||
k: v
|
||||
for k, v in servers.items()
|
||||
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
new_server_names = [
|
||||
name
|
||||
for name, cfg in servers.items()
|
||||
if name not in _servers and _parse_boolish(cfg.get("enabled", True), default=True)
|
||||
]
|
||||
|
||||
if not new_servers:
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
return _existing_tool_names()
|
||||
tool_names = register_mcp_servers(servers)
|
||||
if not new_server_names:
|
||||
return tool_names
|
||||
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
all_tools: List[str] = []
|
||||
failed_count = 0
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
return await _discover_and_register_server(name, cfg)
|
||||
|
||||
async def _discover_all():
|
||||
nonlocal failed_count
|
||||
server_names = list(new_servers.keys())
|
||||
# Connect to all servers in PARALLEL
|
||||
results = await asyncio.gather(
|
||||
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
||||
return_exceptions=True,
|
||||
with _lock:
|
||||
connected_server_names = [name for name in new_server_names if name in _servers]
|
||||
new_tool_count = sum(
|
||||
len(getattr(_servers[name], "_registered_tool_names", []))
|
||||
for name in connected_server_names
|
||||
)
|
||||
for name, result in zip(server_names, results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
command = new_servers.get(name, {}).get("command")
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s'%s: %s",
|
||||
name,
|
||||
f" (command={command})" if command else "",
|
||||
_format_connect_error(result),
|
||||
)
|
||||
elif isinstance(result, list):
|
||||
all_tools.extend(result)
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# Per-server timeouts are handled inside _discover_and_register_server.
|
||||
# The outer timeout is generous: 120s total for parallel discovery.
|
||||
_run_on_mcp_loop(_discover_all(), timeout=120)
|
||||
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
|
||||
# Print summary
|
||||
total_servers = len(new_servers)
|
||||
ok_servers = total_servers - failed_count
|
||||
if all_tools or failed_count:
|
||||
summary = f" MCP: {len(all_tools)} tool(s) from {ok_servers} server(s)"
|
||||
failed_count = len(new_server_names) - len(connected_server_names)
|
||||
if new_tool_count or failed_count:
|
||||
summary = f" MCP: {new_tool_count} tool(s) from {len(connected_server_names)} server(s)"
|
||||
if failed_count:
|
||||
summary += f" ({failed_count} failed)"
|
||||
logger.info(summary)
|
||||
|
||||
# Return ALL registered tools (existing + newly discovered)
|
||||
return _existing_tool_names()
|
||||
return tool_names
|
||||
|
||||
|
||||
def get_mcp_status() -> List[dict]:
|
||||
@@ -2004,6 +2121,29 @@ def shutdown_mcp_servers():
|
||||
_stop_mcp_loop()
|
||||
|
||||
|
||||
def _kill_orphaned_mcp_children() -> None:
|
||||
"""Best-effort kill of MCP stdio subprocesses that survived loop shutdown.
|
||||
|
||||
After the MCP event loop is stopped, stdio server subprocesses *should*
|
||||
have been terminated by the SDK's context-manager cleanup. If the loop
|
||||
was stuck or the shutdown timed out, orphaned children may remain.
|
||||
|
||||
Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children.
|
||||
"""
|
||||
import signal as _signal
|
||||
|
||||
with _lock:
|
||||
pids = list(_stdio_pids)
|
||||
_stdio_pids.clear()
|
||||
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(pid, _signal.SIGKILL)
|
||||
logger.debug("Force-killed orphaned MCP stdio process %d", pid)
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass # Already exited or inaccessible
|
||||
|
||||
|
||||
def _stop_mcp_loop():
|
||||
"""Stop the background event loop and join its thread."""
|
||||
global _mcp_loop, _mcp_thread
|
||||
@@ -2016,4 +2156,10 @@ def _stop_mcp_loop():
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
if thread is not None:
|
||||
thread.join(timeout=5)
|
||||
loop.close()
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
# After closing the loop, any stdio subprocesses that survived the
|
||||
# graceful shutdown are now orphaned. Force-kill them.
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
@@ -127,8 +127,12 @@ def is_stt_enabled(stt_config: Optional[dict] = None) -> bool:
|
||||
|
||||
|
||||
def _has_openai_audio_backend() -> bool:
|
||||
"""Return True when OpenAI audio can use direct credentials or the managed gateway."""
|
||||
return bool(resolve_openai_audio_api_key() or resolve_managed_tool_gateway("openai-audio"))
|
||||
"""Return True when OpenAI audio can use config credentials, env credentials, or the managed gateway."""
|
||||
try:
|
||||
_resolve_openai_audio_client_config()
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _find_binary(binary_name: str) -> Optional[str]:
|
||||
@@ -577,13 +581,20 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
||||
|
||||
def _resolve_openai_audio_client_config() -> tuple[str, str]:
|
||||
"""Return direct OpenAI audio config or a managed gateway fallback."""
|
||||
stt_config = _load_stt_config()
|
||||
openai_cfg = stt_config.get("openai", {})
|
||||
cfg_api_key = openai_cfg.get("api_key", "")
|
||||
cfg_base_url = openai_cfg.get("base_url", "")
|
||||
if cfg_api_key:
|
||||
return cfg_api_key, (cfg_base_url or OPENAI_BASE_URL)
|
||||
|
||||
direct_api_key = resolve_openai_audio_api_key()
|
||||
if direct_api_key:
|
||||
return direct_api_key, OPENAI_BASE_URL
|
||||
|
||||
managed_gateway = resolve_managed_tool_gateway("openai-audio")
|
||||
if managed_gateway is None:
|
||||
message = "Neither VOICE_TOOLS_OPENAI_KEY nor OPENAI_API_KEY is set"
|
||||
message = "Neither stt.openai.api_key in config nor VOICE_TOOLS_OPENAI_KEY/OPENAI_API_KEY is set"
|
||||
if managed_nous_tools_enabled():
|
||||
message += ", and the managed OpenAI audio gateway is unavailable"
|
||||
raise ValueError(message)
|
||||
|
||||
@@ -527,6 +527,187 @@ There is no hard limit. Each profile is just a directory under `~/.hermes/profil
|
||||
|
||||
---
|
||||
|
||||
## Workflows & Patterns
|
||||
|
||||
### Using different models for different tasks (multi-model workflows)
|
||||
|
||||
**Scenario:** You use GPT-5.4 as your daily driver, but Gemini or Grok writes better social media content. Manually switching models every time is tedious.
|
||||
|
||||
**Solution: Delegation config.** Hermes can route subagents to a different model automatically. Set this in `~/.hermes/config.yaml`:
|
||||
|
||||
```yaml
|
||||
delegation:
|
||||
model: "google/gemini-3-flash-preview" # subagents use this model
|
||||
provider: "openrouter" # provider for subagents
|
||||
```
|
||||
|
||||
Now when you tell Hermes "write me a Twitter thread about X" and it spawns a `delegate_task` subagent, that subagent runs on Gemini instead of your main model. Your primary conversation stays on GPT-5.4.
|
||||
|
||||
You can also be explicit in your prompt: *"Delegate a task to write social media posts about our product launch. Use your subagent for the actual writing."* The agent will use `delegate_task`, which automatically picks up the delegation config.
|
||||
|
||||
For one-off model switches without delegation, use `/model` in the CLI:
|
||||
|
||||
```bash
|
||||
/model google/gemini-3-flash-preview # switch for this session
|
||||
# ... write your content ...
|
||||
/model openai/gpt-5.4 # switch back
|
||||
```
|
||||
|
||||
See [Subagent Delegation](../user-guide/features/delegation.md) for more on how delegation works.
|
||||
|
||||
### Running multiple agents on one WhatsApp number (per-chat binding)
|
||||
|
||||
**Scenario:** In OpenClaw, you had multiple independent agents bound to specific WhatsApp chats — one for a family shopping list group, another for your private chat. Can Hermes do this?
|
||||
|
||||
**Current limitation:** Hermes profiles each require their own WhatsApp number/session. You cannot bind multiple profiles to different chats on the same WhatsApp number — the WhatsApp bridge (Baileys) uses one authenticated session per number.
|
||||
|
||||
**Workarounds:**
|
||||
|
||||
1. **Use a single profile with personality switching.** Create different `AGENTS.md` context files or use the `/personality` command to change behavior per chat. The agent sees which chat it's in and can adapt.
|
||||
|
||||
2. **Use cron jobs for specialized tasks.** For a shopping list tracker, set up a cron job that monitors a specific chat and manages the list — no separate agent needed.
|
||||
|
||||
3. **Use separate numbers.** If you need truly independent agents, pair each profile with its own WhatsApp number. Virtual numbers from services like Google Voice work for this.
|
||||
|
||||
4. **Use Telegram or Discord instead.** These platforms support per-chat binding more naturally — each Telegram group or Discord channel gets its own session, and you can run multiple bot tokens (one per profile) on the same account.
|
||||
|
||||
See [Profiles](../user-guide/profiles.md) and [WhatsApp setup](../user-guide/messaging/whatsapp.md) for more details.
|
||||
|
||||
### Controlling what shows up in Telegram (hiding logs and reasoning)
|
||||
|
||||
**Scenario:** You see gateway exec logs, Hermes reasoning, and tool call details in Telegram instead of just the final output.
|
||||
|
||||
**Solution:** The `display.tool_progress` setting in `config.yaml` controls how much tool activity is shown:
|
||||
|
||||
```yaml
|
||||
display:
|
||||
tool_progress: "off" # options: off, new, all, verbose
|
||||
```
|
||||
|
||||
- **`off`** — Only the final response. No tool calls, no reasoning, no logs.
|
||||
- **`new`** — Shows new tool calls as they happen (brief one-liners).
|
||||
- **`all`** — Shows all tool activity including results.
|
||||
- **`verbose`** — Full detail including tool arguments and outputs.
|
||||
|
||||
For messaging platforms, `off` or `new` is usually what you want. After editing `config.yaml`, restart the gateway for changes to take effect.
|
||||
|
||||
You can also toggle this per-session with the `/verbose` command (if enabled):
|
||||
|
||||
```yaml
|
||||
display:
|
||||
tool_progress_command: true # enables /verbose in the gateway
|
||||
```
|
||||
|
||||
### Managing skills on Telegram (slash command limit)
|
||||
|
||||
**Scenario:** Telegram has a 100 slash command limit, and your skills are pushing past it. You want to disable skills you don't need on Telegram, but `hermes skills config` settings don't seem to take effect.
|
||||
|
||||
**Solution:** Use `hermes skills config` to disable skills per-platform. This writes to `config.yaml`:
|
||||
|
||||
```yaml
|
||||
skills:
|
||||
disabled: [] # globally disabled skills
|
||||
platform_disabled:
|
||||
telegram: [skill-a, skill-b] # disabled only on telegram
|
||||
```
|
||||
|
||||
After changing this, **restart the gateway** (`hermes gateway restart` or kill and relaunch). The Telegram bot command menu rebuilds on startup.
|
||||
|
||||
:::tip
|
||||
Skills with very long descriptions are truncated to 40 characters in the Telegram menu to stay within payload size limits. If skills aren't appearing, it may be a total payload size issue rather than the 100 command count limit — disabling unused skills helps with both.
|
||||
:::
|
||||
|
||||
### Shared thread sessions (multiple users, one conversation)
|
||||
|
||||
**Scenario:** You have a Telegram or Discord thread where multiple people mention the bot. You want all mentions in that thread to be part of one shared conversation, not separate per-user sessions.
|
||||
|
||||
**Current behavior:** Hermes creates sessions keyed by user ID on most platforms, so each person gets their own conversation context. This is by design for privacy and context isolation.
|
||||
|
||||
**Workarounds:**
|
||||
|
||||
1. **Use Slack.** Slack sessions are keyed by thread, not by user. Multiple users in the same thread share one conversation — exactly the behavior you're describing. This is the most natural fit.
|
||||
|
||||
2. **Use a group chat with a single user.** If one person is the designated "operator" who relays questions, the session stays unified. Others can read along.
|
||||
|
||||
3. **Use a Discord channel.** Discord sessions are keyed by channel, so all users in the same channel share context. Use a dedicated channel for the shared conversation.
|
||||
|
||||
### Exporting Hermes to another machine
|
||||
|
||||
**Scenario:** You've built up skills, cron jobs, and memories on one machine and want to move everything to a new dedicated Linux box.
|
||||
|
||||
**Solution:**
|
||||
|
||||
1. Install Hermes Agent on the new machine:
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
2. Copy your entire `~/.hermes/` directory **except** the `hermes-agent` subdirectory (that's the code repo — the new install has its own):
|
||||
```bash
|
||||
# On the source machine
|
||||
rsync -av --exclude='hermes-agent' ~/.hermes/ newmachine:~/.hermes/
|
||||
```
|
||||
|
||||
Or use profile export/import:
|
||||
```bash
|
||||
# On source machine
|
||||
hermes profile export default ./hermes-backup.tar.gz
|
||||
|
||||
# On target machine
|
||||
hermes profile import ./hermes-backup.tar.gz default
|
||||
```
|
||||
|
||||
3. On the new machine, run `hermes setup` to verify API keys and provider config are working. Re-authenticate any messaging platforms (especially WhatsApp, which uses QR pairing).
|
||||
|
||||
The `~/.hermes/` directory contains everything: `config.yaml`, `.env`, `SOUL.md`, `memories/`, `skills/`, `state.db` (sessions), `cron/`, and any custom plugins. The code itself lives in `~/.hermes/hermes-agent/` and is installed fresh.
|
||||
|
||||
### Permission denied when reloading shell after install
|
||||
|
||||
**Scenario:** After running the Hermes installer, `source ~/.zshrc` gives a permission denied error.
|
||||
|
||||
**Cause:** This usually happens when `~/.zshrc` (or `~/.bashrc`) has incorrect file permissions, or when the installer couldn't write to it cleanly. It's not a Hermes-specific issue — it's a shell config permissions problem.
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Check permissions
|
||||
ls -la ~/.zshrc
|
||||
|
||||
# Fix if needed (should be -rw-r--r-- or 644)
|
||||
chmod 644 ~/.zshrc
|
||||
|
||||
# Then reload
|
||||
source ~/.zshrc
|
||||
|
||||
# Or just open a new terminal window — it picks up PATH changes automatically
|
||||
```
|
||||
|
||||
If the installer added the PATH line but permissions are wrong, you can add it manually:
|
||||
```bash
|
||||
echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.zshrc
|
||||
```
|
||||
|
||||
### Error 400 on first agent run
|
||||
|
||||
**Scenario:** Setup completes fine, but the first chat attempt fails with HTTP 400.
|
||||
|
||||
**Cause:** Usually a model name mismatch — the configured model doesn't exist on your provider, or the API key doesn't have access to it.
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Check what model and provider are configured
|
||||
hermes config show | head -20
|
||||
|
||||
# Re-run model selection
|
||||
hermes model
|
||||
|
||||
# Or test with a known-good model
|
||||
hermes chat -q "hello" --model anthropic/claude-sonnet-4.6
|
||||
```
|
||||
|
||||
If using OpenRouter, make sure your API key has credits. A 400 from OpenRouter often means the model requires a paid plan or the model ID has a typo.
|
||||
|
||||
---
|
||||
|
||||
## Still Stuck?
|
||||
|
||||
If your issue isn't covered here:
|
||||
|
||||
@@ -88,14 +88,13 @@ Example settings snippet:
|
||||
|
||||
```json
|
||||
{
|
||||
"acp": {
|
||||
"agents": [
|
||||
{
|
||||
"name": "hermes-agent",
|
||||
"registry_dir": "/path/to/hermes-agent/acp_registry"
|
||||
}
|
||||
]
|
||||
}
|
||||
"agent_servers": {
|
||||
"hermes-agent": {
|
||||
"type": "custom",
|
||||
"command": "hermes",
|
||||
"args": ["acp"],
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Reference in New Issue
Block a user