mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-01 00:11:39 +08:00
Compare commits
9 Commits
opencode-p
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4080f6352b | ||
|
|
9e26b8f024 | ||
|
|
3659e1f0c2 | ||
|
|
21c2d32471 | ||
|
|
f66b3fe76b | ||
|
|
9aa82d4807 | ||
|
|
9b2fb1cc2e | ||
|
|
29c98e8f83 | ||
|
|
9e0fc62650 |
@@ -22,6 +22,9 @@ from acp.schema import (
|
||||
InitializeResponse,
|
||||
ListSessionsResponse,
|
||||
LoadSessionResponse,
|
||||
McpServerHttp,
|
||||
McpServerSse,
|
||||
McpServerStdio,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
@@ -93,6 +96,71 @@ class HermesACPAgent(acp.Agent):
|
||||
self._conn = conn
|
||||
logger.info("ACP client connected")
|
||||
|
||||
async def _register_session_mcp_servers(
|
||||
self,
|
||||
state: SessionState,
|
||||
mcp_servers: list[McpServerStdio | McpServerHttp | McpServerSse] | None,
|
||||
) -> None:
|
||||
"""Register ACP-provided MCP servers and refresh the agent tool surface."""
|
||||
if not mcp_servers:
|
||||
return
|
||||
|
||||
try:
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
config_map: dict[str, dict] = {}
|
||||
for server in mcp_servers:
|
||||
name = server.name
|
||||
if isinstance(server, McpServerStdio):
|
||||
config = {
|
||||
"command": server.command,
|
||||
"args": list(server.args),
|
||||
"env": {item.name: item.value for item in server.env},
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"url": server.url,
|
||||
"headers": {item.name: item.value for item in server.headers},
|
||||
}
|
||||
config_map[name] = config
|
||||
|
||||
await asyncio.to_thread(register_mcp_servers, config_map)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Session %s: failed to register ACP MCP servers",
|
||||
state.session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from model_tools import get_tool_definitions
|
||||
|
||||
enabled_toolsets = getattr(state.agent, "enabled_toolsets", None) or ["hermes-acp"]
|
||||
disabled_toolsets = getattr(state.agent, "disabled_toolsets", None)
|
||||
state.agent.tools = get_tool_definitions(
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets,
|
||||
quiet_mode=True,
|
||||
)
|
||||
state.agent.valid_tool_names = {
|
||||
tool["function"]["name"] for tool in state.agent.tools or []
|
||||
}
|
||||
invalidate = getattr(state.agent, "_invalidate_system_prompt", None)
|
||||
if callable(invalidate):
|
||||
invalidate()
|
||||
logger.info(
|
||||
"Session %s: refreshed tool surface after ACP MCP registration (%d tools)",
|
||||
state.session_id,
|
||||
len(state.agent.tools or []),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Session %s: failed to refresh tool surface after ACP MCP registration",
|
||||
state.session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ---- ACP lifecycle ------------------------------------------------------
|
||||
|
||||
async def initialize(
|
||||
@@ -149,6 +217,7 @@ class HermesACPAgent(acp.Agent):
|
||||
**kwargs: Any,
|
||||
) -> NewSessionResponse:
|
||||
state = self.session_manager.create_session(cwd=cwd)
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
||||
return NewSessionResponse(session_id=state.session_id)
|
||||
|
||||
@@ -163,6 +232,7 @@ class HermesACPAgent(acp.Agent):
|
||||
if state is None:
|
||||
logger.warning("load_session: session %s not found", session_id)
|
||||
return None
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Loaded session %s", session_id)
|
||||
return LoadSessionResponse()
|
||||
|
||||
@@ -177,6 +247,7 @@ class HermesACPAgent(acp.Agent):
|
||||
if state is None:
|
||||
logger.warning("resume_session: session %s not found, creating new", session_id)
|
||||
state = self.session_manager.create_session(cwd=cwd)
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Resumed session %s", state.session_id)
|
||||
return ResumeSessionResponse()
|
||||
|
||||
@@ -200,6 +271,8 @@ class HermesACPAgent(acp.Agent):
|
||||
) -> ForkSessionResponse:
|
||||
state = self.session_manager.fork_session(session_id, cwd=cwd)
|
||||
new_id = state.session_id if state else ""
|
||||
if state is not None:
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Forked session %s -> %s", session_id, new_id)
|
||||
return ForkSessionResponse(session_id=new_id)
|
||||
|
||||
|
||||
@@ -563,6 +563,18 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["TELEGRAM_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
|
||||
whatsapp_cfg = yaml_cfg.get("whatsapp", {})
|
||||
if isinstance(whatsapp_cfg, dict):
|
||||
if "require_mention" in whatsapp_cfg and not os.getenv("WHATSAPP_REQUIRE_MENTION"):
|
||||
os.environ["WHATSAPP_REQUIRE_MENTION"] = str(whatsapp_cfg["require_mention"]).lower()
|
||||
if "mention_patterns" in whatsapp_cfg and not os.getenv("WHATSAPP_MENTION_PATTERNS"):
|
||||
os.environ["WHATSAPP_MENTION_PATTERNS"] = json.dumps(whatsapp_cfg["mention_patterns"])
|
||||
frc = whatsapp_cfg.get("free_response_chats")
|
||||
if frc is not None and not os.getenv("WHATSAPP_FREE_RESPONSE_CHATS"):
|
||||
if isinstance(frc, list):
|
||||
frc = ",".join(str(v) for v in frc)
|
||||
os.environ["WHATSAPP_FREE_RESPONSE_CHATS"] = str(frc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process config.yaml — falling back to .env / gateway.json values. "
|
||||
|
||||
@@ -16,9 +16,11 @@ with different backends via a bridge pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
@@ -138,12 +140,137 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
get_hermes_dir("platforms/whatsapp/session", "whatsapp/session")
|
||||
))
|
||||
self._reply_prefix: Optional[str] = config.extra.get("reply_prefix")
|
||||
self._mention_patterns = self._compile_mention_patterns()
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Optional[Path] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._session_lock_identity: Optional[str] = None
|
||||
|
||||
def _whatsapp_require_mention(self) -> bool:
|
||||
configured = self.config.extra.get("require_mention")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() in ("true", "1", "yes", "on")
|
||||
return bool(configured)
|
||||
return os.getenv("WHATSAPP_REQUIRE_MENTION", "false").lower() in ("true", "1", "yes", "on")
|
||||
|
||||
def _whatsapp_free_response_chats(self) -> set[str]:
|
||||
raw = self.config.extra.get("free_response_chats")
|
||||
if raw is None:
|
||||
raw = os.getenv("WHATSAPP_FREE_RESPONSE_CHATS", "")
|
||||
if isinstance(raw, list):
|
||||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
return {part.strip() for part in str(raw).split(",") if part.strip()}
|
||||
|
||||
def _compile_mention_patterns(self):
|
||||
patterns = self.config.extra.get("mention_patterns")
|
||||
if patterns is None:
|
||||
raw = os.getenv("WHATSAPP_MENTION_PATTERNS", "").strip()
|
||||
if raw:
|
||||
try:
|
||||
patterns = json.loads(raw)
|
||||
except Exception:
|
||||
patterns = [part.strip() for part in raw.splitlines() if part.strip()]
|
||||
if not patterns:
|
||||
patterns = [part.strip() for part in raw.split(",") if part.strip()]
|
||||
if patterns is None:
|
||||
return []
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
if not isinstance(patterns, list):
|
||||
logger.warning("[%s] whatsapp mention_patterns must be a list or string; got %s", self.name, type(patterns).__name__)
|
||||
return []
|
||||
|
||||
compiled = []
|
||||
for pattern in patterns:
|
||||
if not isinstance(pattern, str) or not pattern.strip():
|
||||
continue
|
||||
try:
|
||||
compiled.append(re.compile(pattern, re.IGNORECASE))
|
||||
except re.error as exc:
|
||||
logger.warning("[%s] Invalid WhatsApp mention pattern %r: %s", self.name, pattern, exc)
|
||||
if compiled:
|
||||
logger.info("[%s] Loaded %d WhatsApp mention pattern(s)", self.name, len(compiled))
|
||||
return compiled
|
||||
|
||||
@staticmethod
|
||||
def _normalize_whatsapp_id(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
normalized = str(value).strip()
|
||||
if ":" in normalized and "@" in normalized:
|
||||
normalized = normalized.replace(":", "@", 1)
|
||||
return normalized
|
||||
|
||||
def _bot_ids_from_message(self, data: Dict[str, Any]) -> set[str]:
|
||||
bot_ids = set()
|
||||
for candidate in data.get("botIds") or []:
|
||||
normalized = self._normalize_whatsapp_id(candidate)
|
||||
if normalized:
|
||||
bot_ids.add(normalized)
|
||||
return bot_ids
|
||||
|
||||
def _message_is_reply_to_bot(self, data: Dict[str, Any]) -> bool:
|
||||
quoted_participant = self._normalize_whatsapp_id(data.get("quotedParticipant"))
|
||||
if not quoted_participant:
|
||||
return False
|
||||
return quoted_participant in self._bot_ids_from_message(data)
|
||||
|
||||
def _message_mentions_bot(self, data: Dict[str, Any]) -> bool:
|
||||
bot_ids = self._bot_ids_from_message(data)
|
||||
if not bot_ids:
|
||||
return False
|
||||
mentioned_ids = {
|
||||
nid
|
||||
for candidate in (data.get("mentionedIds") or [])
|
||||
if (nid := self._normalize_whatsapp_id(candidate))
|
||||
}
|
||||
if mentioned_ids & bot_ids:
|
||||
return True
|
||||
|
||||
body = str(data.get("body") or "")
|
||||
lower_body = body.lower()
|
||||
for bot_id in bot_ids:
|
||||
bare_id = bot_id.split("@", 1)[0].lower()
|
||||
if bare_id and (f"@{bare_id}" in lower_body or bare_id in lower_body):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _message_matches_mention_patterns(self, data: Dict[str, Any]) -> bool:
|
||||
if not self._mention_patterns:
|
||||
return False
|
||||
body = str(data.get("body") or "")
|
||||
return any(pattern.search(body) for pattern in self._mention_patterns)
|
||||
|
||||
def _clean_bot_mention_text(self, text: str, data: Dict[str, Any]) -> str:
|
||||
if not text:
|
||||
return text
|
||||
bot_ids = self._bot_ids_from_message(data)
|
||||
cleaned = text
|
||||
for bot_id in bot_ids:
|
||||
bare_id = bot_id.split("@", 1)[0]
|
||||
if bare_id:
|
||||
cleaned = re.sub(rf"@{re.escape(bare_id)}\b[,:\-]*\s*", "", cleaned)
|
||||
return cleaned.strip() or text
|
||||
|
||||
def _should_process_message(self, data: Dict[str, Any]) -> bool:
|
||||
if not data.get("isGroup"):
|
||||
return True
|
||||
chat_id = str(data.get("chatId") or "")
|
||||
if chat_id in self._whatsapp_free_response_chats():
|
||||
return True
|
||||
if not self._whatsapp_require_mention():
|
||||
return True
|
||||
body = str(data.get("body") or "").strip()
|
||||
if body.startswith("/"):
|
||||
return True
|
||||
if self._message_is_reply_to_bot(data):
|
||||
return True
|
||||
if self._message_mentions_bot(data):
|
||||
return True
|
||||
return self._message_matches_mention_patterns(data)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
@@ -687,6 +814,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
"""Build a MessageEvent from bridge message data, downloading images to cache."""
|
||||
try:
|
||||
if not self._should_process_message(data):
|
||||
return None
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if data.get("hasMedia"):
|
||||
@@ -768,6 +898,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
# the message text so the agent can read it inline.
|
||||
# Cap at 100KB to match Telegram/Discord/Slack behaviour.
|
||||
body = data.get("body", "")
|
||||
if data.get("isGroup"):
|
||||
body = self._clean_bot_mention_text(body, data)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if msg_type == MessageType.DOCUMENT and cached_urls:
|
||||
for doc_path in cached_urls:
|
||||
|
||||
@@ -5468,15 +5468,25 @@ class GatewayRunner:
|
||||
_loop_for_step = asyncio.get_event_loop()
|
||||
_hooks_ref = self.hooks
|
||||
|
||||
def _step_callback_sync(iteration: int, tool_names: list) -> None:
|
||||
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
||||
try:
|
||||
# prev_tools may be list[str] or list[dict] with "name"/"result"
|
||||
# keys. Normalise to keep "tool_names" backward-compatible for
|
||||
# user-authored hooks that do ', '.join(tool_names)'.
|
||||
_names: list[str] = []
|
||||
for _t in (prev_tools or []):
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_hooks_ref.emit("agent:step", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
"user_id": source.user_id,
|
||||
"session_id": session_id,
|
||||
"iteration": iteration,
|
||||
"tool_names": tool_names,
|
||||
"tool_names": _names,
|
||||
"tools": prev_tools,
|
||||
}),
|
||||
_loop_for_step,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
@@ -108,6 +109,9 @@ CONCLUDE_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
ALL_TOOL_SCHEMAS = [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, CONCLUDE_SCHEMA]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -124,6 +128,34 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
self._prefetch_thread: Optional[threading.Thread] = None
|
||||
self._sync_thread: Optional[threading.Thread] = None
|
||||
|
||||
# B1: recall_mode — set during initialize from config
|
||||
self._recall_mode = "hybrid" # "context", "tools", or "hybrid"
|
||||
|
||||
# B4: First-turn context baking
|
||||
self._first_turn_context: Optional[str] = None
|
||||
self._first_turn_lock = threading.Lock()
|
||||
|
||||
# B5: Cost-awareness turn counting and cadence
|
||||
self._turn_count = 0
|
||||
self._injection_frequency = "every-turn" # or "first-turn"
|
||||
self._context_cadence = 1 # minimum turns between context API calls
|
||||
self._dialectic_cadence = 1 # minimum turns between dialectic API calls
|
||||
self._reasoning_level_cap: Optional[str] = None # "minimal", "low", "mid", "high"
|
||||
self._last_context_turn = -999
|
||||
self._last_dialectic_turn = -999
|
||||
|
||||
# B2: peer_memory_mode gating (stub)
|
||||
self._suppress_memory = False
|
||||
self._suppress_user_profile = False
|
||||
|
||||
# Port #1957: lazy session init for tools-only mode
|
||||
self._session_initialized = False
|
||||
self._lazy_init_kwargs: Optional[dict] = None
|
||||
self._lazy_init_session_id: Optional[str] = None
|
||||
|
||||
# Port #4053: cron guard — when True, plugin is fully inactive
|
||||
self._cron_skipped = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "honcho"
|
||||
@@ -133,6 +165,7 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
try:
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
cfg = HonchoClientConfig.from_global_config()
|
||||
# Port #2645: baseUrl-only verification — api_key OR base_url suffices
|
||||
return cfg.enabled and bool(cfg.api_key or cfg.base_url)
|
||||
except Exception:
|
||||
return False
|
||||
@@ -158,8 +191,22 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
]
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
"""Initialize Honcho session manager."""
|
||||
"""Initialize Honcho session manager.
|
||||
|
||||
Handles: cron guard, recall_mode, session name resolution,
|
||||
peer memory mode, SOUL.md ai_peer sync, memory file migration,
|
||||
and pre-warming context at init.
|
||||
"""
|
||||
try:
|
||||
# ----- Port #4053: cron guard -----
|
||||
agent_context = kwargs.get("agent_context", "")
|
||||
platform = kwargs.get("platform", "cli")
|
||||
if agent_context in ("cron", "flush") or platform == "cron":
|
||||
logger.debug("Honcho skipped: cron/flush context (agent_context=%s, platform=%s)",
|
||||
agent_context, platform)
|
||||
self._cron_skipped = True
|
||||
return
|
||||
|
||||
from plugins.memory.honcho.client import HonchoClientConfig, get_honcho_client
|
||||
from plugins.memory.honcho.session import HonchoSessionManager
|
||||
|
||||
@@ -169,20 +216,78 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
return
|
||||
|
||||
self._config = cfg
|
||||
client = get_honcho_client(cfg)
|
||||
self._manager = HonchoSessionManager(
|
||||
honcho=client,
|
||||
config=cfg,
|
||||
context_tokens=cfg.context_tokens,
|
||||
)
|
||||
|
||||
# Build session key from kwargs or session_id
|
||||
platform = kwargs.get("platform", "cli")
|
||||
user_id = kwargs.get("user_id", "")
|
||||
if user_id:
|
||||
self._session_key = f"{platform}:{user_id}"
|
||||
else:
|
||||
self._session_key = session_id
|
||||
# ----- B1: recall_mode from config -----
|
||||
self._recall_mode = cfg.recall_mode # "context", "tools", or "hybrid"
|
||||
logger.debug("Honcho recall_mode: %s", self._recall_mode)
|
||||
|
||||
# ----- B5: cost-awareness config -----
|
||||
try:
|
||||
raw = cfg.raw or {}
|
||||
self._injection_frequency = raw.get("injectionFrequency", "every-turn")
|
||||
self._context_cadence = int(raw.get("contextCadence", 1))
|
||||
self._dialectic_cadence = int(raw.get("dialecticCadence", 1))
|
||||
cap = raw.get("reasoningLevelCap")
|
||||
if cap and cap in ("minimal", "low", "mid", "high"):
|
||||
self._reasoning_level_cap = cap
|
||||
except Exception as e:
|
||||
logger.debug("Honcho cost-awareness config parse error: %s", e)
|
||||
|
||||
# ----- Port #1969: aiPeer sync from SOUL.md -----
|
||||
try:
|
||||
hermes_home = kwargs.get("hermes_home", "")
|
||||
if hermes_home and not cfg.raw.get("aiPeer"):
|
||||
soul_path = Path(hermes_home) / "SOUL.md"
|
||||
if soul_path.exists():
|
||||
soul_text = soul_path.read_text(encoding="utf-8").strip()
|
||||
if soul_text:
|
||||
# Try YAML frontmatter: "name: Foo"
|
||||
first_line = soul_text.split("\n")[0].strip()
|
||||
if first_line.startswith("---"):
|
||||
# Look for name: in frontmatter
|
||||
for line in soul_text.split("\n")[1:]:
|
||||
line = line.strip()
|
||||
if line == "---":
|
||||
break
|
||||
if line.lower().startswith("name:"):
|
||||
name_val = line.split(":", 1)[1].strip().strip("\"'")
|
||||
if name_val:
|
||||
cfg.ai_peer = name_val
|
||||
logger.debug("Honcho ai_peer set from SOUL.md: %s", name_val)
|
||||
break
|
||||
elif first_line.startswith("# "):
|
||||
# Markdown heading: "# AgentName"
|
||||
name_val = first_line[2:].strip()
|
||||
if name_val:
|
||||
cfg.ai_peer = name_val
|
||||
logger.debug("Honcho ai_peer set from SOUL.md heading: %s", name_val)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho SOUL.md ai_peer sync failed: %s", e)
|
||||
|
||||
# ----- B2: peer_memory_mode gating (stub) -----
|
||||
try:
|
||||
ai_mode = cfg.peer_memory_mode(cfg.ai_peer)
|
||||
user_mode = cfg.peer_memory_mode(cfg.peer_name or "user")
|
||||
# "honcho" means Honcho owns memory; suppress built-in
|
||||
self._suppress_memory = (ai_mode == "honcho")
|
||||
self._suppress_user_profile = (user_mode == "honcho")
|
||||
logger.debug("Honcho peer_memory_mode: ai=%s (suppress_memory=%s), user=%s (suppress_user_profile=%s)",
|
||||
ai_mode, self._suppress_memory, user_mode, self._suppress_user_profile)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho peer_memory_mode check failed: %s", e)
|
||||
|
||||
# ----- Port #1957: lazy session init for tools-only mode -----
|
||||
if self._recall_mode == "tools":
|
||||
# Defer actual session creation until first tool call
|
||||
self._lazy_init_kwargs = kwargs
|
||||
self._lazy_init_session_id = session_id
|
||||
# Still need a client reference for _ensure_session
|
||||
self._config = cfg
|
||||
logger.debug("Honcho tools-only mode — deferring session init until first tool call")
|
||||
return
|
||||
|
||||
# ----- Eager init (context or hybrid mode) -----
|
||||
self._do_session_init(cfg, session_id, **kwargs)
|
||||
|
||||
except ImportError:
|
||||
logger.debug("honcho-ai package not installed — plugin inactive")
|
||||
@@ -190,19 +295,180 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
logger.warning("Honcho init failed: %s", e)
|
||||
self._manager = None
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._manager or not self._session_key:
|
||||
return ""
|
||||
return (
|
||||
"# Honcho Memory\n"
|
||||
"Active. AI-native cross-session user modeling.\n"
|
||||
"Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user."
|
||||
def _do_session_init(self, cfg, session_id: str, **kwargs) -> None:
|
||||
"""Shared session initialization logic for both eager and lazy paths."""
|
||||
from plugins.memory.honcho.client import get_honcho_client
|
||||
from plugins.memory.honcho.session import HonchoSessionManager
|
||||
|
||||
client = get_honcho_client(cfg)
|
||||
self._manager = HonchoSessionManager(
|
||||
honcho=client,
|
||||
config=cfg,
|
||||
context_tokens=cfg.context_tokens,
|
||||
)
|
||||
|
||||
# ----- B3: resolve_session_name -----
|
||||
session_title = kwargs.get("session_title")
|
||||
self._session_key = (
|
||||
cfg.resolve_session_name(session_title=session_title, session_id=session_id)
|
||||
or session_id
|
||||
or "hermes-default"
|
||||
)
|
||||
logger.debug("Honcho session key resolved: %s", self._session_key)
|
||||
|
||||
# Create session eagerly
|
||||
session = self._manager.get_or_create(self._session_key)
|
||||
self._session_initialized = True
|
||||
|
||||
# ----- B6: Memory file migration (one-time, for new sessions) -----
|
||||
try:
|
||||
if not session.messages:
|
||||
from hermes_constants import get_hermes_home
|
||||
mem_dir = str(get_hermes_home() / "memories")
|
||||
self._manager.migrate_memory_files(self._session_key, mem_dir)
|
||||
logger.debug("Honcho memory file migration attempted for new session: %s", self._session_key)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho memory file migration skipped: %s", e)
|
||||
|
||||
# ----- B7: Pre-warming context at init -----
|
||||
if self._recall_mode in ("context", "hybrid"):
|
||||
try:
|
||||
self._manager.prefetch_context(self._session_key)
|
||||
self._manager.prefetch_dialectic(self._session_key, "What should I know about this user?")
|
||||
logger.debug("Honcho pre-warm threads started for session: %s", self._session_key)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho pre-warm failed: %s", e)
|
||||
|
||||
def _ensure_session(self) -> bool:
|
||||
"""Lazily initialize the Honcho session (for tools-only mode).
|
||||
|
||||
Returns True if the manager is ready, False otherwise.
|
||||
"""
|
||||
if self._manager and self._session_initialized:
|
||||
return True
|
||||
if self._cron_skipped:
|
||||
return False
|
||||
if not self._config or not self._lazy_init_kwargs:
|
||||
return False
|
||||
|
||||
try:
|
||||
self._do_session_init(
|
||||
self._config,
|
||||
self._lazy_init_session_id or "hermes-default",
|
||||
**self._lazy_init_kwargs,
|
||||
)
|
||||
# Clear lazy refs
|
||||
self._lazy_init_kwargs = None
|
||||
self._lazy_init_session_id = None
|
||||
return self._manager is not None
|
||||
except Exception as e:
|
||||
logger.warning("Honcho lazy session init failed: %s", e)
|
||||
return False
|
||||
|
||||
def _format_first_turn_context(self, ctx: dict) -> str:
|
||||
"""Format the prefetch context dict into a readable system prompt block."""
|
||||
parts = []
|
||||
|
||||
rep = ctx.get("representation", "")
|
||||
if rep:
|
||||
parts.append(f"## User Representation\n{rep}")
|
||||
|
||||
card = ctx.get("card", "")
|
||||
if card:
|
||||
parts.append(f"## User Peer Card\n{card}")
|
||||
|
||||
ai_rep = ctx.get("ai_representation", "")
|
||||
if ai_rep:
|
||||
parts.append(f"## AI Self-Representation\n{ai_rep}")
|
||||
|
||||
ai_card = ctx.get("ai_card", "")
|
||||
if ai_card:
|
||||
parts.append(f"## AI Identity Card\n{ai_card}")
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
"""Return system prompt text, adapted by recall_mode.
|
||||
|
||||
B4: On the FIRST call, fetch and bake the full Honcho context
|
||||
(user representation, peer card, AI representation, continuity synthesis).
|
||||
Subsequent calls return the cached block for prompt caching stability.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return ""
|
||||
if not self._manager or not self._session_key:
|
||||
# tools-only mode without session yet still returns a minimal block
|
||||
if self._recall_mode == "tools" and self._config:
|
||||
return (
|
||||
"# Honcho Memory\n"
|
||||
"Active (tools-only mode). Use honcho_profile, honcho_search, "
|
||||
"honcho_context, and honcho_conclude tools to access user memory."
|
||||
)
|
||||
return ""
|
||||
|
||||
# ----- B4: First-turn context baking -----
|
||||
first_turn_block = ""
|
||||
if self._recall_mode in ("context", "hybrid"):
|
||||
with self._first_turn_lock:
|
||||
if self._first_turn_context is None:
|
||||
# First call — fetch and cache
|
||||
try:
|
||||
ctx = self._manager.get_prefetch_context(self._session_key)
|
||||
self._first_turn_context = self._format_first_turn_context(ctx) if ctx else ""
|
||||
except Exception as e:
|
||||
logger.debug("Honcho first-turn context fetch failed: %s", e)
|
||||
self._first_turn_context = ""
|
||||
first_turn_block = self._first_turn_context
|
||||
|
||||
# ----- B1: adapt text based on recall_mode -----
|
||||
if self._recall_mode == "context":
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (context-injection mode). Relevant user context is automatically "
|
||||
"injected before each turn. No memory tools are available — context is "
|
||||
"managed automatically."
|
||||
)
|
||||
elif self._recall_mode == "tools":
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (tools-only mode). Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user. "
|
||||
"No automatic context injection — you must use tools to access memory."
|
||||
)
|
||||
else: # hybrid
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (hybrid mode). Relevant context is auto-injected AND memory tools are available. "
|
||||
"Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user."
|
||||
)
|
||||
|
||||
if first_turn_block:
|
||||
return f"{header}\n\n{first_turn_block}"
|
||||
return header
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
"""Return prefetched dialectic context from background thread."""
|
||||
"""Return prefetched dialectic context from background thread.
|
||||
|
||||
B1: Returns empty when recall_mode is "tools" (no injection).
|
||||
B5: Respects injection_frequency — "first-turn" returns cached/empty after turn 0.
|
||||
Port #3265: Truncates to context_tokens budget.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return ""
|
||||
|
||||
# B1: tools-only mode — no auto-injection
|
||||
if self._recall_mode == "tools":
|
||||
return ""
|
||||
|
||||
# B5: injection_frequency — if "first-turn" and past first turn, return empty
|
||||
if self._injection_frequency == "first-turn" and self._turn_count > 0:
|
||||
return ""
|
||||
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
@@ -210,13 +476,49 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
return ""
|
||||
|
||||
# ----- Port #3265: token budget enforcement -----
|
||||
result = self._truncate_to_budget(result)
|
||||
|
||||
return f"## Honcho Context\n{result}"
|
||||
|
||||
def _truncate_to_budget(self, text: str) -> str:
|
||||
"""Truncate text to fit within context_tokens budget if set."""
|
||||
if not self._config or not self._config.context_tokens:
|
||||
return text
|
||||
budget_chars = self._config.context_tokens * 4 # conservative char estimate
|
||||
if len(text) <= budget_chars:
|
||||
return text
|
||||
# Truncate at word boundary
|
||||
truncated = text[:budget_chars]
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > budget_chars * 0.8:
|
||||
truncated = truncated[:last_space]
|
||||
return truncated + " …"
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
"""Fire a background dialectic query for the upcoming turn."""
|
||||
"""Fire a background dialectic query for the upcoming turn.
|
||||
|
||||
B5: Checks cadence before firing background threads.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager or not self._session_key or not query:
|
||||
return
|
||||
|
||||
# B1: tools-only mode — no prefetch
|
||||
if self._recall_mode == "tools":
|
||||
return
|
||||
|
||||
# B5: cadence check — skip if too soon since last dialectic call
|
||||
if self._dialectic_cadence > 1:
|
||||
if (self._turn_count - self._last_dialectic_turn) < self._dialectic_cadence:
|
||||
logger.debug("Honcho dialectic prefetch skipped: cadence %d, turns since last: %d",
|
||||
self._dialectic_cadence, self._turn_count - self._last_dialectic_turn)
|
||||
return
|
||||
|
||||
self._last_dialectic_turn = self._turn_count
|
||||
|
||||
def _run():
|
||||
try:
|
||||
result = self._manager.dialectic_query(
|
||||
@@ -233,14 +535,28 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
)
|
||||
self._prefetch_thread.start()
|
||||
|
||||
# Also fire context prefetch if cadence allows
|
||||
if self._context_cadence <= 1 or (self._turn_count - self._last_context_turn) >= self._context_cadence:
|
||||
self._last_context_turn = self._turn_count
|
||||
try:
|
||||
self._manager.prefetch_context(self._session_key, query)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho context prefetch failed: %s", e)
|
||||
|
||||
def on_turn_start(self, turn_number: int, message: str, **kwargs) -> None:
|
||||
"""Track turn count for cadence and injection_frequency logic."""
|
||||
self._turn_count = turn_number
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Record the conversation turn in Honcho (non-blocking)."""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager or not self._session_key:
|
||||
return
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
session = self._manager.get_or_create_session(self._session_key)
|
||||
session = self._manager.get_or_create(self._session_key)
|
||||
session.add_message("user", user_content[:4000])
|
||||
session.add_message("assistant", assistant_content[:4000])
|
||||
# Flush to Honcho API
|
||||
@@ -259,6 +575,8 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
"""Mirror built-in user profile writes as Honcho conclusions."""
|
||||
if action != "add" or target != "user" or not content:
|
||||
return
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager or not self._session_key:
|
||||
return
|
||||
|
||||
@@ -273,6 +591,8 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
|
||||
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Flush all pending messages to Honcho on session end."""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager:
|
||||
return
|
||||
# Wait for pending sync
|
||||
@@ -284,9 +604,26 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
logger.debug("Honcho session-end flush failed: %s", e)
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, CONCLUDE_SCHEMA]
|
||||
"""Return tool schemas, respecting recall_mode.
|
||||
|
||||
B1: context-only mode hides all tools.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return []
|
||||
if self._recall_mode == "context":
|
||||
return []
|
||||
return list(ALL_TOOL_SCHEMAS)
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
"""Handle a Honcho tool call, with lazy session init for tools-only mode."""
|
||||
if self._cron_skipped:
|
||||
return json.dumps({"error": "Honcho is not active (cron context)."})
|
||||
|
||||
# Port #1957: ensure session is initialized for tools-only mode
|
||||
if not self._session_initialized:
|
||||
if not self._ensure_session():
|
||||
return json.dumps({"error": "Honcho session could not be initialized."})
|
||||
|
||||
if not self._manager or not self._session_key:
|
||||
return json.dumps({"error": "Honcho is not active for this session."})
|
||||
|
||||
|
||||
@@ -85,6 +85,16 @@ def _normalize_recall_mode(val: str) -> str:
|
||||
return val if val in _VALID_RECALL_MODES else "hybrid"
|
||||
|
||||
|
||||
_VALID_OBSERVATION_MODES = {"unified", "directional"}
|
||||
_OBSERVATION_MODE_ALIASES = {"shared": "unified", "separate": "directional", "cross": "directional"}
|
||||
|
||||
|
||||
def _normalize_observation_mode(val: str) -> str:
|
||||
"""Normalize observation mode values."""
|
||||
val = _OBSERVATION_MODE_ALIASES.get(val, val)
|
||||
return val if val in _VALID_OBSERVATION_MODES else "unified"
|
||||
|
||||
|
||||
def _resolve_memory_mode(
|
||||
global_val: str | dict,
|
||||
host_val: str | dict | None,
|
||||
@@ -154,6 +164,10 @@ class HonchoClientConfig:
|
||||
# "context" — auto-injected context only, Honcho tools removed
|
||||
# "tools" — Honcho tools only, no auto-injected context
|
||||
recall_mode: str = "hybrid"
|
||||
# Observation mode: how Honcho peers observe each other.
|
||||
# "unified" — user peer observes self; all agents share one observation pool
|
||||
# "directional" — AI peer observes user; each agent keeps its own view
|
||||
observation_mode: str = "unified"
|
||||
# Session resolution
|
||||
session_strategy: str = "per-directory"
|
||||
session_peer_prefix: bool = False
|
||||
@@ -313,6 +327,11 @@ class HonchoClientConfig:
|
||||
or raw.get("recallMode")
|
||||
or "hybrid"
|
||||
),
|
||||
observation_mode=_normalize_observation_mode(
|
||||
host_block.get("observationMode")
|
||||
or raw.get("observationMode")
|
||||
or "unified"
|
||||
),
|
||||
session_strategy=session_strategy,
|
||||
session_peer_prefix=session_peer_prefix,
|
||||
sessions=raw.get("sessions", {}),
|
||||
|
||||
@@ -110,6 +110,9 @@ class HonchoSessionManager:
|
||||
self._dialectic_max_chars: int = (
|
||||
config.dialectic_max_chars if config else 600
|
||||
)
|
||||
self._observation_mode: str = (
|
||||
config.observation_mode if config else "unified"
|
||||
)
|
||||
|
||||
# Async write queue — started lazily on first enqueue
|
||||
self._async_queue: queue.Queue | None = None
|
||||
@@ -159,13 +162,18 @@ class HonchoSessionManager:
|
||||
|
||||
session = self.honcho.session(session_id)
|
||||
|
||||
# Configure peer observation settings.
|
||||
# observe_me=True for AI peer so Honcho watches what the agent says
|
||||
# and builds its representation over time — enabling identity formation.
|
||||
# Configure peer observation settings based on observation_mode.
|
||||
# Unified: user peer observes self, AI peer passive — all agents share
|
||||
# one observation pool via user self-observations.
|
||||
# Directional: AI peer observes user — each agent keeps its own view.
|
||||
try:
|
||||
from honcho.session import SessionPeerConfig
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=True)
|
||||
ai_config = SessionPeerConfig(observe_me=True, observe_others=True)
|
||||
if self._observation_mode == "directional":
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=False)
|
||||
ai_config = SessionPeerConfig(observe_me=False, observe_others=True)
|
||||
else: # unified (default)
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=False)
|
||||
ai_config = SessionPeerConfig(observe_me=False, observe_others=False)
|
||||
|
||||
session.add_peers([(user_peer, user_config), (assistant_peer, ai_config)])
|
||||
except Exception as e:
|
||||
@@ -493,12 +501,27 @@ class HonchoSessionManager:
|
||||
if not session:
|
||||
return ""
|
||||
|
||||
peer_id = session.assistant_peer_id if peer == "ai" else session.user_peer_id
|
||||
target_peer = self._get_or_create_peer(peer_id)
|
||||
level = reasoning_level or self._dynamic_reasoning_level(query)
|
||||
|
||||
try:
|
||||
result = target_peer.chat(query, reasoning_level=level) or ""
|
||||
if self._observation_mode == "directional":
|
||||
# AI peer queries about the user (cross-observation)
|
||||
if peer == "ai":
|
||||
ai_peer_obj = self._get_or_create_peer(session.assistant_peer_id)
|
||||
result = ai_peer_obj.chat(query, reasoning_level=level) or ""
|
||||
else:
|
||||
ai_peer_obj = self._get_or_create_peer(session.assistant_peer_id)
|
||||
result = ai_peer_obj.chat(
|
||||
query,
|
||||
target=session.user_peer_id,
|
||||
reasoning_level=level,
|
||||
) or ""
|
||||
else:
|
||||
# Unified: user peer queries self, or AI peer queries self
|
||||
peer_id = session.assistant_peer_id if peer == "ai" else session.user_peer_id
|
||||
target_peer = self._get_or_create_peer(peer_id)
|
||||
result = target_peer.chat(query, reasoning_level=level) or ""
|
||||
|
||||
# Apply Hermes-side char cap before caching
|
||||
if result and self._dialectic_max_chars and len(result) > self._dialectic_max_chars:
|
||||
result = result[:self._dialectic_max_chars].rsplit(" ", 1)[0] + " …"
|
||||
@@ -895,9 +918,16 @@ class HonchoSessionManager:
|
||||
logger.warning("No session cached for '%s', skipping conclusion", session_key)
|
||||
return False
|
||||
|
||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||
try:
|
||||
conclusions_scope = assistant_peer.conclusions_of(session.user_peer_id)
|
||||
if self._observation_mode == "directional":
|
||||
# AI peer creates conclusion about user (cross-observation)
|
||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||
conclusions_scope = assistant_peer.conclusions_of(session.user_peer_id)
|
||||
else:
|
||||
# Unified: user peer creates self-conclusion
|
||||
user_peer = self._get_or_create_peer(session.user_peer_id)
|
||||
conclusions_scope = user_peer.conclusions_of(session.user_peer_id)
|
||||
|
||||
conclusions_scope.create([{
|
||||
"content": content.strip(),
|
||||
"session_id": session.honcho_session_id,
|
||||
|
||||
15
run_agent.py
15
run_agent.py
@@ -6656,10 +6656,21 @@ class AIAgent:
|
||||
if self.step_callback is not None:
|
||||
try:
|
||||
prev_tools = []
|
||||
for _m in reversed(messages):
|
||||
for _idx, _m in enumerate(reversed(messages)):
|
||||
if _m.get("role") == "assistant" and _m.get("tool_calls"):
|
||||
_fwd_start = len(messages) - _idx
|
||||
_results_by_id = {}
|
||||
for _tm in messages[_fwd_start:]:
|
||||
if _tm.get("role") != "tool":
|
||||
break
|
||||
_tcid = _tm.get("tool_call_id")
|
||||
if _tcid:
|
||||
_results_by_id[_tcid] = _tm.get("content", "")
|
||||
prev_tools = [
|
||||
tc["function"]["name"]
|
||||
{
|
||||
"name": tc["function"]["name"],
|
||||
"result": _results_by_id.get(tc.get("id")),
|
||||
}
|
||||
for tc in _m["tool_calls"]
|
||||
if isinstance(tc, dict)
|
||||
]
|
||||
|
||||
@@ -62,6 +62,33 @@ function formatOutgoingMessage(message) {
|
||||
return REPLY_PREFIX ? `${REPLY_PREFIX}${message}` : message;
|
||||
}
|
||||
|
||||
function normalizeWhatsAppId(value) {
|
||||
if (!value) return '';
|
||||
return String(value).replace(':', '@');
|
||||
}
|
||||
|
||||
function getMessageContent(msg) {
|
||||
const content = msg?.message || {};
|
||||
if (content.ephemeralMessage?.message) return content.ephemeralMessage.message;
|
||||
if (content.viewOnceMessage?.message) return content.viewOnceMessage.message;
|
||||
if (content.viewOnceMessageV2?.message) return content.viewOnceMessageV2.message;
|
||||
if (content.documentWithCaptionMessage?.message) return content.documentWithCaptionMessage.message;
|
||||
if (content.templateMessage?.hydratedTemplate) return content.templateMessage.hydratedTemplate;
|
||||
if (content.buttonsMessage) return content.buttonsMessage;
|
||||
if (content.listMessage) return content.listMessage;
|
||||
return content;
|
||||
}
|
||||
|
||||
function getContextInfo(messageContent) {
|
||||
if (!messageContent || typeof messageContent !== 'object') return {};
|
||||
for (const value of Object.values(messageContent)) {
|
||||
if (value && typeof value === 'object' && value.contextInfo) {
|
||||
return value.contextInfo;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
mkdirSync(SESSION_DIR, { recursive: true });
|
||||
|
||||
// Build LID → phone reverse map from session files (lid-mapping-{phone}.json)
|
||||
@@ -157,6 +184,11 @@ async function startSocket() {
|
||||
// than 'notify'. Accept both and filter agent echo-backs below.
|
||||
if (type !== 'notify' && type !== 'append') return;
|
||||
|
||||
const botIds = Array.from(new Set([
|
||||
normalizeWhatsAppId(sock.user?.id),
|
||||
normalizeWhatsAppId(sock.user?.lid),
|
||||
].filter(Boolean)));
|
||||
|
||||
for (const msg of messages) {
|
||||
if (!msg.message) continue;
|
||||
|
||||
@@ -200,23 +232,28 @@ async function startSocket() {
|
||||
continue;
|
||||
}
|
||||
|
||||
const messageContent = getMessageContent(msg);
|
||||
const contextInfo = getContextInfo(messageContent);
|
||||
const mentionedIds = Array.from(new Set((contextInfo?.mentionedJid || []).map(normalizeWhatsAppId).filter(Boolean)));
|
||||
const quotedParticipant = normalizeWhatsAppId(contextInfo?.participant || contextInfo?.remoteJid || '');
|
||||
|
||||
// Extract message body
|
||||
let body = '';
|
||||
let hasMedia = false;
|
||||
let mediaType = '';
|
||||
const mediaUrls = [];
|
||||
|
||||
if (msg.message.conversation) {
|
||||
body = msg.message.conversation;
|
||||
} else if (msg.message.extendedTextMessage?.text) {
|
||||
body = msg.message.extendedTextMessage.text;
|
||||
} else if (msg.message.imageMessage) {
|
||||
body = msg.message.imageMessage.caption || '';
|
||||
if (messageContent.conversation) {
|
||||
body = messageContent.conversation;
|
||||
} else if (messageContent.extendedTextMessage?.text) {
|
||||
body = messageContent.extendedTextMessage.text;
|
||||
} else if (messageContent.imageMessage) {
|
||||
body = messageContent.imageMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'image';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = msg.message.imageMessage.mimetype || 'image/jpeg';
|
||||
const mime = messageContent.imageMessage.mimetype || 'image/jpeg';
|
||||
const extMap = { 'image/jpeg': '.jpg', 'image/png': '.png', 'image/webp': '.webp', 'image/gif': '.gif' };
|
||||
const ext = extMap[mime] || '.jpg';
|
||||
mkdirSync(IMAGE_CACHE_DIR, { recursive: true });
|
||||
@@ -226,13 +263,13 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download image:', err.message);
|
||||
}
|
||||
} else if (msg.message.videoMessage) {
|
||||
body = msg.message.videoMessage.caption || '';
|
||||
} else if (messageContent.videoMessage) {
|
||||
body = messageContent.videoMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'video';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = msg.message.videoMessage.mimetype || 'video/mp4';
|
||||
const mime = messageContent.videoMessage.mimetype || 'video/mp4';
|
||||
const ext = mime.includes('mp4') ? '.mp4' : '.mkv';
|
||||
mkdirSync(DOCUMENT_CACHE_DIR, { recursive: true });
|
||||
const filePath = path.join(DOCUMENT_CACHE_DIR, `vid_${randomBytes(6).toString('hex')}${ext}`);
|
||||
@@ -241,11 +278,11 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download video:', err.message);
|
||||
}
|
||||
} else if (msg.message.audioMessage || msg.message.pttMessage) {
|
||||
} else if (messageContent.audioMessage || messageContent.pttMessage) {
|
||||
hasMedia = true;
|
||||
mediaType = msg.message.pttMessage ? 'ptt' : 'audio';
|
||||
mediaType = messageContent.pttMessage ? 'ptt' : 'audio';
|
||||
try {
|
||||
const audioMsg = msg.message.pttMessage || msg.message.audioMessage;
|
||||
const audioMsg = messageContent.pttMessage || messageContent.audioMessage;
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
const mime = audioMsg.mimetype || 'audio/ogg';
|
||||
const ext = mime.includes('ogg') ? '.ogg' : mime.includes('mp4') ? '.m4a' : '.ogg';
|
||||
@@ -256,11 +293,11 @@ async function startSocket() {
|
||||
} catch (err) {
|
||||
console.error('[bridge] Failed to download audio:', err.message);
|
||||
}
|
||||
} else if (msg.message.documentMessage) {
|
||||
body = msg.message.documentMessage.caption || '';
|
||||
} else if (messageContent.documentMessage) {
|
||||
body = messageContent.documentMessage.caption || '';
|
||||
hasMedia = true;
|
||||
mediaType = 'document';
|
||||
const fileName = msg.message.documentMessage.fileName || 'document';
|
||||
const fileName = messageContent.documentMessage.fileName || 'document';
|
||||
try {
|
||||
const buf = await downloadMediaMessage(msg, 'buffer', {}, { logger, reuploadRequest: sock.updateMediaMessage });
|
||||
mkdirSync(DOCUMENT_CACHE_DIR, { recursive: true });
|
||||
@@ -309,6 +346,9 @@ async function startSocket() {
|
||||
hasMedia,
|
||||
mediaType,
|
||||
mediaUrls,
|
||||
mentionedIds,
|
||||
quotedParticipant,
|
||||
botIds,
|
||||
timestamp: msg.messageTimestamp,
|
||||
};
|
||||
|
||||
|
||||
@@ -205,6 +205,47 @@ class TestStepCallback:
|
||||
assert "read_file" not in tool_call_ids
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_result_passed_to_build_tool_complete(self, mock_conn, event_loop_fixture):
|
||||
"""Tool result from prev_tools dict is forwarded to build_tool_complete."""
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
# Provide a result string in the tool info dict
|
||||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
||||
)
|
||||
|
||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||
"""When result is None (e.g. first iteration), None is passed through."""
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "web_search", "result": None}])
|
||||
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message callback
|
||||
|
||||
349
tests/acp/test_mcp_e2e.py
Normal file
349
tests/acp/test_mcp_e2e.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""End-to-end tests for ACP MCP server registration and tool-result reporting.
|
||||
|
||||
Exercises the full flow through the ACP server layer:
|
||||
new_session(mcpServers) → MCP tools registered → prompt() →
|
||||
tool_progress_callback (ToolCallStart) →
|
||||
step_callback with results (ToolCallUpdate with rawOutput) →
|
||||
session_update events arrive at the mock client
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.schema import (
|
||||
EnvVariable,
|
||||
HttpHeader,
|
||||
McpServerHttp,
|
||||
McpServerStdio,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
TextContentBlock,
|
||||
ToolCallProgress,
|
||||
ToolCallStart,
|
||||
)
|
||||
|
||||
from acp_adapter.server import HermesACPAgent
|
||||
from acp_adapter.session import SessionManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_manager():
|
||||
return SessionManager(agent_factory=lambda: MagicMock(name="MockAIAgent"))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def acp_agent(mock_manager):
|
||||
return HermesACPAgent(session_manager=mock_manager)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E: MCP registration → prompt → tool events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMcpRegistrationE2E:
|
||||
"""Full flow: session with MCP servers → prompt with tool calls → ACP events."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_with_mcp_servers_registers_tools(self, acp_agent, mock_manager):
|
||||
"""new_session with mcpServers converts them to Hermes config and registers."""
|
||||
servers = [
|
||||
McpServerStdio(
|
||||
name="test-fs",
|
||||
command="/usr/bin/mcp-fs",
|
||||
args=["--root", "/tmp"],
|
||||
env=[EnvVariable(name="DEBUG", value="1")],
|
||||
),
|
||||
McpServerHttp(
|
||||
name="test-api",
|
||||
url="https://api.example.com/mcp",
|
||||
headers=[HttpHeader(name="Authorization", value="Bearer tok123")],
|
||||
),
|
||||
]
|
||||
|
||||
registered_configs = {}
|
||||
|
||||
def mock_register(config_map):
|
||||
registered_configs.update(config_map)
|
||||
return ["mcp_test_fs_read", "mcp_test_fs_write", "mcp_test_api_search"]
|
||||
|
||||
fake_tools = [
|
||||
{"function": {"name": "mcp_test_fs_read"}},
|
||||
{"function": {"name": "mcp_test_fs_write"}},
|
||||
{"function": {"name": "mcp_test_api_search"}},
|
||||
{"function": {"name": "terminal"}},
|
||||
]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
resp = await acp_agent.new_session(cwd="/tmp", mcp_servers=servers)
|
||||
|
||||
assert isinstance(resp, NewSessionResponse)
|
||||
state = mock_manager.get_session(resp.session_id)
|
||||
|
||||
# Verify stdio server was converted correctly
|
||||
assert "test-fs" in registered_configs
|
||||
fs_cfg = registered_configs["test-fs"]
|
||||
assert fs_cfg["command"] == "/usr/bin/mcp-fs"
|
||||
assert fs_cfg["args"] == ["--root", "/tmp"]
|
||||
assert fs_cfg["env"] == {"DEBUG": "1"}
|
||||
|
||||
# Verify HTTP server was converted correctly
|
||||
assert "test-api" in registered_configs
|
||||
api_cfg = registered_configs["test-api"]
|
||||
assert api_cfg["url"] == "https://api.example.com/mcp"
|
||||
assert api_cfg["headers"] == {"Authorization": "Bearer tok123"}
|
||||
|
||||
# Verify agent tool surface was refreshed
|
||||
assert state.agent.tools == fake_tools
|
||||
assert state.agent.valid_tool_names == {
|
||||
"mcp_test_fs_read", "mcp_test_fs_write", "mcp_test_api_search", "terminal"
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_with_tool_calls_emits_acp_events(self, acp_agent, mock_manager):
|
||||
"""Prompt → agent fires callbacks → ACP ToolCallStart + ToolCallUpdate events."""
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
session_id = resp.session_id
|
||||
state = mock_manager.get_session(session_id)
|
||||
|
||||
# Wire up a mock ACP client connection
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
mock_conn.request_permission = AsyncMock()
|
||||
acp_agent._conn = mock_conn
|
||||
|
||||
def mock_run_conversation(user_message, conversation_history=None, task_id=None):
|
||||
"""Simulate an agent turn that calls terminal, gets a result, then responds."""
|
||||
agent = state.agent
|
||||
|
||||
# 1) Agent fires tool_progress_callback (ToolCallStart)
|
||||
if agent.tool_progress_callback:
|
||||
agent.tool_progress_callback(
|
||||
"terminal", "$ echo hello", {"command": "echo hello"}
|
||||
)
|
||||
|
||||
# 2) Agent fires step_callback with tool results (ToolCallUpdate)
|
||||
if agent.step_callback:
|
||||
agent.step_callback(1, [
|
||||
{"name": "terminal", "result": '{"output": "hello\\n", "exit_code": 0}'}
|
||||
])
|
||||
|
||||
return {
|
||||
"final_response": "The command output 'hello'.",
|
||||
"messages": [
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": "The command output 'hello'."},
|
||||
],
|
||||
}
|
||||
|
||||
state.agent.run_conversation = mock_run_conversation
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="run echo hello")]
|
||||
resp = await acp_agent.prompt(prompt=prompt, session_id=session_id)
|
||||
|
||||
assert isinstance(resp, PromptResponse)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
# Collect all session_update calls
|
||||
updates = []
|
||||
for call in mock_conn.session_update.call_args_list:
|
||||
# session_update(session_id, update) — grab the update
|
||||
update_arg = call[1].get("update") or call[0][1]
|
||||
updates.append(update_arg)
|
||||
|
||||
# Find tool_call (start) and tool_call_update (completion) events
|
||||
starts = [u for u in updates if getattr(u, "session_update", None) == "tool_call"]
|
||||
completions = [u for u in updates if getattr(u, "session_update", None) == "tool_call_update"]
|
||||
|
||||
# Should have at least one ToolCallStart for "terminal"
|
||||
assert len(starts) >= 1, f"Expected ToolCallStart, got updates: {[getattr(u, 'session_update', '?') for u in updates]}"
|
||||
start_event = starts[0]
|
||||
assert isinstance(start_event, ToolCallStart)
|
||||
assert start_event.title.startswith("terminal:")
|
||||
|
||||
# Should have at least one ToolCallUpdate (completion) with rawOutput
|
||||
assert len(completions) >= 1, f"Expected ToolCallUpdate, got updates: {[getattr(u, 'session_update', '?') for u in updates]}"
|
||||
complete_event = completions[0]
|
||||
assert isinstance(complete_event, ToolCallProgress)
|
||||
assert complete_event.status == "completed"
|
||||
# rawOutput should contain the tool result string
|
||||
assert complete_event.raw_output is not None
|
||||
assert "hello" in str(complete_event.raw_output)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
session_id = resp.session_id
|
||||
state = mock_manager.get_session(session_id)
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
mock_conn.request_permission = AsyncMock()
|
||||
acp_agent._conn = mock_conn
|
||||
|
||||
def mock_run(user_message, conversation_history=None, task_id=None):
|
||||
agent = state.agent
|
||||
# Fire two tool calls
|
||||
if agent.tool_progress_callback:
|
||||
agent.tool_progress_callback("read_file", "read: /etc/hosts", {"path": "/etc/hosts"})
|
||||
agent.tool_progress_callback("web_search", "web search: test", {"query": "test"})
|
||||
|
||||
if agent.step_callback:
|
||||
agent.step_callback(1, [
|
||||
{"name": "read_file", "result": '{"content": "127.0.0.1 localhost"}'},
|
||||
{"name": "web_search", "result": '{"data": {"web": []}}'},
|
||||
])
|
||||
|
||||
return {"final_response": "Done.", "messages": []}
|
||||
|
||||
state.agent.run_conversation = mock_run
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="test")]
|
||||
await acp_agent.prompt(prompt=prompt, session_id=session_id)
|
||||
|
||||
updates = []
|
||||
for call in mock_conn.session_update.call_args_list:
|
||||
update_arg = call[1].get("update") or call[0][1]
|
||||
updates.append(update_arg)
|
||||
|
||||
starts = [u for u in updates if getattr(u, "session_update", None) == "tool_call"]
|
||||
completions = [u for u in updates if getattr(u, "session_update", None) == "tool_call_update"]
|
||||
|
||||
assert len(starts) == 2, f"Expected 2 starts, got {len(starts)}"
|
||||
assert len(completions) == 2, f"Expected 2 completions, got {len(completions)}"
|
||||
|
||||
# Each completion's toolCallId must match a start's toolCallId
|
||||
start_ids = {s.tool_call_id for s in starts}
|
||||
completion_ids = {c.tool_call_id for c in completions}
|
||||
assert start_ids == completion_ids, (
|
||||
f"IDs must match: starts={start_ids}, completions={completion_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestMcpSanitizationE2E:
|
||||
"""Verify server names with special chars work end-to-end."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slashed_server_name_registers_cleanly(self, acp_agent, mock_manager):
|
||||
"""Server name 'ai.exa/exa' should not crash — tools get sanitized names."""
|
||||
servers = [
|
||||
McpServerHttp(
|
||||
name="ai.exa/exa",
|
||||
url="https://exa.ai/mcp",
|
||||
headers=[],
|
||||
),
|
||||
]
|
||||
|
||||
registered_configs = {}
|
||||
def mock_register(config_map):
|
||||
registered_configs.update(config_map)
|
||||
return ["mcp_ai_exa_exa_search"]
|
||||
|
||||
fake_tools = [{"function": {"name": "mcp_ai_exa_exa_search"}}]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
resp = await acp_agent.new_session(cwd="/tmp", mcp_servers=servers)
|
||||
|
||||
state = mock_manager.get_session(resp.session_id)
|
||||
|
||||
# Raw server name preserved as config key
|
||||
assert "ai.exa/exa" in registered_configs
|
||||
# Agent tools refreshed with sanitized name
|
||||
assert "mcp_ai_exa_exa_search" in state.agent.valid_tool_names
|
||||
|
||||
|
||||
class TestSessionLifecycleMcpE2E:
|
||||
"""Verify MCP servers are registered on all session lifecycle methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""load_session re-registers MCP servers (spec says agents may not retain them)."""
|
||||
# Create a session first
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerStdio(name="srv", command="/bin/test", args=[], env=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
state = mock_manager.get_session(sid)
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await acp_agent.load_session(cwd="/tmp", session_id=sid, mcp_servers=servers)
|
||||
|
||||
assert "srv" in registered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""resume_session re-registers MCP servers."""
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerStdio(name="srv2", command="/bin/test2", args=[], env=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
state = mock_manager.get_session(sid)
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await acp_agent.resume_session(cwd="/tmp", session_id=sid, mcp_servers=servers)
|
||||
|
||||
assert "srv2" in registered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fork_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""fork_session registers MCP servers on the new forked session."""
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerHttp(name="api", url="https://api.test/mcp", headers=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
# Need to set up the forked session's agent too
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
fork_resp = await acp_agent.fork_session(
|
||||
cwd="/tmp", session_id=sid, mcp_servers=servers
|
||||
)
|
||||
|
||||
assert fork_resp.session_id != ""
|
||||
assert "api" in registered
|
||||
@@ -505,3 +505,179 @@ class TestSlashCommands:
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _register_session_mcp_servers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterSessionMcpServers:
|
||||
"""Tests for ACP MCP server registration in session lifecycle."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_no_servers(self, agent, mock_manager):
|
||||
"""No-op when mcp_servers is None or empty."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
# Should not raise
|
||||
await agent._register_session_mcp_servers(state, None)
|
||||
await agent._register_session_mcp_servers(state, [])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_stdio_servers(self, agent, mock_manager):
|
||||
"""McpServerStdio servers are converted and passed to register_mcp_servers."""
|
||||
from acp.schema import McpServerStdio, EnvVariable
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
# Give the mock agent the attributes _register_session_mcp_servers reads
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
server = McpServerStdio(
|
||||
name="test-server",
|
||||
command="/usr/bin/test",
|
||||
args=["--flag"],
|
||||
env=[EnvVariable(name="KEY", value="val")],
|
||||
)
|
||||
|
||||
registered_config = {}
|
||||
def capture_register(config_map):
|
||||
registered_config.update(config_map)
|
||||
return ["mcp_test_server_tool1"]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert "test-server" in registered_config
|
||||
cfg = registered_config["test-server"]
|
||||
assert cfg["command"] == "/usr/bin/test"
|
||||
assert cfg["args"] == ["--flag"]
|
||||
assert cfg["env"] == {"KEY": "val"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_http_servers(self, agent, mock_manager):
|
||||
"""McpServerHttp servers are converted correctly."""
|
||||
from acp.schema import McpServerHttp, HttpHeader
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
server = McpServerHttp(
|
||||
name="http-server",
|
||||
url="https://api.example.com/mcp",
|
||||
headers=[HttpHeader(name="Authorization", value="Bearer tok")],
|
||||
)
|
||||
|
||||
registered_config = {}
|
||||
def capture_register(config_map):
|
||||
registered_config.update(config_map)
|
||||
return []
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert "http-server" in registered_config
|
||||
cfg = registered_config["http-server"]
|
||||
assert cfg["url"] == "https://api.example.com/mcp"
|
||||
assert cfg["headers"] == {"Authorization": "Bearer tok"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refreshes_agent_tool_surface(self, agent, mock_manager):
|
||||
"""After MCP registration, agent.tools and valid_tool_names are refreshed."""
|
||||
from acp.schema import McpServerStdio
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
state.agent._cached_system_prompt = "old prompt"
|
||||
|
||||
server = McpServerStdio(
|
||||
name="srv",
|
||||
command="/bin/test",
|
||||
args=[],
|
||||
env=[],
|
||||
)
|
||||
|
||||
fake_tools = [
|
||||
{"function": {"name": "mcp_srv_search"}},
|
||||
{"function": {"name": "terminal"}},
|
||||
]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", return_value=["mcp_srv_search"]), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert state.agent.tools == fake_tools
|
||||
assert state.agent.valid_tool_names == {"mcp_srv_search", "terminal"}
|
||||
# _invalidate_system_prompt should have been called
|
||||
state.agent._invalidate_system_prompt.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_failure_logs_warning(self, agent, mock_manager):
|
||||
"""If register_mcp_servers raises, warning is logged but no crash."""
|
||||
from acp.schema import McpServerStdio
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
server = McpServerStdio(
|
||||
name="bad",
|
||||
command="/nonexistent",
|
||||
args=[],
|
||||
env=[],
|
||||
)
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=RuntimeError("boom")):
|
||||
# Should not raise
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_calls_register(self, agent, mock_manager):
|
||||
"""new_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.new_session(cwd="/tmp", mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
# Second arg should be the mcp_servers list
|
||||
assert mock_reg.call_args[0][1] == ["fake"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_calls_register(self, agent, mock_manager):
|
||||
"""load_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
# Create a session first so load can find it
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.load_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_calls_register(self, agent, mock_manager):
|
||||
"""resume_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.resume_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fork_session_calls_register(self, agent, mock_manager):
|
||||
"""fork_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.fork_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
133
tests/gateway/test_step_callback_compat.py
Normal file
133
tests/gateway/test_step_callback_compat.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for step_callback backward compatibility.
|
||||
|
||||
Verifies that the gateway's step_callback normalization keeps
|
||||
``tool_names`` as a list of strings for backward-compatible hooks,
|
||||
while also providing the enriched ``tools`` list with results.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStepCallbackNormalization:
|
||||
"""The gateway's _step_callback_sync normalizes prev_tools from run_agent."""
|
||||
|
||||
def _extract_step_callback(self):
|
||||
"""Build a minimal _step_callback_sync using the same logic as gateway/run.py.
|
||||
|
||||
We replicate the closure so we can test normalisation in isolation
|
||||
without spinning up the full gateway.
|
||||
"""
|
||||
captured_events = []
|
||||
|
||||
class FakeHooks:
|
||||
async def emit(self, event_type, data):
|
||||
captured_events.append((event_type, data))
|
||||
|
||||
hooks_ref = FakeHooks()
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
||||
_names: list[str] = []
|
||||
for _t in (prev_tools or []):
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
hooks_ref.emit("agent:step", {
|
||||
"iteration": iteration,
|
||||
"tool_names": _names,
|
||||
"tools": prev_tools,
|
||||
}),
|
||||
loop,
|
||||
)
|
||||
|
||||
return _step_callback_sync, captured_events, loop
|
||||
|
||||
def test_dict_prev_tools_produce_string_tool_names(self):
|
||||
"""When prev_tools is list[dict], tool_names should be list[str]."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
# Simulate the enriched format from run_agent.py
|
||||
prev_tools = [
|
||||
{"name": "terminal", "result": '{"output": "hello"}'},
|
||||
{"name": "read_file", "result": '{"content": "..."}'},
|
||||
]
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0)) # prime the loop
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(1, prev_tools))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
# tool_names must be strings for backward compat
|
||||
assert data["tool_names"] == ["terminal", "read_file"]
|
||||
assert all(isinstance(n, str) for n in data["tool_names"])
|
||||
# tools should be the enriched dicts
|
||||
assert data["tools"] == prev_tools
|
||||
|
||||
def test_string_prev_tools_still_work(self):
|
||||
"""When prev_tools is list[str] (legacy), tool_names should pass through."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
prev_tools = ["terminal", "read_file"]
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0))
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(2, prev_tools))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
assert data["tool_names"] == ["terminal", "read_file"]
|
||||
|
||||
def test_empty_prev_tools(self):
|
||||
"""Empty or None prev_tools should produce empty tool_names."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0))
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(1, []))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
assert data["tool_names"] == []
|
||||
|
||||
def test_joinable_for_hook_example(self):
|
||||
"""The documented hook example: ', '.join(tool_names) should work."""
|
||||
# This is the exact pattern from the docs
|
||||
prev_tools = [
|
||||
{"name": "terminal", "result": "ok"},
|
||||
{"name": "web_search", "result": None},
|
||||
]
|
||||
|
||||
_names = []
|
||||
for _t in prev_tools:
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
|
||||
# This must not raise — documented hook pattern
|
||||
result = ", ".join(_names)
|
||||
assert result == "terminal, web_search"
|
||||
142
tests/gateway/test_whatsapp_group_gating.py
Normal file
142
tests/gateway/test_whatsapp_group_gating.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig, load_gateway_config
|
||||
|
||||
|
||||
def _make_adapter(require_mention=None, mention_patterns=None, free_response_chats=None):
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter
|
||||
|
||||
extra = {}
|
||||
if require_mention is not None:
|
||||
extra["require_mention"] = require_mention
|
||||
if mention_patterns is not None:
|
||||
extra["mention_patterns"] = mention_patterns
|
||||
if free_response_chats is not None:
|
||||
extra["free_response_chats"] = free_response_chats
|
||||
|
||||
adapter = object.__new__(WhatsAppAdapter)
|
||||
adapter.platform = Platform.WHATSAPP
|
||||
adapter.config = PlatformConfig(enabled=True, extra=extra)
|
||||
adapter._message_handler = AsyncMock()
|
||||
adapter._mention_patterns = adapter._compile_mention_patterns()
|
||||
return adapter
|
||||
|
||||
|
||||
def _group_message(body="hello", **overrides):
|
||||
data = {
|
||||
"isGroup": True,
|
||||
"body": body,
|
||||
"chatId": "120363001234567890@g.us",
|
||||
"mentionedIds": [],
|
||||
"botIds": ["15551230000@s.whatsapp.net", "15551230000@lid"],
|
||||
"quotedParticipant": "",
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
|
||||
def test_group_messages_can_be_opened_via_config():
|
||||
adapter = _make_adapter(require_mention=False)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is True
|
||||
|
||||
|
||||
def test_group_messages_can_require_direct_trigger_via_config():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
assert adapter._should_process_message(
|
||||
_group_message(
|
||||
"hi there",
|
||||
mentionedIds=["15551230000@s.whatsapp.net"],
|
||||
)
|
||||
) is True
|
||||
assert adapter._should_process_message(
|
||||
_group_message(
|
||||
"replying",
|
||||
quotedParticipant="15551230000@lid",
|
||||
)
|
||||
) is True
|
||||
assert adapter._should_process_message(_group_message("/status")) is True
|
||||
|
||||
|
||||
def test_regex_mention_patterns_allow_custom_wake_words():
|
||||
adapter = _make_adapter(require_mention=True, mention_patterns=[r"^\s*chompy\b"])
|
||||
|
||||
assert adapter._should_process_message(_group_message("chompy status")) is True
|
||||
assert adapter._should_process_message(_group_message(" chompy help")) is True
|
||||
assert adapter._should_process_message(_group_message("hey chompy")) is False
|
||||
|
||||
|
||||
def test_invalid_regex_patterns_are_ignored():
|
||||
adapter = _make_adapter(require_mention=True, mention_patterns=[r"(", r"^\s*chompy\b"])
|
||||
|
||||
assert adapter._should_process_message(_group_message("chompy status")) is True
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
|
||||
|
||||
def test_config_bridges_whatsapp_group_settings(monkeypatch, tmp_path):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"whatsapp:\n"
|
||||
" require_mention: true\n"
|
||||
" mention_patterns:\n"
|
||||
" - \"^\\\\s*chompy\\\\b\"\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("WHATSAPP_REQUIRE_MENTION", raising=False)
|
||||
monkeypatch.delenv("WHATSAPP_MENTION_PATTERNS", raising=False)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config is not None
|
||||
assert config.platforms[Platform.WHATSAPP].extra["require_mention"] is True
|
||||
assert config.platforms[Platform.WHATSAPP].extra["mention_patterns"] == [r"^\s*chompy\b"]
|
||||
assert __import__("os").environ["WHATSAPP_REQUIRE_MENTION"] == "true"
|
||||
assert json.loads(__import__("os").environ["WHATSAPP_MENTION_PATTERNS"]) == [r"^\s*chompy\b"]
|
||||
|
||||
|
||||
def test_free_response_chats_bypass_mention_gating():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
free_response_chats=["120363001234567890@g.us"],
|
||||
)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is True
|
||||
|
||||
|
||||
def test_free_response_chats_does_not_bypass_other_groups():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
free_response_chats=["999999999999@g.us"],
|
||||
)
|
||||
|
||||
assert adapter._should_process_message(_group_message("hello everyone")) is False
|
||||
|
||||
|
||||
def test_dm_always_passes_even_with_require_mention():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
dm = {"isGroup": False, "body": "hello", "botIds": [], "mentionedIds": []}
|
||||
assert adapter._should_process_message(dm) is True
|
||||
|
||||
|
||||
def test_mention_stripping_removes_bot_phone_from_body():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
data = _group_message("@15551230000 what is the weather?")
|
||||
cleaned = adapter._clean_bot_mention_text(data["body"], data)
|
||||
assert "15551230000" not in cleaned
|
||||
assert "weather" in cleaned
|
||||
|
||||
|
||||
def test_mention_stripping_preserves_body_when_no_mention():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
data = _group_message("just a normal message")
|
||||
cleaned = adapter._clean_bot_mention_text(data["body"], data)
|
||||
assert cleaned == "just a normal message"
|
||||
@@ -2900,3 +2900,164 @@ class TestMCPBuiltinCollisionGuard:
|
||||
assert mock_registry.get_toolset_for_tool("mcp_srv_do_thing") == "mcp-srv"
|
||||
|
||||
_servers.pop("srv", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sanitize_mcp_name_component
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeMcpNameComponent:
|
||||
"""Verify sanitize_mcp_name_component handles all edge cases."""
|
||||
|
||||
def test_hyphens_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("my-server") == "my_server"
|
||||
|
||||
def test_dots_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("ai.exa") == "ai_exa"
|
||||
|
||||
def test_slashes_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("ai.exa/exa") == "ai_exa_exa"
|
||||
|
||||
def test_mixed_special_characters(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("@scope/my-pkg.v2") == "_scope_my_pkg_v2"
|
||||
|
||||
def test_alphanumeric_and_underscores_preserved(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("my_server_123") == "my_server_123"
|
||||
|
||||
def test_empty_string(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("") == ""
|
||||
|
||||
def test_none_returns_empty(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component(None) == ""
|
||||
|
||||
def test_slash_in_convert_mcp_schema(self):
|
||||
"""Server names with slashes produce valid tool names via _convert_mcp_schema."""
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
mcp_tool = _make_mcp_tool(name="search")
|
||||
schema = _convert_mcp_schema("ai.exa/exa", mcp_tool)
|
||||
assert schema["name"] == "mcp_ai_exa_exa_search"
|
||||
# Must match Anthropic's pattern: ^[a-zA-Z0-9_-]{1,128}$
|
||||
import re
|
||||
assert re.match(r"^[a-zA-Z0-9_-]{1,128}$", schema["name"])
|
||||
|
||||
def test_slash_in_build_utility_schemas(self):
|
||||
"""Server names with slashes produce valid utility tool names."""
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("ai.exa/exa")
|
||||
for s in schemas:
|
||||
name = s["schema"]["name"]
|
||||
assert "/" not in name
|
||||
assert "." not in name
|
||||
|
||||
def test_slash_in_sync_mcp_toolsets(self):
|
||||
"""_sync_mcp_toolsets uses sanitize consistently with _convert_mcp_schema."""
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
|
||||
# Verify the prefix generation matches what _convert_mcp_schema produces
|
||||
server_name = "ai.exa/exa"
|
||||
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
|
||||
assert safe_prefix == "mcp_ai_exa_exa_"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_mcp_servers public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterMcpServers:
|
||||
"""Verify the new register_mcp_servers() public API."""
|
||||
|
||||
def test_empty_servers_returns_empty(self):
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True):
|
||||
result = register_mcp_servers({})
|
||||
assert result == []
|
||||
|
||||
def test_mcp_not_available_returns_empty(self):
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", False):
|
||||
result = register_mcp_servers({"srv": {"command": "test"}})
|
||||
assert result == []
|
||||
|
||||
def test_skips_already_connected_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers
|
||||
|
||||
mock_server = _make_mock_server("existing")
|
||||
_servers["existing"] = mock_server
|
||||
|
||||
try:
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_existing_tool"]):
|
||||
result = register_mcp_servers({"existing": {"command": "test"}})
|
||||
assert result == ["mcp_existing_tool"]
|
||||
finally:
|
||||
_servers.pop("existing", None)
|
||||
|
||||
def test_skips_disabled_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers
|
||||
|
||||
try:
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=[]):
|
||||
result = register_mcp_servers({"srv": {"command": "test", "enabled": False}})
|
||||
assert result == []
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_connects_new_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
|
||||
|
||||
fake_config = {"my_server": {"command": "npx", "args": ["test"]}}
|
||||
|
||||
async def fake_register(name, cfg):
|
||||
server = _make_mock_server(name)
|
||||
server._registered_tool_names = ["mcp_my_server_tool1"]
|
||||
_servers[name] = server
|
||||
return ["mcp_my_server_tool1"]
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_my_server_tool1"]):
|
||||
_ensure_mcp_loop()
|
||||
result = register_mcp_servers(fake_config)
|
||||
|
||||
assert "mcp_my_server_tool1" in result
|
||||
_servers.pop("my_server", None)
|
||||
|
||||
def test_logs_summary_on_success(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
|
||||
|
||||
fake_config = {"srv": {"command": "npx", "args": ["test"]}}
|
||||
|
||||
async def fake_register(name, cfg):
|
||||
server = _make_mock_server(name)
|
||||
server._registered_tool_names = ["mcp_srv_t1", "mcp_srv_t2"]
|
||||
_servers[name] = server
|
||||
return ["mcp_srv_t1", "mcp_srv_t2"]
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_srv_t1", "mcp_srv_t2"]):
|
||||
_ensure_mcp_loop()
|
||||
|
||||
with patch("tools.mcp_tool.logger") as mock_logger:
|
||||
register_mcp_servers(fake_config)
|
||||
|
||||
info_calls = [str(c) for c in mock_logger.info.call_args_list]
|
||||
assert any("2 tool(s)" in c and "1 server(s)" in c for c in info_calls), (
|
||||
f"Summary should report 2 tools from 1 server, got: {info_calls}"
|
||||
)
|
||||
|
||||
_servers.pop("srv", None)
|
||||
|
||||
@@ -1406,6 +1406,17 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
|
||||
return schema
|
||||
|
||||
|
||||
def sanitize_mcp_name_component(value: str) -> str:
|
||||
"""Return an MCP name component safe for tool and prefix generation.
|
||||
|
||||
Preserves Hermes's historical behavior of converting hyphens to
|
||||
underscores, and also replaces any other character outside
|
||||
``[A-Za-z0-9_]`` with ``_`` so generated tool names are compatible with
|
||||
provider validation rules.
|
||||
"""
|
||||
return re.sub(r"[^A-Za-z0-9_]", "_", str(value or ""))
|
||||
|
||||
|
||||
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
"""Convert an MCP tool listing to the Hermes registry schema format.
|
||||
|
||||
@@ -1417,9 +1428,8 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
Returns:
|
||||
A dict suitable for ``registry.register(schema=...)``.
|
||||
"""
|
||||
# Sanitize: replace hyphens and dots with underscores for LLM API compatibility
|
||||
safe_tool_name = mcp_tool.name.replace("-", "_").replace(".", "_")
|
||||
safe_server_name = server_name.replace("-", "_").replace(".", "_")
|
||||
safe_tool_name = sanitize_mcp_name_component(mcp_tool.name)
|
||||
safe_server_name = sanitize_mcp_name_component(server_name)
|
||||
prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}"
|
||||
return {
|
||||
"name": prefixed_name,
|
||||
@@ -1449,7 +1459,7 @@ def _sync_mcp_toolsets(server_names: Optional[List[str]] = None) -> None:
|
||||
all_mcp_tools: List[str] = []
|
||||
|
||||
for server_name in server_names:
|
||||
safe_prefix = f"mcp_{server_name.replace('-', '_').replace('.', '_')}_"
|
||||
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
|
||||
server_tools = sorted(
|
||||
t for t in existing if t.startswith(safe_prefix)
|
||||
)
|
||||
@@ -1485,7 +1495,7 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
|
||||
with keys: schema, handler_key.
|
||||
"""
|
||||
safe_name = server_name.replace("-", "_").replace(".", "_")
|
||||
safe_name = sanitize_mcp_name_component(server_name)
|
||||
return [
|
||||
{
|
||||
"schema": {
|
||||
@@ -1772,6 +1782,86 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def register_mcp_servers(servers: Dict[str, dict]) -> List[str]:
|
||||
"""Connect to explicit MCP servers and register their tools.
|
||||
|
||||
Idempotent for already-connected server names. Servers with
|
||||
``enabled: false`` are skipped without disconnecting existing sessions.
|
||||
|
||||
Args:
|
||||
servers: Mapping of ``{server_name: server_config}``.
|
||||
|
||||
Returns:
|
||||
List of all currently registered MCP tool names.
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
logger.debug("MCP SDK not available -- skipping explicit MCP registration")
|
||||
return []
|
||||
|
||||
if not servers:
|
||||
logger.debug("No explicit MCP servers provided")
|
||||
return []
|
||||
|
||||
# Only attempt servers that aren't already connected and are enabled
|
||||
# (enabled: false skips the server entirely without removing its config)
|
||||
with _lock:
|
||||
new_servers = {
|
||||
k: v
|
||||
for k, v in servers.items()
|
||||
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
|
||||
if not new_servers:
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
return _existing_tool_names()
|
||||
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
return await _discover_and_register_server(name, cfg)
|
||||
|
||||
async def _discover_all():
|
||||
server_names = list(new_servers.keys())
|
||||
# Connect to all servers in PARALLEL
|
||||
results = await asyncio.gather(
|
||||
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for name, result in zip(server_names, results):
|
||||
if isinstance(result, Exception):
|
||||
command = new_servers.get(name, {}).get("command")
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s'%s: %s",
|
||||
name,
|
||||
f" (command={command})" if command else "",
|
||||
_format_connect_error(result),
|
||||
)
|
||||
|
||||
# Per-server timeouts are handled inside _discover_and_register_server.
|
||||
# The outer timeout is generous: 120s total for parallel discovery.
|
||||
_run_on_mcp_loop(_discover_all(), timeout=120)
|
||||
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
|
||||
# Log a summary so ACP callers get visibility into what was registered.
|
||||
with _lock:
|
||||
connected = [n for n in new_servers if n in _servers]
|
||||
new_tool_count = sum(
|
||||
len(getattr(_servers[n], "_registered_tool_names", []))
|
||||
for n in connected
|
||||
)
|
||||
failed = len(new_servers) - len(connected)
|
||||
if new_tool_count or failed:
|
||||
summary = f"MCP: registered {new_tool_count} tool(s) from {len(connected)} server(s)"
|
||||
if failed:
|
||||
summary += f" ({failed} failed)"
|
||||
logger.info(summary)
|
||||
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def discover_mcp_tools() -> List[str]:
|
||||
"""Entry point: load config, connect to MCP servers, register tools.
|
||||
|
||||
@@ -1793,69 +1883,32 @@ def discover_mcp_tools() -> List[str]:
|
||||
logger.debug("No MCP servers configured")
|
||||
return []
|
||||
|
||||
# Only attempt servers that aren't already connected and are enabled
|
||||
# (enabled: false skips the server entirely without removing its config)
|
||||
with _lock:
|
||||
new_servers = {
|
||||
k: v
|
||||
for k, v in servers.items()
|
||||
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
new_server_names = [
|
||||
name
|
||||
for name, cfg in servers.items()
|
||||
if name not in _servers and _parse_boolish(cfg.get("enabled", True), default=True)
|
||||
]
|
||||
|
||||
if not new_servers:
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
return _existing_tool_names()
|
||||
tool_names = register_mcp_servers(servers)
|
||||
if not new_server_names:
|
||||
return tool_names
|
||||
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
all_tools: List[str] = []
|
||||
failed_count = 0
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
return await _discover_and_register_server(name, cfg)
|
||||
|
||||
async def _discover_all():
|
||||
nonlocal failed_count
|
||||
server_names = list(new_servers.keys())
|
||||
# Connect to all servers in PARALLEL
|
||||
results = await asyncio.gather(
|
||||
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
||||
return_exceptions=True,
|
||||
with _lock:
|
||||
connected_server_names = [name for name in new_server_names if name in _servers]
|
||||
new_tool_count = sum(
|
||||
len(getattr(_servers[name], "_registered_tool_names", []))
|
||||
for name in connected_server_names
|
||||
)
|
||||
for name, result in zip(server_names, results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
command = new_servers.get(name, {}).get("command")
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s'%s: %s",
|
||||
name,
|
||||
f" (command={command})" if command else "",
|
||||
_format_connect_error(result),
|
||||
)
|
||||
elif isinstance(result, list):
|
||||
all_tools.extend(result)
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# Per-server timeouts are handled inside _discover_and_register_server.
|
||||
# The outer timeout is generous: 120s total for parallel discovery.
|
||||
_run_on_mcp_loop(_discover_all(), timeout=120)
|
||||
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
|
||||
# Print summary
|
||||
total_servers = len(new_servers)
|
||||
ok_servers = total_servers - failed_count
|
||||
if all_tools or failed_count:
|
||||
summary = f" MCP: {len(all_tools)} tool(s) from {ok_servers} server(s)"
|
||||
failed_count = len(new_server_names) - len(connected_server_names)
|
||||
if new_tool_count or failed_count:
|
||||
summary = f" MCP: {new_tool_count} tool(s) from {len(connected_server_names)} server(s)"
|
||||
if failed_count:
|
||||
summary += f" ({failed_count} failed)"
|
||||
logger.info(summary)
|
||||
|
||||
# Return ALL registered tools (existing + newly discovered)
|
||||
return _existing_tool_names()
|
||||
return tool_names
|
||||
|
||||
|
||||
def get_mcp_status() -> List[dict]:
|
||||
|
||||
Reference in New Issue
Block a user