mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-21 09:31:31 +08:00
Compare commits
1 Commits
hermes/her
...
opencode-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e52ddb6318 |
@@ -22,9 +22,6 @@ from acp.schema import (
|
||||
InitializeResponse,
|
||||
ListSessionsResponse,
|
||||
LoadSessionResponse,
|
||||
McpServerHttp,
|
||||
McpServerSse,
|
||||
McpServerStdio,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
@@ -96,71 +93,6 @@ class HermesACPAgent(acp.Agent):
|
||||
self._conn = conn
|
||||
logger.info("ACP client connected")
|
||||
|
||||
async def _register_session_mcp_servers(
|
||||
self,
|
||||
state: SessionState,
|
||||
mcp_servers: list[McpServerStdio | McpServerHttp | McpServerSse] | None,
|
||||
) -> None:
|
||||
"""Register ACP-provided MCP servers and refresh the agent tool surface."""
|
||||
if not mcp_servers:
|
||||
return
|
||||
|
||||
try:
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
config_map: dict[str, dict] = {}
|
||||
for server in mcp_servers:
|
||||
name = server.name
|
||||
if isinstance(server, McpServerStdio):
|
||||
config = {
|
||||
"command": server.command,
|
||||
"args": list(server.args),
|
||||
"env": {item.name: item.value for item in server.env},
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"url": server.url,
|
||||
"headers": {item.name: item.value for item in server.headers},
|
||||
}
|
||||
config_map[name] = config
|
||||
|
||||
await asyncio.to_thread(register_mcp_servers, config_map)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Session %s: failed to register ACP MCP servers",
|
||||
state.session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from model_tools import get_tool_definitions
|
||||
|
||||
enabled_toolsets = getattr(state.agent, "enabled_toolsets", None) or ["hermes-acp"]
|
||||
disabled_toolsets = getattr(state.agent, "disabled_toolsets", None)
|
||||
state.agent.tools = get_tool_definitions(
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets,
|
||||
quiet_mode=True,
|
||||
)
|
||||
state.agent.valid_tool_names = {
|
||||
tool["function"]["name"] for tool in state.agent.tools or []
|
||||
}
|
||||
invalidate = getattr(state.agent, "_invalidate_system_prompt", None)
|
||||
if callable(invalidate):
|
||||
invalidate()
|
||||
logger.info(
|
||||
"Session %s: refreshed tool surface after ACP MCP registration (%d tools)",
|
||||
state.session_id,
|
||||
len(state.agent.tools or []),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Session %s: failed to refresh tool surface after ACP MCP registration",
|
||||
state.session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ---- ACP lifecycle ------------------------------------------------------
|
||||
|
||||
async def initialize(
|
||||
@@ -217,7 +149,6 @@ class HermesACPAgent(acp.Agent):
|
||||
**kwargs: Any,
|
||||
) -> NewSessionResponse:
|
||||
state = self.session_manager.create_session(cwd=cwd)
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
||||
return NewSessionResponse(session_id=state.session_id)
|
||||
|
||||
@@ -232,7 +163,6 @@ class HermesACPAgent(acp.Agent):
|
||||
if state is None:
|
||||
logger.warning("load_session: session %s not found", session_id)
|
||||
return None
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Loaded session %s", session_id)
|
||||
return LoadSessionResponse()
|
||||
|
||||
@@ -247,7 +177,6 @@ class HermesACPAgent(acp.Agent):
|
||||
if state is None:
|
||||
logger.warning("resume_session: session %s not found, creating new", session_id)
|
||||
state = self.session_manager.create_session(cwd=cwd)
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Resumed session %s", state.session_id)
|
||||
return ResumeSessionResponse()
|
||||
|
||||
@@ -271,8 +200,6 @@ class HermesACPAgent(acp.Agent):
|
||||
) -> ForkSessionResponse:
|
||||
state = self.session_manager.fork_session(session_id, cwd=cwd)
|
||||
new_id = state.session_id if state else ""
|
||||
if state is not None:
|
||||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Forked session %s -> %s", session_id, new_id)
|
||||
return ForkSessionResponse(session_id=new_id)
|
||||
|
||||
|
||||
@@ -301,6 +301,8 @@ Update the summary using this exact structure. PRESERVE all existing information
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
|
||||
|
||||
Write the summary in the same language the user was using in the conversation.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
else:
|
||||
# First compaction: summarize from scratch
|
||||
@@ -339,6 +341,8 @@ Use this exact structure:
|
||||
|
||||
Target ~{summary_budget} tokens. Be specific — include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details.
|
||||
|
||||
Write the summary in the same language the user was using in the conversation.
|
||||
|
||||
Write only the summary body. Do not include any preamble or prefix."""
|
||||
|
||||
try:
|
||||
|
||||
48
cli.py
48
cli.py
@@ -3052,54 +3052,10 @@ class HermesCLI:
|
||||
print(f" Config File: {config_path} {config_status}")
|
||||
print()
|
||||
|
||||
def _list_recent_sessions(self, limit: int = 10) -> list[dict[str, Any]]:
|
||||
"""Return recent CLI sessions for in-chat browsing/resume affordances."""
|
||||
if not self._session_db:
|
||||
return []
|
||||
try:
|
||||
sessions = self._session_db.list_sessions_rich(
|
||||
source="cli",
|
||||
exclude_sources=["tool"],
|
||||
limit=limit,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
return [s for s in sessions if s.get("id") != self.session_id]
|
||||
|
||||
def _show_recent_sessions(self, *, reason: str = "history", limit: int = 10) -> bool:
|
||||
"""Render recent sessions inline from the active chat TUI.
|
||||
|
||||
Returns True when something was shown, False if no session list was available.
|
||||
"""
|
||||
sessions = self._list_recent_sessions(limit=limit)
|
||||
if not sessions:
|
||||
return False
|
||||
|
||||
from hermes_cli.main import _relative_time
|
||||
|
||||
print()
|
||||
if reason == "history":
|
||||
print("(._.) No messages in the current chat yet — here are recent sessions you can resume:")
|
||||
else:
|
||||
print(" Recent sessions:")
|
||||
print()
|
||||
print(f" {'Title':<32} {'Preview':<40} {'Last Active':<13} {'ID'}")
|
||||
print(f" {'─' * 32} {'─' * 40} {'─' * 13} {'─' * 24}")
|
||||
for session in sessions:
|
||||
title = (session.get("title") or "—")[:30]
|
||||
preview = (session.get("preview") or "")[:38]
|
||||
last_active = _relative_time(session.get("last_active"))
|
||||
print(f" {title:<32} {preview:<40} {last_active:<13} {session['id']}")
|
||||
print()
|
||||
print(" Use /resume <session id or title> to continue where you left off.")
|
||||
print()
|
||||
return True
|
||||
|
||||
def show_history(self):
|
||||
"""Display conversation history."""
|
||||
if not self.conversation_history:
|
||||
if not self._show_recent_sessions(reason="history"):
|
||||
print("(._.) No conversation history yet.")
|
||||
print("(._.) No conversation history yet.")
|
||||
return
|
||||
|
||||
preview_limit = 400
|
||||
@@ -3224,8 +3180,6 @@ class HermesCLI:
|
||||
|
||||
if not target:
|
||||
_cprint(" Usage: /resume <session_id_or_title>")
|
||||
if self._show_recent_sessions(reason="resume"):
|
||||
return
|
||||
_cprint(" Tip: Use /history or `hermes sessions list` to find sessions.")
|
||||
return
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ runs at a time if multiple processes overlap.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -444,30 +443,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
session_db=_session_db,
|
||||
)
|
||||
|
||||
# Run the agent with a timeout so a hung API call or tool doesn't
|
||||
# block the cron ticker thread indefinitely. Default 10 minutes;
|
||||
# override via env var. Uses a separate thread because
|
||||
# run_conversation is synchronous.
|
||||
_cron_timeout = float(os.getenv("HERMES_CRON_TIMEOUT", 600))
|
||||
_cron_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
_cron_future = _cron_pool.submit(agent.run_conversation, prompt)
|
||||
try:
|
||||
result = _cron_future.result(timeout=_cron_timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
logger.error(
|
||||
"Job '%s' timed out after %.0fs — interrupting agent",
|
||||
job_name, _cron_timeout,
|
||||
)
|
||||
if hasattr(agent, "interrupt"):
|
||||
agent.interrupt("Cron job timed out")
|
||||
_cron_pool.shutdown(wait=False, cancel_futures=True)
|
||||
raise TimeoutError(
|
||||
f"Cron job '{job_name}' timed out after "
|
||||
f"{int(_cron_timeout // 60)} minutes"
|
||||
)
|
||||
finally:
|
||||
_cron_pool.shutdown(wait=False)
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
final_response = result.get("final_response", "") or ""
|
||||
# Use a separate variable for log display; keep final_response clean
|
||||
# for delivery logic (empty response = no delivery).
|
||||
|
||||
@@ -1046,13 +1046,6 @@ class BasePlatformAdapter(ABC):
|
||||
self._active_sessions[session_key].set()
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
# Mark session as active BEFORE spawning background task to close
|
||||
# the race window where a second message arriving before the task
|
||||
# starts would also pass the _active_sessions check and spawn a
|
||||
# duplicate task. (grammY sequentialize / aiogram EventIsolation
|
||||
# pattern — set the guard synchronously, not inside the task.)
|
||||
self._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Spawn background task to process this message
|
||||
task = asyncio.create_task(self._process_message_background(event, session_key))
|
||||
try:
|
||||
@@ -1099,10 +1092,8 @@ class BasePlatformAdapter(ABC):
|
||||
if getattr(result, "success", False):
|
||||
delivery_succeeded = True
|
||||
|
||||
# Reuse the interrupt event set by handle_message() (which marks
|
||||
# the session active before spawning this task to prevent races).
|
||||
# Fall back to a new Event only if the entry was removed externally.
|
||||
interrupt_event = self._active_sessions.get(session_key) or asyncio.Event()
|
||||
# Create interrupt event for this session
|
||||
interrupt_event = asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
@@ -1115,12 +1106,9 @@ class BasePlatformAdapter(ABC):
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
|
||||
# Send response if any. A None/empty response is normal when
|
||||
# streaming already delivered the text (already_sent=True) or
|
||||
# when the message was queued behind an active agent. Log at
|
||||
# DEBUG to avoid noisy warnings for expected behavior.
|
||||
# Send response if any
|
||||
if not response:
|
||||
logger.debug("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
|
||||
@@ -900,9 +900,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception:
|
||||
pass # best-effort truncation
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
# Flood control / RetryAfter — short waits are retried inline,
|
||||
# long waits return a failure immediately so streaming can fall back
|
||||
# to a normal final send instead of leaving a truncated partial.
|
||||
# Flood control / RetryAfter — back off and retry once
|
||||
retry_after = getattr(e, "retry_after", None)
|
||||
if retry_after is not None or "retry after" in err_str:
|
||||
wait = retry_after if retry_after else 1.0
|
||||
@@ -910,8 +908,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"[%s] Telegram flood control, waiting %.1fs",
|
||||
self.name, wait,
|
||||
)
|
||||
if wait > 5.0:
|
||||
return SendResult(success=False, error=f"flood_control:{wait}")
|
||||
await asyncio.sleep(wait)
|
||||
try:
|
||||
await self._bot.edit_message_text(
|
||||
|
||||
153
gateway/run.py
153
gateway/run.py
@@ -303,43 +303,6 @@ def _resolve_runtime_agent_kwargs() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _build_media_placeholder(event) -> str:
|
||||
"""Build a text placeholder for media-only events so they aren't dropped.
|
||||
|
||||
When a photo/document is queued during active processing and later
|
||||
dequeued, only .text is extracted. If the event has no caption,
|
||||
the media would be silently lost. This builds a placeholder that
|
||||
the vision enrichment pipeline will replace with a real description.
|
||||
"""
|
||||
parts = []
|
||||
media_urls = getattr(event, "media_urls", None) or []
|
||||
media_types = getattr(event, "media_types", None) or []
|
||||
for i, url in enumerate(media_urls):
|
||||
mtype = media_types[i] if i < len(media_types) else ""
|
||||
if mtype.startswith("image/") or getattr(event, "message_type", None) == MessageType.PHOTO:
|
||||
parts.append(f"[User sent an image: {url}]")
|
||||
elif mtype.startswith("audio/"):
|
||||
parts.append(f"[User sent audio: {url}]")
|
||||
else:
|
||||
parts.append(f"[User sent a file: {url}]")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _dequeue_pending_text(adapter, session_key: str) -> str | None:
|
||||
"""Consume and return the text of a pending queued message.
|
||||
|
||||
Preserves media context for captionless photo/document events by
|
||||
building a placeholder so the message isn't silently dropped.
|
||||
"""
|
||||
event = adapter.get_pending_message(session_key)
|
||||
if not event:
|
||||
return None
|
||||
text = event.text
|
||||
if not text and getattr(event, "media_urls", None):
|
||||
text = _build_media_placeholder(event)
|
||||
return text
|
||||
|
||||
|
||||
def _check_unavailable_skill(command_name: str) -> str | None:
|
||||
"""Check if a command matches a known-but-inactive skill.
|
||||
|
||||
@@ -448,14 +411,10 @@ def _resolve_hermes_bin() -> Optional[list[str]]:
|
||||
class GatewayRunner:
|
||||
"""
|
||||
Main gateway controller.
|
||||
|
||||
|
||||
Manages the lifecycle of all platform adapters and routes
|
||||
messages to/from the agent.
|
||||
"""
|
||||
|
||||
# Class-level defaults so partial construction in tests doesn't
|
||||
# blow up on attribute access.
|
||||
_running_agents_ts: Dict[str, float] = {}
|
||||
|
||||
def __init__(self, config: Optional[GatewayConfig] = None):
|
||||
self.config = config or load_gateway_config()
|
||||
@@ -487,7 +446,6 @@ class GatewayRunner:
|
||||
# Track running agents per session for interrupt support
|
||||
# Key: session_key, Value: AIAgent instance
|
||||
self._running_agents: Dict[str, Any] = {}
|
||||
self._running_agents_ts: Dict[str, float] = {} # start timestamp per session
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
|
||||
# Cache AIAgent instances per session to preserve prompt caching.
|
||||
@@ -1740,20 +1698,6 @@ class GatewayRunner:
|
||||
# simultaneous updates. Do NOT interrupt for photo-only follow-ups here;
|
||||
# let the adapter-level batching/queueing logic absorb them.
|
||||
_quick_key = self._session_key_for_source(source)
|
||||
|
||||
# Staleness eviction: if an entry has been in _running_agents for
|
||||
# longer than the agent timeout, it's a leaked lock from a hung or
|
||||
# crashed handler. Evict it so the session isn't permanently stuck.
|
||||
_STALE_TTL = float(os.getenv("HERMES_AGENT_TIMEOUT", 600)) + 60 # timeout + 1 min grace
|
||||
_stale_ts = self._running_agents_ts.get(_quick_key, 0)
|
||||
if _quick_key in self._running_agents and _stale_ts and (time.time() - _stale_ts) > _STALE_TTL:
|
||||
logger.warning(
|
||||
"Evicting stale _running_agents entry for %s (age: %.0fs)",
|
||||
_quick_key[:30], time.time() - _stale_ts,
|
||||
)
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
|
||||
if _quick_key in self._running_agents:
|
||||
if event.get_command() == "status":
|
||||
return await self._handle_status_command(event)
|
||||
@@ -2079,7 +2023,6 @@ class GatewayRunner:
|
||||
# "already running" guard and spin up a duplicate agent for the
|
||||
# same session — corrupting the transcript.
|
||||
self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL
|
||||
self._running_agents_ts[_quick_key] = time.time()
|
||||
|
||||
try:
|
||||
return await self._handle_message_with_agent(event, source, _quick_key)
|
||||
@@ -2090,7 +2033,6 @@ class GatewayRunner:
|
||||
# not linger or the session would be permanently locked out.
|
||||
if self._running_agents.get(_quick_key) is _AGENT_PENDING_SENTINEL:
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
|
||||
async def _handle_message_with_agent(self, event, source, _quick_key: str):
|
||||
"""Inner handler that runs under the _running_agents sentinel guard."""
|
||||
@@ -5442,13 +5384,11 @@ class GatewayRunner:
|
||||
progress_lines = [] # Accumulated tool lines
|
||||
progress_msg_id = None # ID of the progress message to edit
|
||||
can_edit = True # False once an edit fails (platform doesn't support it)
|
||||
_last_edit_ts = 0.0 # Throttle edits to avoid Telegram flood control
|
||||
_PROGRESS_EDIT_INTERVAL = 1.5 # Minimum seconds between edits
|
||||
|
||||
while True:
|
||||
try:
|
||||
raw = progress_queue.get_nowait()
|
||||
|
||||
|
||||
# Handle dedup messages: update last line with repeat counter
|
||||
if isinstance(raw, tuple) and len(raw) == 3 and raw[0] == "__dedup__":
|
||||
_, base_msg, count = raw
|
||||
@@ -5459,19 +5399,6 @@ class GatewayRunner:
|
||||
msg = raw
|
||||
progress_lines.append(msg)
|
||||
|
||||
# Throttle edits: batch rapid tool updates into fewer
|
||||
# API calls to avoid hitting Telegram flood control.
|
||||
# (grammY auto-retry pattern: proactively rate-limit
|
||||
# instead of reacting to 429s.)
|
||||
_now = time.monotonic()
|
||||
_remaining = _PROGRESS_EDIT_INTERVAL - (_now - _last_edit_ts)
|
||||
if _remaining > 0:
|
||||
# Wait out the throttle interval, then loop back to
|
||||
# drain any additional queued messages before sending
|
||||
# a single batched edit.
|
||||
await asyncio.sleep(_remaining)
|
||||
continue
|
||||
|
||||
if can_edit and progress_msg_id is not None:
|
||||
# Try to edit the existing progress message
|
||||
full_text = "\n".join(progress_lines)
|
||||
@@ -5481,15 +5408,8 @@ class GatewayRunner:
|
||||
content=full_text,
|
||||
)
|
||||
if not result.success:
|
||||
_err = (getattr(result, "error", "") or "").lower()
|
||||
if "flood" in _err or "retry after" in _err:
|
||||
# Flood control hit — disable further edits,
|
||||
# switch to sending new messages only for
|
||||
# important updates. Don't block 23s.
|
||||
logger.info(
|
||||
"[%s] Progress edits disabled due to flood control",
|
||||
adapter.name,
|
||||
)
|
||||
# Platform doesn't support editing — stop trying,
|
||||
# send just this new line as a separate message
|
||||
can_edit = False
|
||||
await adapter.send(chat_id=source.chat_id, content=msg, metadata=_progress_metadata)
|
||||
else:
|
||||
@@ -5503,8 +5423,6 @@ class GatewayRunner:
|
||||
if result.success and result.message_id:
|
||||
progress_msg_id = result.message_id
|
||||
|
||||
_last_edit_ts = time.monotonic()
|
||||
|
||||
# Restore typing indicator
|
||||
await asyncio.sleep(0.3)
|
||||
await adapter.send_typing(source.chat_id, metadata=_progress_metadata)
|
||||
@@ -5550,25 +5468,15 @@ class GatewayRunner:
|
||||
_loop_for_step = asyncio.get_event_loop()
|
||||
_hooks_ref = self.hooks
|
||||
|
||||
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
||||
def _step_callback_sync(iteration: int, tool_names: list) -> None:
|
||||
try:
|
||||
# prev_tools may be list[str] or list[dict] with "name"/"result"
|
||||
# keys. Normalise to keep "tool_names" backward-compatible for
|
||||
# user-authored hooks that do ', '.join(tool_names)'.
|
||||
_names: list[str] = []
|
||||
for _t in (prev_tools or []):
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_hooks_ref.emit("agent:step", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
"user_id": source.user_id,
|
||||
"session_id": session_id,
|
||||
"iteration": iteration,
|
||||
"tool_names": _names,
|
||||
"tools": prev_tools,
|
||||
"tool_names": tool_names,
|
||||
}),
|
||||
_loop_for_step,
|
||||
)
|
||||
@@ -6025,38 +5933,9 @@ class GatewayRunner:
|
||||
interrupt_monitor = asyncio.create_task(monitor_for_interrupt())
|
||||
|
||||
try:
|
||||
# Run in thread pool to not block. Cap total execution time
|
||||
# so a hung API call or runaway tool doesn't permanently lock
|
||||
# the session. Default 10 minutes; override with env var.
|
||||
_agent_timeout = float(os.getenv("HERMES_AGENT_TIMEOUT", 600))
|
||||
# Run in thread pool to not block
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, run_sync),
|
||||
timeout=_agent_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
"Agent execution timed out after %.0fs for session %s",
|
||||
_agent_timeout, session_key,
|
||||
)
|
||||
# Interrupt the agent if it's still running so the thread
|
||||
# pool worker is freed.
|
||||
_timed_out_agent = agent_holder[0]
|
||||
if _timed_out_agent and hasattr(_timed_out_agent, "interrupt"):
|
||||
_timed_out_agent.interrupt("Execution timed out")
|
||||
response = {
|
||||
"final_response": (
|
||||
f"⏱️ Request timed out after {int(_agent_timeout // 60)} minutes. "
|
||||
"The agent may have been stuck on a tool or API call.\n"
|
||||
"Try again, or use /reset to start fresh."
|
||||
),
|
||||
"messages": result_holder[0].get("messages", []) if result_holder[0] else [],
|
||||
"api_calls": 0,
|
||||
"tools": tools_holder[0] or [],
|
||||
"history_offset": 0,
|
||||
"failed": True,
|
||||
}
|
||||
response = await loop.run_in_executor(None, run_sync)
|
||||
|
||||
# Track fallback model state: if the agent switched to a
|
||||
# fallback model during this run, persist it so /model shows
|
||||
@@ -6084,12 +5963,18 @@ class GatewayRunner:
|
||||
pending = None
|
||||
if result and adapter and session_key:
|
||||
if result.get("interrupted"):
|
||||
pending = _dequeue_pending_text(adapter, session_key)
|
||||
if not pending and result.get("interrupt_message"):
|
||||
# Interrupted — consume the interrupt message
|
||||
pending_event = adapter.get_pending_message(session_key)
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
elif result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
else:
|
||||
pending = _dequeue_pending_text(adapter, session_key)
|
||||
if pending:
|
||||
# Normal completion — check for /queue'd messages that were
|
||||
# stored without triggering an interrupt.
|
||||
pending_event = adapter.get_pending_message(session_key)
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||
|
||||
if pending:
|
||||
@@ -6165,8 +6050,6 @@ class GatewayRunner:
|
||||
tracking_task.cancel()
|
||||
if session_key and session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
if session_key:
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
|
||||
# Wait for cancelled tasks
|
||||
for task in [progress_task, interrupt_monitor, tracking_task]:
|
||||
|
||||
@@ -174,12 +174,12 @@ class GatewayStreamConsumer:
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
else:
|
||||
# If an edit fails mid-stream (especially Telegram flood control),
|
||||
# stop progressive edits and let the normal final send path deliver
|
||||
# the complete answer instead of leaving the user with a partial.
|
||||
# Edit not supported by this adapter — stop streaming,
|
||||
# let the normal send path handle the final response.
|
||||
# Without this guard, adapters like Signal/Email would
|
||||
# flood the chat with a new message every edit_interval.
|
||||
logger.debug("Edit failed, disabling streaming for this adapter")
|
||||
self._edit_supported = False
|
||||
self._already_sent = False
|
||||
else:
|
||||
# Editing not supported — skip intermediate updates.
|
||||
# The final response will be sent by the normal path.
|
||||
|
||||
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
@@ -109,9 +108,6 @@ CONCLUDE_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
ALL_TOOL_SCHEMAS = [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, CONCLUDE_SCHEMA]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -128,34 +124,6 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
self._prefetch_thread: Optional[threading.Thread] = None
|
||||
self._sync_thread: Optional[threading.Thread] = None
|
||||
|
||||
# B1: recall_mode — set during initialize from config
|
||||
self._recall_mode = "hybrid" # "context", "tools", or "hybrid"
|
||||
|
||||
# B4: First-turn context baking
|
||||
self._first_turn_context: Optional[str] = None
|
||||
self._first_turn_lock = threading.Lock()
|
||||
|
||||
# B5: Cost-awareness turn counting and cadence
|
||||
self._turn_count = 0
|
||||
self._injection_frequency = "every-turn" # or "first-turn"
|
||||
self._context_cadence = 1 # minimum turns between context API calls
|
||||
self._dialectic_cadence = 1 # minimum turns between dialectic API calls
|
||||
self._reasoning_level_cap: Optional[str] = None # "minimal", "low", "mid", "high"
|
||||
self._last_context_turn = -999
|
||||
self._last_dialectic_turn = -999
|
||||
|
||||
# B2: peer_memory_mode gating (stub)
|
||||
self._suppress_memory = False
|
||||
self._suppress_user_profile = False
|
||||
|
||||
# Port #1957: lazy session init for tools-only mode
|
||||
self._session_initialized = False
|
||||
self._lazy_init_kwargs: Optional[dict] = None
|
||||
self._lazy_init_session_id: Optional[str] = None
|
||||
|
||||
# Port #4053: cron guard — when True, plugin is fully inactive
|
||||
self._cron_skipped = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "honcho"
|
||||
@@ -165,7 +133,6 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
try:
|
||||
from plugins.memory.honcho.client import HonchoClientConfig
|
||||
cfg = HonchoClientConfig.from_global_config()
|
||||
# Port #2645: baseUrl-only verification — api_key OR base_url suffices
|
||||
return cfg.enabled and bool(cfg.api_key or cfg.base_url)
|
||||
except Exception:
|
||||
return False
|
||||
@@ -191,22 +158,8 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
]
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
"""Initialize Honcho session manager.
|
||||
|
||||
Handles: cron guard, recall_mode, session name resolution,
|
||||
peer memory mode, SOUL.md ai_peer sync, memory file migration,
|
||||
and pre-warming context at init.
|
||||
"""
|
||||
"""Initialize Honcho session manager."""
|
||||
try:
|
||||
# ----- Port #4053: cron guard -----
|
||||
agent_context = kwargs.get("agent_context", "")
|
||||
platform = kwargs.get("platform", "cli")
|
||||
if agent_context in ("cron", "flush") or platform == "cron":
|
||||
logger.debug("Honcho skipped: cron/flush context (agent_context=%s, platform=%s)",
|
||||
agent_context, platform)
|
||||
self._cron_skipped = True
|
||||
return
|
||||
|
||||
from plugins.memory.honcho.client import HonchoClientConfig, get_honcho_client
|
||||
from plugins.memory.honcho.session import HonchoSessionManager
|
||||
|
||||
@@ -216,78 +169,20 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
return
|
||||
|
||||
self._config = cfg
|
||||
client = get_honcho_client(cfg)
|
||||
self._manager = HonchoSessionManager(
|
||||
honcho=client,
|
||||
config=cfg,
|
||||
context_tokens=cfg.context_tokens,
|
||||
)
|
||||
|
||||
# ----- B1: recall_mode from config -----
|
||||
self._recall_mode = cfg.recall_mode # "context", "tools", or "hybrid"
|
||||
logger.debug("Honcho recall_mode: %s", self._recall_mode)
|
||||
|
||||
# ----- B5: cost-awareness config -----
|
||||
try:
|
||||
raw = cfg.raw or {}
|
||||
self._injection_frequency = raw.get("injectionFrequency", "every-turn")
|
||||
self._context_cadence = int(raw.get("contextCadence", 1))
|
||||
self._dialectic_cadence = int(raw.get("dialecticCadence", 1))
|
||||
cap = raw.get("reasoningLevelCap")
|
||||
if cap and cap in ("minimal", "low", "mid", "high"):
|
||||
self._reasoning_level_cap = cap
|
||||
except Exception as e:
|
||||
logger.debug("Honcho cost-awareness config parse error: %s", e)
|
||||
|
||||
# ----- Port #1969: aiPeer sync from SOUL.md -----
|
||||
try:
|
||||
hermes_home = kwargs.get("hermes_home", "")
|
||||
if hermes_home and not cfg.raw.get("aiPeer"):
|
||||
soul_path = Path(hermes_home) / "SOUL.md"
|
||||
if soul_path.exists():
|
||||
soul_text = soul_path.read_text(encoding="utf-8").strip()
|
||||
if soul_text:
|
||||
# Try YAML frontmatter: "name: Foo"
|
||||
first_line = soul_text.split("\n")[0].strip()
|
||||
if first_line.startswith("---"):
|
||||
# Look for name: in frontmatter
|
||||
for line in soul_text.split("\n")[1:]:
|
||||
line = line.strip()
|
||||
if line == "---":
|
||||
break
|
||||
if line.lower().startswith("name:"):
|
||||
name_val = line.split(":", 1)[1].strip().strip("\"'")
|
||||
if name_val:
|
||||
cfg.ai_peer = name_val
|
||||
logger.debug("Honcho ai_peer set from SOUL.md: %s", name_val)
|
||||
break
|
||||
elif first_line.startswith("# "):
|
||||
# Markdown heading: "# AgentName"
|
||||
name_val = first_line[2:].strip()
|
||||
if name_val:
|
||||
cfg.ai_peer = name_val
|
||||
logger.debug("Honcho ai_peer set from SOUL.md heading: %s", name_val)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho SOUL.md ai_peer sync failed: %s", e)
|
||||
|
||||
# ----- B2: peer_memory_mode gating (stub) -----
|
||||
try:
|
||||
ai_mode = cfg.peer_memory_mode(cfg.ai_peer)
|
||||
user_mode = cfg.peer_memory_mode(cfg.peer_name or "user")
|
||||
# "honcho" means Honcho owns memory; suppress built-in
|
||||
self._suppress_memory = (ai_mode == "honcho")
|
||||
self._suppress_user_profile = (user_mode == "honcho")
|
||||
logger.debug("Honcho peer_memory_mode: ai=%s (suppress_memory=%s), user=%s (suppress_user_profile=%s)",
|
||||
ai_mode, self._suppress_memory, user_mode, self._suppress_user_profile)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho peer_memory_mode check failed: %s", e)
|
||||
|
||||
# ----- Port #1957: lazy session init for tools-only mode -----
|
||||
if self._recall_mode == "tools":
|
||||
# Defer actual session creation until first tool call
|
||||
self._lazy_init_kwargs = kwargs
|
||||
self._lazy_init_session_id = session_id
|
||||
# Still need a client reference for _ensure_session
|
||||
self._config = cfg
|
||||
logger.debug("Honcho tools-only mode — deferring session init until first tool call")
|
||||
return
|
||||
|
||||
# ----- Eager init (context or hybrid mode) -----
|
||||
self._do_session_init(cfg, session_id, **kwargs)
|
||||
# Build session key from kwargs or session_id
|
||||
platform = kwargs.get("platform", "cli")
|
||||
user_id = kwargs.get("user_id", "")
|
||||
if user_id:
|
||||
self._session_key = f"{platform}:{user_id}"
|
||||
else:
|
||||
self._session_key = session_id
|
||||
|
||||
except ImportError:
|
||||
logger.debug("honcho-ai package not installed — plugin inactive")
|
||||
@@ -295,180 +190,19 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
logger.warning("Honcho init failed: %s", e)
|
||||
self._manager = None
|
||||
|
||||
def _do_session_init(self, cfg, session_id: str, **kwargs) -> None:
|
||||
"""Shared session initialization logic for both eager and lazy paths."""
|
||||
from plugins.memory.honcho.client import get_honcho_client
|
||||
from plugins.memory.honcho.session import HonchoSessionManager
|
||||
|
||||
client = get_honcho_client(cfg)
|
||||
self._manager = HonchoSessionManager(
|
||||
honcho=client,
|
||||
config=cfg,
|
||||
context_tokens=cfg.context_tokens,
|
||||
)
|
||||
|
||||
# ----- B3: resolve_session_name -----
|
||||
session_title = kwargs.get("session_title")
|
||||
self._session_key = (
|
||||
cfg.resolve_session_name(session_title=session_title, session_id=session_id)
|
||||
or session_id
|
||||
or "hermes-default"
|
||||
)
|
||||
logger.debug("Honcho session key resolved: %s", self._session_key)
|
||||
|
||||
# Create session eagerly
|
||||
session = self._manager.get_or_create(self._session_key)
|
||||
self._session_initialized = True
|
||||
|
||||
# ----- B6: Memory file migration (one-time, for new sessions) -----
|
||||
try:
|
||||
if not session.messages:
|
||||
from hermes_constants import get_hermes_home
|
||||
mem_dir = str(get_hermes_home() / "memories")
|
||||
self._manager.migrate_memory_files(self._session_key, mem_dir)
|
||||
logger.debug("Honcho memory file migration attempted for new session: %s", self._session_key)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho memory file migration skipped: %s", e)
|
||||
|
||||
# ----- B7: Pre-warming context at init -----
|
||||
if self._recall_mode in ("context", "hybrid"):
|
||||
try:
|
||||
self._manager.prefetch_context(self._session_key)
|
||||
self._manager.prefetch_dialectic(self._session_key, "What should I know about this user?")
|
||||
logger.debug("Honcho pre-warm threads started for session: %s", self._session_key)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho pre-warm failed: %s", e)
|
||||
|
||||
def _ensure_session(self) -> bool:
|
||||
"""Lazily initialize the Honcho session (for tools-only mode).
|
||||
|
||||
Returns True if the manager is ready, False otherwise.
|
||||
"""
|
||||
if self._manager and self._session_initialized:
|
||||
return True
|
||||
if self._cron_skipped:
|
||||
return False
|
||||
if not self._config or not self._lazy_init_kwargs:
|
||||
return False
|
||||
|
||||
try:
|
||||
self._do_session_init(
|
||||
self._config,
|
||||
self._lazy_init_session_id or "hermes-default",
|
||||
**self._lazy_init_kwargs,
|
||||
)
|
||||
# Clear lazy refs
|
||||
self._lazy_init_kwargs = None
|
||||
self._lazy_init_session_id = None
|
||||
return self._manager is not None
|
||||
except Exception as e:
|
||||
logger.warning("Honcho lazy session init failed: %s", e)
|
||||
return False
|
||||
|
||||
def _format_first_turn_context(self, ctx: dict) -> str:
|
||||
"""Format the prefetch context dict into a readable system prompt block."""
|
||||
parts = []
|
||||
|
||||
rep = ctx.get("representation", "")
|
||||
if rep:
|
||||
parts.append(f"## User Representation\n{rep}")
|
||||
|
||||
card = ctx.get("card", "")
|
||||
if card:
|
||||
parts.append(f"## User Peer Card\n{card}")
|
||||
|
||||
ai_rep = ctx.get("ai_representation", "")
|
||||
if ai_rep:
|
||||
parts.append(f"## AI Self-Representation\n{ai_rep}")
|
||||
|
||||
ai_card = ctx.get("ai_card", "")
|
||||
if ai_card:
|
||||
parts.append(f"## AI Identity Card\n{ai_card}")
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
"""Return system prompt text, adapted by recall_mode.
|
||||
|
||||
B4: On the FIRST call, fetch and bake the full Honcho context
|
||||
(user representation, peer card, AI representation, continuity synthesis).
|
||||
Subsequent calls return the cached block for prompt caching stability.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return ""
|
||||
if not self._manager or not self._session_key:
|
||||
# tools-only mode without session yet still returns a minimal block
|
||||
if self._recall_mode == "tools" and self._config:
|
||||
return (
|
||||
"# Honcho Memory\n"
|
||||
"Active (tools-only mode). Use honcho_profile, honcho_search, "
|
||||
"honcho_context, and honcho_conclude tools to access user memory."
|
||||
)
|
||||
return ""
|
||||
|
||||
# ----- B4: First-turn context baking -----
|
||||
first_turn_block = ""
|
||||
if self._recall_mode in ("context", "hybrid"):
|
||||
with self._first_turn_lock:
|
||||
if self._first_turn_context is None:
|
||||
# First call — fetch and cache
|
||||
try:
|
||||
ctx = self._manager.get_prefetch_context(self._session_key)
|
||||
self._first_turn_context = self._format_first_turn_context(ctx) if ctx else ""
|
||||
except Exception as e:
|
||||
logger.debug("Honcho first-turn context fetch failed: %s", e)
|
||||
self._first_turn_context = ""
|
||||
first_turn_block = self._first_turn_context
|
||||
|
||||
# ----- B1: adapt text based on recall_mode -----
|
||||
if self._recall_mode == "context":
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (context-injection mode). Relevant user context is automatically "
|
||||
"injected before each turn. No memory tools are available — context is "
|
||||
"managed automatically."
|
||||
)
|
||||
elif self._recall_mode == "tools":
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (tools-only mode). Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user. "
|
||||
"No automatic context injection — you must use tools to access memory."
|
||||
)
|
||||
else: # hybrid
|
||||
header = (
|
||||
"# Honcho Memory\n"
|
||||
"Active (hybrid mode). Relevant context is auto-injected AND memory tools are available. "
|
||||
"Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user."
|
||||
)
|
||||
|
||||
if first_turn_block:
|
||||
return f"{header}\n\n{first_turn_block}"
|
||||
return header
|
||||
return (
|
||||
"# Honcho Memory\n"
|
||||
"Active. AI-native cross-session user modeling.\n"
|
||||
"Use honcho_profile for a quick factual snapshot, "
|
||||
"honcho_search for raw excerpts, honcho_context for synthesized answers, "
|
||||
"honcho_conclude to save facts about the user."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
"""Return prefetched dialectic context from background thread.
|
||||
|
||||
B1: Returns empty when recall_mode is "tools" (no injection).
|
||||
B5: Respects injection_frequency — "first-turn" returns cached/empty after turn 0.
|
||||
Port #3265: Truncates to context_tokens budget.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return ""
|
||||
|
||||
# B1: tools-only mode — no auto-injection
|
||||
if self._recall_mode == "tools":
|
||||
return ""
|
||||
|
||||
# B5: injection_frequency — if "first-turn" and past first turn, return empty
|
||||
if self._injection_frequency == "first-turn" and self._turn_count > 0:
|
||||
return ""
|
||||
|
||||
"""Return prefetched dialectic context from background thread."""
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
@@ -476,49 +210,13 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
return ""
|
||||
|
||||
# ----- Port #3265: token budget enforcement -----
|
||||
result = self._truncate_to_budget(result)
|
||||
|
||||
return f"## Honcho Context\n{result}"
|
||||
|
||||
def _truncate_to_budget(self, text: str) -> str:
|
||||
"""Truncate text to fit within context_tokens budget if set."""
|
||||
if not self._config or not self._config.context_tokens:
|
||||
return text
|
||||
budget_chars = self._config.context_tokens * 4 # conservative char estimate
|
||||
if len(text) <= budget_chars:
|
||||
return text
|
||||
# Truncate at word boundary
|
||||
truncated = text[:budget_chars]
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > budget_chars * 0.8:
|
||||
truncated = truncated[:last_space]
|
||||
return truncated + " …"
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
"""Fire a background dialectic query for the upcoming turn.
|
||||
|
||||
B5: Checks cadence before firing background threads.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
"""Fire a background dialectic query for the upcoming turn."""
|
||||
if not self._manager or not self._session_key or not query:
|
||||
return
|
||||
|
||||
# B1: tools-only mode — no prefetch
|
||||
if self._recall_mode == "tools":
|
||||
return
|
||||
|
||||
# B5: cadence check — skip if too soon since last dialectic call
|
||||
if self._dialectic_cadence > 1:
|
||||
if (self._turn_count - self._last_dialectic_turn) < self._dialectic_cadence:
|
||||
logger.debug("Honcho dialectic prefetch skipped: cadence %d, turns since last: %d",
|
||||
self._dialectic_cadence, self._turn_count - self._last_dialectic_turn)
|
||||
return
|
||||
|
||||
self._last_dialectic_turn = self._turn_count
|
||||
|
||||
def _run():
|
||||
try:
|
||||
result = self._manager.dialectic_query(
|
||||
@@ -535,28 +233,14 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
)
|
||||
self._prefetch_thread.start()
|
||||
|
||||
# Also fire context prefetch if cadence allows
|
||||
if self._context_cadence <= 1 or (self._turn_count - self._last_context_turn) >= self._context_cadence:
|
||||
self._last_context_turn = self._turn_count
|
||||
try:
|
||||
self._manager.prefetch_context(self._session_key, query)
|
||||
except Exception as e:
|
||||
logger.debug("Honcho context prefetch failed: %s", e)
|
||||
|
||||
def on_turn_start(self, turn_number: int, message: str, **kwargs) -> None:
|
||||
"""Track turn count for cadence and injection_frequency logic."""
|
||||
self._turn_count = turn_number
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Record the conversation turn in Honcho (non-blocking)."""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager or not self._session_key:
|
||||
return
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
session = self._manager.get_or_create(self._session_key)
|
||||
session = self._manager.get_or_create_session(self._session_key)
|
||||
session.add_message("user", user_content[:4000])
|
||||
session.add_message("assistant", assistant_content[:4000])
|
||||
# Flush to Honcho API
|
||||
@@ -575,8 +259,6 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
"""Mirror built-in user profile writes as Honcho conclusions."""
|
||||
if action != "add" or target != "user" or not content:
|
||||
return
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager or not self._session_key:
|
||||
return
|
||||
|
||||
@@ -591,8 +273,6 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
|
||||
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Flush all pending messages to Honcho on session end."""
|
||||
if self._cron_skipped:
|
||||
return
|
||||
if not self._manager:
|
||||
return
|
||||
# Wait for pending sync
|
||||
@@ -604,26 +284,9 @@ class HonchoMemoryProvider(MemoryProvider):
|
||||
logger.debug("Honcho session-end flush failed: %s", e)
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""Return tool schemas, respecting recall_mode.
|
||||
|
||||
B1: context-only mode hides all tools.
|
||||
"""
|
||||
if self._cron_skipped:
|
||||
return []
|
||||
if self._recall_mode == "context":
|
||||
return []
|
||||
return list(ALL_TOOL_SCHEMAS)
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, CONCLUDE_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
"""Handle a Honcho tool call, with lazy session init for tools-only mode."""
|
||||
if self._cron_skipped:
|
||||
return json.dumps({"error": "Honcho is not active (cron context)."})
|
||||
|
||||
# Port #1957: ensure session is initialized for tools-only mode
|
||||
if not self._session_initialized:
|
||||
if not self._ensure_session():
|
||||
return json.dumps({"error": "Honcho session could not be initialized."})
|
||||
|
||||
if not self._manager or not self._session_key:
|
||||
return json.dumps({"error": "Honcho is not active for this session."})
|
||||
|
||||
|
||||
@@ -85,16 +85,6 @@ def _normalize_recall_mode(val: str) -> str:
|
||||
return val if val in _VALID_RECALL_MODES else "hybrid"
|
||||
|
||||
|
||||
_VALID_OBSERVATION_MODES = {"unified", "directional"}
|
||||
_OBSERVATION_MODE_ALIASES = {"shared": "unified", "separate": "directional", "cross": "directional"}
|
||||
|
||||
|
||||
def _normalize_observation_mode(val: str) -> str:
|
||||
"""Normalize observation mode values."""
|
||||
val = _OBSERVATION_MODE_ALIASES.get(val, val)
|
||||
return val if val in _VALID_OBSERVATION_MODES else "unified"
|
||||
|
||||
|
||||
def _resolve_memory_mode(
|
||||
global_val: str | dict,
|
||||
host_val: str | dict | None,
|
||||
@@ -164,10 +154,6 @@ class HonchoClientConfig:
|
||||
# "context" — auto-injected context only, Honcho tools removed
|
||||
# "tools" — Honcho tools only, no auto-injected context
|
||||
recall_mode: str = "hybrid"
|
||||
# Observation mode: how Honcho peers observe each other.
|
||||
# "unified" — user peer observes self; all agents share one observation pool
|
||||
# "directional" — AI peer observes user; each agent keeps its own view
|
||||
observation_mode: str = "unified"
|
||||
# Session resolution
|
||||
session_strategy: str = "per-directory"
|
||||
session_peer_prefix: bool = False
|
||||
@@ -327,11 +313,6 @@ class HonchoClientConfig:
|
||||
or raw.get("recallMode")
|
||||
or "hybrid"
|
||||
),
|
||||
observation_mode=_normalize_observation_mode(
|
||||
host_block.get("observationMode")
|
||||
or raw.get("observationMode")
|
||||
or "unified"
|
||||
),
|
||||
session_strategy=session_strategy,
|
||||
session_peer_prefix=session_peer_prefix,
|
||||
sessions=raw.get("sessions", {}),
|
||||
|
||||
@@ -110,9 +110,6 @@ class HonchoSessionManager:
|
||||
self._dialectic_max_chars: int = (
|
||||
config.dialectic_max_chars if config else 600
|
||||
)
|
||||
self._observation_mode: str = (
|
||||
config.observation_mode if config else "unified"
|
||||
)
|
||||
|
||||
# Async write queue — started lazily on first enqueue
|
||||
self._async_queue: queue.Queue | None = None
|
||||
@@ -162,18 +159,13 @@ class HonchoSessionManager:
|
||||
|
||||
session = self.honcho.session(session_id)
|
||||
|
||||
# Configure peer observation settings based on observation_mode.
|
||||
# Unified: user peer observes self, AI peer passive — all agents share
|
||||
# one observation pool via user self-observations.
|
||||
# Directional: AI peer observes user — each agent keeps its own view.
|
||||
# Configure peer observation settings.
|
||||
# observe_me=True for AI peer so Honcho watches what the agent says
|
||||
# and builds its representation over time — enabling identity formation.
|
||||
try:
|
||||
from honcho.session import SessionPeerConfig
|
||||
if self._observation_mode == "directional":
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=False)
|
||||
ai_config = SessionPeerConfig(observe_me=False, observe_others=True)
|
||||
else: # unified (default)
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=False)
|
||||
ai_config = SessionPeerConfig(observe_me=False, observe_others=False)
|
||||
user_config = SessionPeerConfig(observe_me=True, observe_others=True)
|
||||
ai_config = SessionPeerConfig(observe_me=True, observe_others=True)
|
||||
|
||||
session.add_peers([(user_peer, user_config), (assistant_peer, ai_config)])
|
||||
except Exception as e:
|
||||
@@ -501,27 +493,12 @@ class HonchoSessionManager:
|
||||
if not session:
|
||||
return ""
|
||||
|
||||
peer_id = session.assistant_peer_id if peer == "ai" else session.user_peer_id
|
||||
target_peer = self._get_or_create_peer(peer_id)
|
||||
level = reasoning_level or self._dynamic_reasoning_level(query)
|
||||
|
||||
try:
|
||||
if self._observation_mode == "directional":
|
||||
# AI peer queries about the user (cross-observation)
|
||||
if peer == "ai":
|
||||
ai_peer_obj = self._get_or_create_peer(session.assistant_peer_id)
|
||||
result = ai_peer_obj.chat(query, reasoning_level=level) or ""
|
||||
else:
|
||||
ai_peer_obj = self._get_or_create_peer(session.assistant_peer_id)
|
||||
result = ai_peer_obj.chat(
|
||||
query,
|
||||
target=session.user_peer_id,
|
||||
reasoning_level=level,
|
||||
) or ""
|
||||
else:
|
||||
# Unified: user peer queries self, or AI peer queries self
|
||||
peer_id = session.assistant_peer_id if peer == "ai" else session.user_peer_id
|
||||
target_peer = self._get_or_create_peer(peer_id)
|
||||
result = target_peer.chat(query, reasoning_level=level) or ""
|
||||
|
||||
result = target_peer.chat(query, reasoning_level=level) or ""
|
||||
# Apply Hermes-side char cap before caching
|
||||
if result and self._dialectic_max_chars and len(result) > self._dialectic_max_chars:
|
||||
result = result[:self._dialectic_max_chars].rsplit(" ", 1)[0] + " …"
|
||||
@@ -918,16 +895,9 @@ class HonchoSessionManager:
|
||||
logger.warning("No session cached for '%s', skipping conclusion", session_key)
|
||||
return False
|
||||
|
||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||
try:
|
||||
if self._observation_mode == "directional":
|
||||
# AI peer creates conclusion about user (cross-observation)
|
||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||
conclusions_scope = assistant_peer.conclusions_of(session.user_peer_id)
|
||||
else:
|
||||
# Unified: user peer creates self-conclusion
|
||||
user_peer = self._get_or_create_peer(session.user_peer_id)
|
||||
conclusions_scope = user_peer.conclusions_of(session.user_peer_id)
|
||||
|
||||
conclusions_scope = assistant_peer.conclusions_of(session.user_peer_id)
|
||||
conclusions_scope.create([{
|
||||
"content": content.strip(),
|
||||
"session_id": session.honcho_session_id,
|
||||
|
||||
179
run_agent.py
179
run_agent.py
@@ -375,58 +375,6 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
return found
|
||||
|
||||
|
||||
def _normalize_structured_content(content) -> tuple:
|
||||
"""Normalize Mistral-style structured content blocks to (text, reasoning).
|
||||
|
||||
Mistral's Magistral models (and mistral-large-2512+) return ``content`` as
|
||||
a list of typed blocks instead of a plain string::
|
||||
|
||||
[{"type": "thinking", "thinking": [{"type": "text", "text": "..."}]},
|
||||
{"type": "text", "text": "final answer"},
|
||||
{"type": "reference", ...}]
|
||||
|
||||
This also appears in streaming deltas (``delta.content`` is a list).
|
||||
|
||||
Returns:
|
||||
(text_content, thinking_content) — text is always a string (possibly
|
||||
empty), thinking is a string or None.
|
||||
"""
|
||||
if content is None:
|
||||
return ("", None)
|
||||
if isinstance(content, str):
|
||||
return (content, None)
|
||||
if not isinstance(content, list):
|
||||
return (str(content), None)
|
||||
|
||||
text_parts: list = []
|
||||
thinking_parts: list = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
continue
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
block_type = block.get("type", "")
|
||||
if block_type == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif block_type == "thinking":
|
||||
# "thinking" is itself a list of text blocks
|
||||
thinking = block.get("thinking", [])
|
||||
if isinstance(thinking, list):
|
||||
for t in thinking:
|
||||
if isinstance(t, dict) and t.get("type") == "text":
|
||||
thinking_parts.append(t.get("text", ""))
|
||||
elif isinstance(t, str):
|
||||
thinking_parts.append(t)
|
||||
elif isinstance(thinking, str):
|
||||
thinking_parts.append(thinking)
|
||||
# Other types (reference, image, document, audio, file) are skipped.
|
||||
|
||||
text = "\n".join(p for p in text_parts if p)
|
||||
thinking = "\n\n".join(p for p in thinking_parts if p) or None
|
||||
return (text, thinking)
|
||||
|
||||
|
||||
def _strip_budget_warnings_from_history(messages: list) -> None:
|
||||
"""Remove budget pressure warnings from tool-result messages in-place.
|
||||
|
||||
@@ -4159,43 +4107,30 @@ class AIAgent:
|
||||
_fire_first_delta()
|
||||
self._fire_reasoning_delta(reasoning_text)
|
||||
|
||||
# Accumulate text content — fire callback only when no tool calls.
|
||||
# Mistral Magistral models return delta.content as a list of
|
||||
# structured blocks instead of a plain string; normalize first.
|
||||
# Accumulate text content — fire callback only when no tool calls
|
||||
if delta and delta.content:
|
||||
_raw_delta_content = delta.content
|
||||
if isinstance(_raw_delta_content, list):
|
||||
_delta_text, _delta_thinking = _normalize_structured_content(_raw_delta_content)
|
||||
if _delta_thinking:
|
||||
reasoning_parts.append(_delta_thinking)
|
||||
_fire_first_delta()
|
||||
self._fire_reasoning_delta(_delta_thinking)
|
||||
content_parts.append(delta.content)
|
||||
if not tool_calls_acc:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(delta.content)
|
||||
deltas_were_sent["yes"] = True
|
||||
else:
|
||||
_delta_text = _raw_delta_content
|
||||
|
||||
if _delta_text:
|
||||
content_parts.append(_delta_text)
|
||||
if not tool_calls_acc:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(_delta_text)
|
||||
deltas_were_sent["yes"] = True
|
||||
else:
|
||||
# Tool calls suppress regular content streaming (avoids
|
||||
# displaying chatty "I'll use the tool..." text alongside
|
||||
# tool calls). But reasoning tags embedded in suppressed
|
||||
# content should still reach the display — otherwise the
|
||||
# reasoning box only appears as a post-response fallback,
|
||||
# rendering it confusingly after the already-streamed
|
||||
# response. Route suppressed content through the stream
|
||||
# delta callback so its tag extraction can fire the
|
||||
# reasoning display. Non-reasoning text is harmlessly
|
||||
# suppressed by the CLI's _stream_delta when the stream
|
||||
# box is already closed (tool boundary flush).
|
||||
if self.stream_delta_callback:
|
||||
try:
|
||||
self.stream_delta_callback(_delta_text)
|
||||
except Exception:
|
||||
pass
|
||||
# Tool calls suppress regular content streaming (avoids
|
||||
# displaying chatty "I'll use the tool..." text alongside
|
||||
# tool calls). But reasoning tags embedded in suppressed
|
||||
# content should still reach the display — otherwise the
|
||||
# reasoning box only appears as a post-response fallback,
|
||||
# rendering it confusingly after the already-streamed
|
||||
# response. Route suppressed content through the stream
|
||||
# delta callback so its tag extraction can fire the
|
||||
# reasoning display. Non-reasoning text is harmlessly
|
||||
# suppressed by the CLI's _stream_delta when the stream
|
||||
# box is already closed (tool boundary flush).
|
||||
if self.stream_delta_callback:
|
||||
try:
|
||||
self.stream_delta_callback(delta.content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Accumulate tool call deltas — notify display on first name
|
||||
if delta and delta.tool_calls:
|
||||
@@ -5235,32 +5170,18 @@ class AIAgent:
|
||||
Handles reasoning extraction, reasoning_details, and optional tool_calls
|
||||
so both the tool-call path and the final-response path share one builder.
|
||||
"""
|
||||
# Normalize content early — Mistral Magistral models return content
|
||||
# as a list of structured blocks instead of a string.
|
||||
_raw_content = assistant_message.content
|
||||
_structured_thinking = None
|
||||
if isinstance(_raw_content, list):
|
||||
_raw_content, _structured_thinking = _normalize_structured_content(_raw_content)
|
||||
|
||||
reasoning_text = self._extract_reasoning(assistant_message)
|
||||
_from_structured = bool(reasoning_text)
|
||||
|
||||
# If the structured content included thinking blocks and
|
||||
# _extract_reasoning didn't find anything, use the structured thinking.
|
||||
if not reasoning_text and _structured_thinking:
|
||||
reasoning_text = _structured_thinking
|
||||
_from_structured = True
|
||||
|
||||
# Fallback: extract inline <think> blocks from content when no structured
|
||||
# reasoning fields are present (some models/providers embed thinking
|
||||
# directly in the content rather than returning separate API fields).
|
||||
if not reasoning_text:
|
||||
content = _raw_content or ""
|
||||
if isinstance(content, str):
|
||||
think_blocks = re.findall(r'<think>(.*?)</think>', content, flags=re.DOTALL)
|
||||
if think_blocks:
|
||||
combined = "\n\n".join(b.strip() for b in think_blocks if b.strip())
|
||||
reasoning_text = combined or None
|
||||
content = assistant_message.content or ""
|
||||
think_blocks = re.findall(r'<think>(.*?)</think>', content, flags=re.DOTALL)
|
||||
if think_blocks:
|
||||
combined = "\n\n".join(b.strip() for b in think_blocks if b.strip())
|
||||
reasoning_text = combined or None
|
||||
|
||||
if reasoning_text and self.verbose_logging:
|
||||
logging.debug(f"Captured reasoning ({len(reasoning_text)} chars): {reasoning_text}")
|
||||
@@ -5282,7 +5203,7 @@ class AIAgent:
|
||||
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": _raw_content or "",
|
||||
"content": assistant_message.content or "",
|
||||
"reasoning": reasoning_text,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
@@ -6735,21 +6656,10 @@ class AIAgent:
|
||||
if self.step_callback is not None:
|
||||
try:
|
||||
prev_tools = []
|
||||
for _idx, _m in enumerate(reversed(messages)):
|
||||
for _m in reversed(messages):
|
||||
if _m.get("role") == "assistant" and _m.get("tool_calls"):
|
||||
_fwd_start = len(messages) - _idx
|
||||
_results_by_id = {}
|
||||
for _tm in messages[_fwd_start:]:
|
||||
if _tm.get("role") != "tool":
|
||||
break
|
||||
_tcid = _tm.get("tool_call_id")
|
||||
if _tcid:
|
||||
_results_by_id[_tcid] = _tm.get("content", "")
|
||||
prev_tools = [
|
||||
{
|
||||
"name": tc["function"]["name"],
|
||||
"result": _results_by_id.get(tc.get("id")),
|
||||
}
|
||||
tc["function"]["name"]
|
||||
for tc in _m["tool_calls"]
|
||||
if isinstance(tc, dict)
|
||||
]
|
||||
@@ -7101,9 +7011,6 @@ class AIAgent:
|
||||
if self.api_mode == "chat_completions":
|
||||
_trunc_msg = response.choices[0].message if (hasattr(response, "choices") and response.choices) else None
|
||||
_trunc_content = getattr(_trunc_msg, "content", None) if _trunc_msg else None
|
||||
# Mistral Magistral: content may be a list of blocks
|
||||
if isinstance(_trunc_content, list):
|
||||
_trunc_content, _ = _normalize_structured_content(_trunc_content)
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
# Anthropic response.content is a list of blocks
|
||||
_text_parts = []
|
||||
@@ -7158,10 +7065,7 @@ class AIAgent:
|
||||
interim_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
messages.append(interim_msg)
|
||||
if assistant_message.content:
|
||||
_cont = assistant_message.content
|
||||
if isinstance(_cont, list):
|
||||
_cont, _ = _normalize_structured_content(_cont)
|
||||
truncated_response_prefix += _cont
|
||||
truncated_response_prefix += assistant_message.content
|
||||
|
||||
if length_continue_retries < 3:
|
||||
self._vprint(
|
||||
@@ -7876,22 +7780,21 @@ class AIAgent:
|
||||
# Normalize content to string — some OpenAI-compatible servers
|
||||
# (llama-server, etc.) return content as a dict or list instead
|
||||
# of a plain string, which crashes downstream .strip() calls.
|
||||
# Mistral Magistral models return a list of structured blocks
|
||||
# including {type: "thinking"} and {type: "text"}.
|
||||
if assistant_message.content is not None and not isinstance(assistant_message.content, str):
|
||||
raw = assistant_message.content
|
||||
if isinstance(raw, dict):
|
||||
assistant_message.content = raw.get("text", "") or raw.get("content", "") or json.dumps(raw)
|
||||
elif isinstance(raw, list):
|
||||
_norm_text, _norm_thinking = _normalize_structured_content(raw)
|
||||
assistant_message.content = _norm_text
|
||||
# Preserve extracted thinking as reasoning_content so
|
||||
# _extract_reasoning / _build_assistant_message picks it up.
|
||||
if _norm_thinking and not getattr(assistant_message, "reasoning_content", None):
|
||||
try:
|
||||
assistant_message.reasoning_content = _norm_thinking
|
||||
except (AttributeError, TypeError):
|
||||
pass # frozen/read-only SDK object
|
||||
# Multimodal content list — extract text parts
|
||||
parts = []
|
||||
for part in raw:
|
||||
if isinstance(part, str):
|
||||
parts.append(part)
|
||||
elif isinstance(part, dict) and part.get("type") == "text":
|
||||
parts.append(part.get("text", ""))
|
||||
elif isinstance(part, dict) and "text" in part:
|
||||
parts.append(str(part["text"]))
|
||||
assistant_message.content = "\n".join(parts)
|
||||
else:
|
||||
assistant_message.content = str(raw)
|
||||
|
||||
|
||||
@@ -205,47 +205,6 @@ class TestStepCallback:
|
||||
assert "read_file" not in tool_call_ids
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_result_passed_to_build_tool_complete(self, mock_conn, event_loop_fixture):
|
||||
"""Tool result from prev_tools dict is forwarded to build_tool_complete."""
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
# Provide a result string in the tool info dict
|
||||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
||||
)
|
||||
|
||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||
"""When result is None (e.g. first iteration), None is passed through."""
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "web_search", "result": None}])
|
||||
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message callback
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
"""End-to-end tests for ACP MCP server registration and tool-result reporting.
|
||||
|
||||
Exercises the full flow through the ACP server layer:
|
||||
new_session(mcpServers) → MCP tools registered → prompt() →
|
||||
tool_progress_callback (ToolCallStart) →
|
||||
step_callback with results (ToolCallUpdate with rawOutput) →
|
||||
session_update events arrive at the mock client
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.schema import (
|
||||
EnvVariable,
|
||||
HttpHeader,
|
||||
McpServerHttp,
|
||||
McpServerStdio,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
TextContentBlock,
|
||||
ToolCallProgress,
|
||||
ToolCallStart,
|
||||
)
|
||||
|
||||
from acp_adapter.server import HermesACPAgent
|
||||
from acp_adapter.session import SessionManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_manager():
|
||||
return SessionManager(agent_factory=lambda: MagicMock(name="MockAIAgent"))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def acp_agent(mock_manager):
|
||||
return HermesACPAgent(session_manager=mock_manager)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E: MCP registration → prompt → tool events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMcpRegistrationE2E:
|
||||
"""Full flow: session with MCP servers → prompt with tool calls → ACP events."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_with_mcp_servers_registers_tools(self, acp_agent, mock_manager):
|
||||
"""new_session with mcpServers converts them to Hermes config and registers."""
|
||||
servers = [
|
||||
McpServerStdio(
|
||||
name="test-fs",
|
||||
command="/usr/bin/mcp-fs",
|
||||
args=["--root", "/tmp"],
|
||||
env=[EnvVariable(name="DEBUG", value="1")],
|
||||
),
|
||||
McpServerHttp(
|
||||
name="test-api",
|
||||
url="https://api.example.com/mcp",
|
||||
headers=[HttpHeader(name="Authorization", value="Bearer tok123")],
|
||||
),
|
||||
]
|
||||
|
||||
registered_configs = {}
|
||||
|
||||
def mock_register(config_map):
|
||||
registered_configs.update(config_map)
|
||||
return ["mcp_test_fs_read", "mcp_test_fs_write", "mcp_test_api_search"]
|
||||
|
||||
fake_tools = [
|
||||
{"function": {"name": "mcp_test_fs_read"}},
|
||||
{"function": {"name": "mcp_test_fs_write"}},
|
||||
{"function": {"name": "mcp_test_api_search"}},
|
||||
{"function": {"name": "terminal"}},
|
||||
]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
resp = await acp_agent.new_session(cwd="/tmp", mcp_servers=servers)
|
||||
|
||||
assert isinstance(resp, NewSessionResponse)
|
||||
state = mock_manager.get_session(resp.session_id)
|
||||
|
||||
# Verify stdio server was converted correctly
|
||||
assert "test-fs" in registered_configs
|
||||
fs_cfg = registered_configs["test-fs"]
|
||||
assert fs_cfg["command"] == "/usr/bin/mcp-fs"
|
||||
assert fs_cfg["args"] == ["--root", "/tmp"]
|
||||
assert fs_cfg["env"] == {"DEBUG": "1"}
|
||||
|
||||
# Verify HTTP server was converted correctly
|
||||
assert "test-api" in registered_configs
|
||||
api_cfg = registered_configs["test-api"]
|
||||
assert api_cfg["url"] == "https://api.example.com/mcp"
|
||||
assert api_cfg["headers"] == {"Authorization": "Bearer tok123"}
|
||||
|
||||
# Verify agent tool surface was refreshed
|
||||
assert state.agent.tools == fake_tools
|
||||
assert state.agent.valid_tool_names == {
|
||||
"mcp_test_fs_read", "mcp_test_fs_write", "mcp_test_api_search", "terminal"
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_with_tool_calls_emits_acp_events(self, acp_agent, mock_manager):
|
||||
"""Prompt → agent fires callbacks → ACP ToolCallStart + ToolCallUpdate events."""
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
session_id = resp.session_id
|
||||
state = mock_manager.get_session(session_id)
|
||||
|
||||
# Wire up a mock ACP client connection
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
mock_conn.request_permission = AsyncMock()
|
||||
acp_agent._conn = mock_conn
|
||||
|
||||
def mock_run_conversation(user_message, conversation_history=None, task_id=None):
|
||||
"""Simulate an agent turn that calls terminal, gets a result, then responds."""
|
||||
agent = state.agent
|
||||
|
||||
# 1) Agent fires tool_progress_callback (ToolCallStart)
|
||||
if agent.tool_progress_callback:
|
||||
agent.tool_progress_callback(
|
||||
"terminal", "$ echo hello", {"command": "echo hello"}
|
||||
)
|
||||
|
||||
# 2) Agent fires step_callback with tool results (ToolCallUpdate)
|
||||
if agent.step_callback:
|
||||
agent.step_callback(1, [
|
||||
{"name": "terminal", "result": '{"output": "hello\\n", "exit_code": 0}'}
|
||||
])
|
||||
|
||||
return {
|
||||
"final_response": "The command output 'hello'.",
|
||||
"messages": [
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": "The command output 'hello'."},
|
||||
],
|
||||
}
|
||||
|
||||
state.agent.run_conversation = mock_run_conversation
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="run echo hello")]
|
||||
resp = await acp_agent.prompt(prompt=prompt, session_id=session_id)
|
||||
|
||||
assert isinstance(resp, PromptResponse)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
# Collect all session_update calls
|
||||
updates = []
|
||||
for call in mock_conn.session_update.call_args_list:
|
||||
# session_update(session_id, update) — grab the update
|
||||
update_arg = call[1].get("update") or call[0][1]
|
||||
updates.append(update_arg)
|
||||
|
||||
# Find tool_call (start) and tool_call_update (completion) events
|
||||
starts = [u for u in updates if getattr(u, "session_update", None) == "tool_call"]
|
||||
completions = [u for u in updates if getattr(u, "session_update", None) == "tool_call_update"]
|
||||
|
||||
# Should have at least one ToolCallStart for "terminal"
|
||||
assert len(starts) >= 1, f"Expected ToolCallStart, got updates: {[getattr(u, 'session_update', '?') for u in updates]}"
|
||||
start_event = starts[0]
|
||||
assert isinstance(start_event, ToolCallStart)
|
||||
assert start_event.title.startswith("terminal:")
|
||||
|
||||
# Should have at least one ToolCallUpdate (completion) with rawOutput
|
||||
assert len(completions) >= 1, f"Expected ToolCallUpdate, got updates: {[getattr(u, 'session_update', '?') for u in updates]}"
|
||||
complete_event = completions[0]
|
||||
assert isinstance(complete_event, ToolCallProgress)
|
||||
assert complete_event.status == "completed"
|
||||
# rawOutput should contain the tool result string
|
||||
assert complete_event.raw_output is not None
|
||||
assert "hello" in str(complete_event.raw_output)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
session_id = resp.session_id
|
||||
state = mock_manager.get_session(session_id)
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
mock_conn.request_permission = AsyncMock()
|
||||
acp_agent._conn = mock_conn
|
||||
|
||||
def mock_run(user_message, conversation_history=None, task_id=None):
|
||||
agent = state.agent
|
||||
# Fire two tool calls
|
||||
if agent.tool_progress_callback:
|
||||
agent.tool_progress_callback("read_file", "read: /etc/hosts", {"path": "/etc/hosts"})
|
||||
agent.tool_progress_callback("web_search", "web search: test", {"query": "test"})
|
||||
|
||||
if agent.step_callback:
|
||||
agent.step_callback(1, [
|
||||
{"name": "read_file", "result": '{"content": "127.0.0.1 localhost"}'},
|
||||
{"name": "web_search", "result": '{"data": {"web": []}}'},
|
||||
])
|
||||
|
||||
return {"final_response": "Done.", "messages": []}
|
||||
|
||||
state.agent.run_conversation = mock_run
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="test")]
|
||||
await acp_agent.prompt(prompt=prompt, session_id=session_id)
|
||||
|
||||
updates = []
|
||||
for call in mock_conn.session_update.call_args_list:
|
||||
update_arg = call[1].get("update") or call[0][1]
|
||||
updates.append(update_arg)
|
||||
|
||||
starts = [u for u in updates if getattr(u, "session_update", None) == "tool_call"]
|
||||
completions = [u for u in updates if getattr(u, "session_update", None) == "tool_call_update"]
|
||||
|
||||
assert len(starts) == 2, f"Expected 2 starts, got {len(starts)}"
|
||||
assert len(completions) == 2, f"Expected 2 completions, got {len(completions)}"
|
||||
|
||||
# Each completion's toolCallId must match a start's toolCallId
|
||||
start_ids = {s.tool_call_id for s in starts}
|
||||
completion_ids = {c.tool_call_id for c in completions}
|
||||
assert start_ids == completion_ids, (
|
||||
f"IDs must match: starts={start_ids}, completions={completion_ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestMcpSanitizationE2E:
|
||||
"""Verify server names with special chars work end-to-end."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slashed_server_name_registers_cleanly(self, acp_agent, mock_manager):
|
||||
"""Server name 'ai.exa/exa' should not crash — tools get sanitized names."""
|
||||
servers = [
|
||||
McpServerHttp(
|
||||
name="ai.exa/exa",
|
||||
url="https://exa.ai/mcp",
|
||||
headers=[],
|
||||
),
|
||||
]
|
||||
|
||||
registered_configs = {}
|
||||
def mock_register(config_map):
|
||||
registered_configs.update(config_map)
|
||||
return ["mcp_ai_exa_exa_search"]
|
||||
|
||||
fake_tools = [{"function": {"name": "mcp_ai_exa_exa_search"}}]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
resp = await acp_agent.new_session(cwd="/tmp", mcp_servers=servers)
|
||||
|
||||
state = mock_manager.get_session(resp.session_id)
|
||||
|
||||
# Raw server name preserved as config key
|
||||
assert "ai.exa/exa" in registered_configs
|
||||
# Agent tools refreshed with sanitized name
|
||||
assert "mcp_ai_exa_exa_search" in state.agent.valid_tool_names
|
||||
|
||||
|
||||
class TestSessionLifecycleMcpE2E:
|
||||
"""Verify MCP servers are registered on all session lifecycle methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""load_session re-registers MCP servers (spec says agents may not retain them)."""
|
||||
# Create a session first
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerStdio(name="srv", command="/bin/test", args=[], env=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
state = mock_manager.get_session(sid)
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await acp_agent.load_session(cwd="/tmp", session_id=sid, mcp_servers=servers)
|
||||
|
||||
assert "srv" in registered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""resume_session re-registers MCP servers."""
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerStdio(name="srv2", command="/bin/test2", args=[], env=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
state = mock_manager.get_session(sid)
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await acp_agent.resume_session(cwd="/tmp", session_id=sid, mcp_servers=servers)
|
||||
|
||||
assert "srv2" in registered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fork_session_registers_mcp(self, acp_agent, mock_manager):
|
||||
"""fork_session registers MCP servers on the new forked session."""
|
||||
create_resp = await acp_agent.new_session(cwd="/tmp")
|
||||
sid = create_resp.session_id
|
||||
|
||||
servers = [
|
||||
McpServerHttp(name="api", url="https://api.test/mcp", headers=[]),
|
||||
]
|
||||
|
||||
registered = {}
|
||||
def mock_register(config_map):
|
||||
registered.update(config_map)
|
||||
return []
|
||||
|
||||
# Need to set up the forked session's agent too
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=mock_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
fork_resp = await acp_agent.fork_session(
|
||||
cwd="/tmp", session_id=sid, mcp_servers=servers
|
||||
)
|
||||
|
||||
assert fork_resp.session_id != ""
|
||||
assert "api" in registered
|
||||
@@ -505,179 +505,3 @@ class TestSlashCommands:
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _register_session_mcp_servers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterSessionMcpServers:
|
||||
"""Tests for ACP MCP server registration in session lifecycle."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_no_servers(self, agent, mock_manager):
|
||||
"""No-op when mcp_servers is None or empty."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
# Should not raise
|
||||
await agent._register_session_mcp_servers(state, None)
|
||||
await agent._register_session_mcp_servers(state, [])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_stdio_servers(self, agent, mock_manager):
|
||||
"""McpServerStdio servers are converted and passed to register_mcp_servers."""
|
||||
from acp.schema import McpServerStdio, EnvVariable
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
# Give the mock agent the attributes _register_session_mcp_servers reads
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
server = McpServerStdio(
|
||||
name="test-server",
|
||||
command="/usr/bin/test",
|
||||
args=["--flag"],
|
||||
env=[EnvVariable(name="KEY", value="val")],
|
||||
)
|
||||
|
||||
registered_config = {}
|
||||
def capture_register(config_map):
|
||||
registered_config.update(config_map)
|
||||
return ["mcp_test_server_tool1"]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert "test-server" in registered_config
|
||||
cfg = registered_config["test-server"]
|
||||
assert cfg["command"] == "/usr/bin/test"
|
||||
assert cfg["args"] == ["--flag"]
|
||||
assert cfg["env"] == {"KEY": "val"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registers_http_servers(self, agent, mock_manager):
|
||||
"""McpServerHttp servers are converted correctly."""
|
||||
from acp.schema import McpServerHttp, HttpHeader
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
|
||||
server = McpServerHttp(
|
||||
name="http-server",
|
||||
url="https://api.example.com/mcp",
|
||||
headers=[HttpHeader(name="Authorization", value="Bearer tok")],
|
||||
)
|
||||
|
||||
registered_config = {}
|
||||
def capture_register(config_map):
|
||||
registered_config.update(config_map)
|
||||
return []
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=capture_register), \
|
||||
patch("model_tools.get_tool_definitions", return_value=[]):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert "http-server" in registered_config
|
||||
cfg = registered_config["http-server"]
|
||||
assert cfg["url"] == "https://api.example.com/mcp"
|
||||
assert cfg["headers"] == {"Authorization": "Bearer tok"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refreshes_agent_tool_surface(self, agent, mock_manager):
|
||||
"""After MCP registration, agent.tools and valid_tool_names are refreshed."""
|
||||
from acp.schema import McpServerStdio
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.enabled_toolsets = ["hermes-acp"]
|
||||
state.agent.disabled_toolsets = None
|
||||
state.agent.tools = []
|
||||
state.agent.valid_tool_names = set()
|
||||
state.agent._cached_system_prompt = "old prompt"
|
||||
|
||||
server = McpServerStdio(
|
||||
name="srv",
|
||||
command="/bin/test",
|
||||
args=[],
|
||||
env=[],
|
||||
)
|
||||
|
||||
fake_tools = [
|
||||
{"function": {"name": "mcp_srv_search"}},
|
||||
{"function": {"name": "terminal"}},
|
||||
]
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", return_value=["mcp_srv_search"]), \
|
||||
patch("model_tools.get_tool_definitions", return_value=fake_tools):
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
assert state.agent.tools == fake_tools
|
||||
assert state.agent.valid_tool_names == {"mcp_srv_search", "terminal"}
|
||||
# _invalidate_system_prompt should have been called
|
||||
state.agent._invalidate_system_prompt.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_failure_logs_warning(self, agent, mock_manager):
|
||||
"""If register_mcp_servers raises, warning is logged but no crash."""
|
||||
from acp.schema import McpServerStdio
|
||||
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
server = McpServerStdio(
|
||||
name="bad",
|
||||
command="/nonexistent",
|
||||
args=[],
|
||||
env=[],
|
||||
)
|
||||
|
||||
with patch("tools.mcp_tool.register_mcp_servers", side_effect=RuntimeError("boom")):
|
||||
# Should not raise
|
||||
await agent._register_session_mcp_servers(state, [server])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_calls_register(self, agent, mock_manager):
|
||||
"""new_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.new_session(cwd="/tmp", mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
# Second arg should be the mcp_servers list
|
||||
assert mock_reg.call_args[0][1] == ["fake"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_calls_register(self, agent, mock_manager):
|
||||
"""load_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
# Create a session first so load can find it
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.load_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_calls_register(self, agent, mock_manager):
|
||||
"""resume_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.resume_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fork_session_calls_register(self, agent, mock_manager):
|
||||
"""fork_session passes mcp_servers to _register_session_mcp_servers."""
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
sid = state.session_id
|
||||
|
||||
with patch.object(agent, "_register_session_mcp_servers", new_callable=AsyncMock) as mock_reg:
|
||||
resp = await agent.fork_session(cwd="/tmp", session_id=sid, mcp_servers=["fake"])
|
||||
assert resp is not None
|
||||
mock_reg.assert_called_once()
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
"""Tests for step_callback backward compatibility.
|
||||
|
||||
Verifies that the gateway's step_callback normalization keeps
|
||||
``tool_names`` as a list of strings for backward-compatible hooks,
|
||||
while also providing the enriched ``tools`` list with results.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStepCallbackNormalization:
|
||||
"""The gateway's _step_callback_sync normalizes prev_tools from run_agent."""
|
||||
|
||||
def _extract_step_callback(self):
|
||||
"""Build a minimal _step_callback_sync using the same logic as gateway/run.py.
|
||||
|
||||
We replicate the closure so we can test normalisation in isolation
|
||||
without spinning up the full gateway.
|
||||
"""
|
||||
captured_events = []
|
||||
|
||||
class FakeHooks:
|
||||
async def emit(self, event_type, data):
|
||||
captured_events.append((event_type, data))
|
||||
|
||||
hooks_ref = FakeHooks()
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
||||
_names: list[str] = []
|
||||
for _t in (prev_tools or []):
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
hooks_ref.emit("agent:step", {
|
||||
"iteration": iteration,
|
||||
"tool_names": _names,
|
||||
"tools": prev_tools,
|
||||
}),
|
||||
loop,
|
||||
)
|
||||
|
||||
return _step_callback_sync, captured_events, loop
|
||||
|
||||
def test_dict_prev_tools_produce_string_tool_names(self):
|
||||
"""When prev_tools is list[dict], tool_names should be list[str]."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
# Simulate the enriched format from run_agent.py
|
||||
prev_tools = [
|
||||
{"name": "terminal", "result": '{"output": "hello"}'},
|
||||
{"name": "read_file", "result": '{"content": "..."}'},
|
||||
]
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0)) # prime the loop
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(1, prev_tools))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
# tool_names must be strings for backward compat
|
||||
assert data["tool_names"] == ["terminal", "read_file"]
|
||||
assert all(isinstance(n, str) for n in data["tool_names"])
|
||||
# tools should be the enriched dicts
|
||||
assert data["tools"] == prev_tools
|
||||
|
||||
def test_string_prev_tools_still_work(self):
|
||||
"""When prev_tools is list[str] (legacy), tool_names should pass through."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
prev_tools = ["terminal", "read_file"]
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0))
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(2, prev_tools))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
assert data["tool_names"] == ["terminal", "read_file"]
|
||||
|
||||
def test_empty_prev_tools(self):
|
||||
"""Empty or None prev_tools should produce empty tool_names."""
|
||||
cb, events, loop = self._extract_step_callback()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.sleep(0))
|
||||
import threading
|
||||
t = threading.Thread(target=cb, args=(1, []))
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
loop.run_until_complete(asyncio.sleep(0.1))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert len(events) == 1
|
||||
_, data = events[0]
|
||||
assert data["tool_names"] == []
|
||||
|
||||
def test_joinable_for_hook_example(self):
|
||||
"""The documented hook example: ', '.join(tool_names) should work."""
|
||||
# This is the exact pattern from the docs
|
||||
prev_tools = [
|
||||
{"name": "terminal", "result": "ok"},
|
||||
{"name": "web_search", "result": None},
|
||||
]
|
||||
|
||||
_names = []
|
||||
for _t in prev_tools:
|
||||
if isinstance(_t, dict):
|
||||
_names.append(_t.get("name") or "")
|
||||
else:
|
||||
_names.append(str(_t))
|
||||
|
||||
# This must not raise — documented hook pattern
|
||||
result = ", ".join(_names)
|
||||
assert result == "terminal, web_search"
|
||||
@@ -191,60 +191,6 @@ class TestHistoryDisplay:
|
||||
assert "A" * 250 in output
|
||||
assert "A" * 250 + "..." not in output
|
||||
|
||||
def test_history_shows_recent_sessions_when_current_chat_is_empty(self, capsys):
|
||||
cli = _make_cli()
|
||||
cli.session_id = "current"
|
||||
cli._session_db = MagicMock()
|
||||
cli._session_db.list_sessions_rich.return_value = [
|
||||
{
|
||||
"id": "current",
|
||||
"title": "Current",
|
||||
"preview": "Current preview",
|
||||
"last_active": 0,
|
||||
},
|
||||
{
|
||||
"id": "20260401_201329_d85961",
|
||||
"title": "Checking Running Hermes Agent",
|
||||
"preview": "check running gateways for hermes agent",
|
||||
"last_active": 0,
|
||||
},
|
||||
]
|
||||
|
||||
cli.show_history()
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "No messages in the current chat yet" in output
|
||||
assert "Checking Running Hermes Agent" in output
|
||||
assert "20260401_201329_d85961" in output
|
||||
assert "/resume" in output
|
||||
assert "Current preview" not in output
|
||||
|
||||
def test_resume_without_target_lists_recent_sessions(self, capsys):
|
||||
cli = _make_cli()
|
||||
cli.session_id = "current"
|
||||
cli._session_db = MagicMock()
|
||||
cli._session_db.list_sessions_rich.return_value = [
|
||||
{
|
||||
"id": "current",
|
||||
"title": "Current",
|
||||
"preview": "Current preview",
|
||||
"last_active": 0,
|
||||
},
|
||||
{
|
||||
"id": "20260401_201329_d85961",
|
||||
"title": "Checking Running Hermes Agent",
|
||||
"preview": "check running gateways for hermes agent",
|
||||
"last_active": 0,
|
||||
},
|
||||
]
|
||||
|
||||
cli._handle_resume_command("/resume")
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Recent sessions" in output
|
||||
assert "Checking Running Hermes Agent" in output
|
||||
assert "Use /resume <session id or title> to continue" in output
|
||||
|
||||
|
||||
class TestRootLevelProviderOverride:
|
||||
"""Root-level provider/base_url in config.yaml must NOT override model.provider."""
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
"""Tests for Mistral Magistral structured content handling.
|
||||
|
||||
Mistral's Magistral reasoning models return ``content`` as a list of typed
|
||||
blocks instead of a plain string (both in streaming deltas and non-streaming
|
||||
responses). This test suite verifies that:
|
||||
|
||||
1. _normalize_structured_content() correctly extracts text and thinking parts.
|
||||
2. The streaming path handles list-valued delta.content without crashing.
|
||||
3. The non-streaming path normalizes list content and extracts thinking.
|
||||
4. _build_assistant_message handles list content correctly.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ── Ensure HERMES_HOME is set before importing run_agent ──────────────
|
||||
if not os.environ.get("HERMES_HOME"):
|
||||
import tempfile
|
||||
|
||||
_tmp = tempfile.mkdtemp(prefix="hermes_test_")
|
||||
os.environ["HERMES_HOME"] = _tmp
|
||||
|
||||
from run_agent import AIAgent, _normalize_structured_content
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────
|
||||
|
||||
def _make_tool_defs(*names):
|
||||
"""Build minimal tool definitions matching get_tool_definitions output."""
|
||||
return [
|
||||
{"type": "function", "function": {"name": n, "description": n, "parameters": {}}}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent():
|
||||
"""Minimal AIAgent for testing _build_assistant_message."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
ag = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
model="mistral/magistral-medium-latest",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
ag.client = MagicMock()
|
||||
ag.verbose_logging = False
|
||||
ag.reasoning_callback = None
|
||||
ag.stream_delta_callback = None
|
||||
return ag
|
||||
|
||||
|
||||
# ── Sample data matching Mistral's API format ─────────────────────────
|
||||
|
||||
MAGISTRAL_CONTENT_BLOCKS = [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": [
|
||||
{"type": "text", "text": "Let me think about this step by step."},
|
||||
{"type": "text", "text": "The capital of France is Paris."},
|
||||
],
|
||||
},
|
||||
{"type": "text", "text": "The capital of France is Paris."},
|
||||
]
|
||||
|
||||
MAGISTRAL_TEXT_ONLY_BLOCKS = [
|
||||
{"type": "text", "text": "Hello, how can I help?"},
|
||||
]
|
||||
|
||||
MAGISTRAL_WITH_REFERENCE = [
|
||||
{"type": "thinking", "thinking": [{"type": "text", "text": "Checking references."}]},
|
||||
{"type": "text", "text": "Here is the answer."},
|
||||
{"type": "reference", "url": "https://example.com"},
|
||||
]
|
||||
|
||||
STREAMING_THINKING_DELTA = [
|
||||
{"type": "thinking", "thinking": [{"type": "text", "text": "Okay"}]},
|
||||
]
|
||||
|
||||
STREAMING_TEXT_DELTA = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
]
|
||||
|
||||
|
||||
# ── Tests: _normalize_structured_content ──────────────────────────────
|
||||
|
||||
class TestNormalizeStructuredContent:
|
||||
"""Tests for the _normalize_structured_content helper."""
|
||||
|
||||
def test_string_passthrough(self):
|
||||
text, thinking = _normalize_structured_content("Hello world")
|
||||
assert text == "Hello world"
|
||||
assert thinking is None
|
||||
|
||||
def test_none_returns_empty_string(self):
|
||||
text, thinking = _normalize_structured_content(None)
|
||||
assert text == ""
|
||||
assert thinking is None
|
||||
|
||||
def test_non_list_non_string_coerced(self):
|
||||
text, thinking = _normalize_structured_content(42)
|
||||
assert text == "42"
|
||||
assert thinking is None
|
||||
|
||||
def test_magistral_full_response(self):
|
||||
text, thinking = _normalize_structured_content(MAGISTRAL_CONTENT_BLOCKS)
|
||||
assert text == "The capital of France is Paris."
|
||||
assert "step by step" in thinking
|
||||
assert "capital of France is Paris" in thinking
|
||||
|
||||
def test_text_only_blocks(self):
|
||||
text, thinking = _normalize_structured_content(MAGISTRAL_TEXT_ONLY_BLOCKS)
|
||||
assert text == "Hello, how can I help?"
|
||||
assert thinking is None
|
||||
|
||||
def test_with_reference_blocks(self):
|
||||
"""Reference blocks should be skipped, not cause errors."""
|
||||
text, thinking = _normalize_structured_content(MAGISTRAL_WITH_REFERENCE)
|
||||
assert text == "Here is the answer."
|
||||
assert thinking == "Checking references."
|
||||
|
||||
def test_streaming_thinking_delta(self):
|
||||
text, thinking = _normalize_structured_content(STREAMING_THINKING_DELTA)
|
||||
assert text == ""
|
||||
assert thinking == "Okay"
|
||||
|
||||
def test_streaming_text_delta(self):
|
||||
text, thinking = _normalize_structured_content(STREAMING_TEXT_DELTA)
|
||||
assert text == "Hello"
|
||||
assert thinking is None
|
||||
|
||||
def test_empty_list(self):
|
||||
text, thinking = _normalize_structured_content([])
|
||||
assert text == ""
|
||||
assert thinking is None
|
||||
|
||||
def test_mixed_string_and_dict_blocks(self):
|
||||
"""Some providers might mix raw strings with typed blocks."""
|
||||
content = ["raw text", {"type": "text", "text": "typed text"}]
|
||||
text, thinking = _normalize_structured_content(content)
|
||||
assert "raw text" in text
|
||||
assert "typed text" in text
|
||||
|
||||
def test_thinking_as_plain_string(self):
|
||||
"""Handle edge case where thinking value is a string not a list."""
|
||||
content = [{"type": "thinking", "thinking": "I'm thinking..."}]
|
||||
text, thinking = _normalize_structured_content(content)
|
||||
assert text == ""
|
||||
assert thinking == "I'm thinking..."
|
||||
|
||||
def test_multiple_text_blocks_joined(self):
|
||||
content = [
|
||||
{"type": "text", "text": "First paragraph."},
|
||||
{"type": "text", "text": "Second paragraph."},
|
||||
]
|
||||
text, thinking = _normalize_structured_content(content)
|
||||
assert "First paragraph." in text
|
||||
assert "Second paragraph." in text
|
||||
assert "\n" in text # joined with newline
|
||||
|
||||
def test_empty_thinking_block(self):
|
||||
"""Thinking block with no text should result in thinking=None."""
|
||||
content = [
|
||||
{"type": "thinking", "thinking": []},
|
||||
{"type": "text", "text": "Answer"},
|
||||
]
|
||||
text, thinking = _normalize_structured_content(content)
|
||||
assert text == "Answer"
|
||||
assert thinking is None
|
||||
|
||||
|
||||
# ── Tests: _build_assistant_message with structured content ────────────
|
||||
|
||||
class TestBuildAssistantMessageStructuredContent:
|
||||
"""Tests that _build_assistant_message correctly handles Mistral list content."""
|
||||
|
||||
def test_list_content_normalized_to_string(self, agent):
|
||||
msg = SimpleNamespace(
|
||||
content=MAGISTRAL_CONTENT_BLOCKS,
|
||||
tool_calls=None,
|
||||
)
|
||||
result = agent._build_assistant_message(msg, "stop")
|
||||
assert isinstance(result["content"], str)
|
||||
assert "The capital of France is Paris." in result["content"]
|
||||
|
||||
def test_list_content_thinking_extracted(self, agent):
|
||||
msg = SimpleNamespace(
|
||||
content=MAGISTRAL_CONTENT_BLOCKS,
|
||||
tool_calls=None,
|
||||
)
|
||||
result = agent._build_assistant_message(msg, "stop")
|
||||
assert result["reasoning"] is not None
|
||||
assert "step by step" in result["reasoning"]
|
||||
|
||||
def test_string_content_unchanged(self, agent):
|
||||
msg = SimpleNamespace(
|
||||
content="Normal string response",
|
||||
tool_calls=None,
|
||||
)
|
||||
result = agent._build_assistant_message(msg, "stop")
|
||||
assert result["content"] == "Normal string response"
|
||||
|
||||
def test_list_content_with_tool_calls(self, agent):
|
||||
tool_call = SimpleNamespace(
|
||||
id="call_123",
|
||||
type="function",
|
||||
function=SimpleNamespace(name="web_search", arguments='{"query": "test"}'),
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
content=MAGISTRAL_CONTENT_BLOCKS,
|
||||
tool_calls=[tool_call],
|
||||
)
|
||||
result = agent._build_assistant_message(msg, "tool_calls")
|
||||
assert isinstance(result["content"], str)
|
||||
assert "tool_calls" in result
|
||||
|
||||
def test_text_only_blocks_no_reasoning(self, agent):
|
||||
msg = SimpleNamespace(
|
||||
content=MAGISTRAL_TEXT_ONLY_BLOCKS,
|
||||
tool_calls=None,
|
||||
)
|
||||
result = agent._build_assistant_message(msg, "stop")
|
||||
assert result["content"] == "Hello, how can I help?"
|
||||
assert result["reasoning"] is None
|
||||
|
||||
def test_structured_thinking_not_duplicated_with_reasoning_content(self, agent):
|
||||
"""When reasoning_content is set AND content has thinking blocks,
|
||||
don't duplicate the reasoning."""
|
||||
msg = SimpleNamespace(
|
||||
content=MAGISTRAL_CONTENT_BLOCKS,
|
||||
tool_calls=None,
|
||||
reasoning_content="Already extracted reasoning",
|
||||
)
|
||||
result = agent._build_assistant_message(msg, "stop")
|
||||
# Should use the already-set reasoning_content, not duplicate
|
||||
assert result["reasoning"] == "Already extracted reasoning"
|
||||
|
||||
|
||||
# ── Tests: Non-streaming content normalization ─────────────────────────
|
||||
|
||||
class TestNonStreamingContentNormalization:
|
||||
"""Tests for the non-streaming content normalization block in the agent loop."""
|
||||
|
||||
def test_list_content_normalized(self, agent):
|
||||
"""Simulate the normalization block that runs after getting the
|
||||
assistant_message from response.choices[0].message."""
|
||||
msg = SimpleNamespace(content=MAGISTRAL_CONTENT_BLOCKS, tool_calls=None)
|
||||
|
||||
# Simulate the normalization block from run_agent.py
|
||||
if msg.content is not None and not isinstance(msg.content, str):
|
||||
raw = msg.content
|
||||
if isinstance(raw, list):
|
||||
text, thinking = _normalize_structured_content(raw)
|
||||
msg.content = text
|
||||
if thinking and not getattr(msg, "reasoning_content", None):
|
||||
msg.reasoning_content = thinking
|
||||
|
||||
assert isinstance(msg.content, str)
|
||||
assert "The capital of France is Paris." in msg.content
|
||||
assert hasattr(msg, "reasoning_content")
|
||||
assert "step by step" in msg.reasoning_content
|
||||
|
||||
def test_dict_content_handled(self, agent):
|
||||
"""Dict content (from llama-server etc.) should still work."""
|
||||
msg = SimpleNamespace(content={"text": "Hello from dict"}, tool_calls=None)
|
||||
|
||||
if msg.content is not None and not isinstance(msg.content, str):
|
||||
raw = msg.content
|
||||
if isinstance(raw, dict):
|
||||
msg.content = raw.get("text", "") or raw.get("content", "") or str(raw)
|
||||
|
||||
assert msg.content == "Hello from dict"
|
||||
|
||||
|
||||
# ── Tests: Streaming delta normalization ───────────────────────────────
|
||||
|
||||
class TestStreamingDeltaNormalization:
|
||||
"""Tests for the streaming delta content normalization."""
|
||||
|
||||
def test_list_delta_content_split(self):
|
||||
"""When delta.content is a list, text goes to content_parts
|
||||
and thinking goes to reasoning_parts."""
|
||||
content_parts = []
|
||||
reasoning_parts = []
|
||||
|
||||
# Simulate the streaming normalization block
|
||||
delta_content = MAGISTRAL_CONTENT_BLOCKS
|
||||
if isinstance(delta_content, list):
|
||||
text, thinking = _normalize_structured_content(delta_content)
|
||||
if thinking:
|
||||
reasoning_parts.append(thinking)
|
||||
else:
|
||||
text = delta_content
|
||||
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
# Verify text and thinking are separated
|
||||
assert len(content_parts) == 1
|
||||
assert "The capital of France is Paris." in content_parts[0]
|
||||
assert len(reasoning_parts) == 1
|
||||
assert "step by step" in reasoning_parts[0]
|
||||
|
||||
# Verify join succeeds (this was the original crash)
|
||||
full_content = "".join(content_parts)
|
||||
assert isinstance(full_content, str)
|
||||
|
||||
def test_string_delta_passthrough(self):
|
||||
"""Normal string deltas should work unchanged."""
|
||||
content_parts = []
|
||||
delta_content = "Hello"
|
||||
|
||||
if isinstance(delta_content, list):
|
||||
text, _ = _normalize_structured_content(delta_content)
|
||||
else:
|
||||
text = delta_content
|
||||
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
full_content = "".join(content_parts)
|
||||
assert full_content == "Hello"
|
||||
|
||||
def test_thinking_only_delta(self):
|
||||
"""Streaming delta with only thinking and no text."""
|
||||
content_parts = []
|
||||
reasoning_parts = []
|
||||
|
||||
delta_content = STREAMING_THINKING_DELTA
|
||||
if isinstance(delta_content, list):
|
||||
text, thinking = _normalize_structured_content(delta_content)
|
||||
if thinking:
|
||||
reasoning_parts.append(thinking)
|
||||
else:
|
||||
text = delta_content
|
||||
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
# No text content, only reasoning
|
||||
assert len(content_parts) == 0
|
||||
assert len(reasoning_parts) == 1
|
||||
assert reasoning_parts[0] == "Okay"
|
||||
|
||||
# Join should succeed (empty list)
|
||||
full_content = "".join(content_parts) or None
|
||||
assert full_content is None
|
||||
|
||||
def test_multiple_streaming_chunks_joined(self):
|
||||
"""Multiple streaming chunks with mixed list and string content."""
|
||||
content_parts = []
|
||||
reasoning_parts = []
|
||||
|
||||
chunks = [
|
||||
STREAMING_THINKING_DELTA, # list: thinking only
|
||||
STREAMING_TEXT_DELTA, # list: text only
|
||||
"more text", # string
|
||||
]
|
||||
|
||||
for delta_content in chunks:
|
||||
if isinstance(delta_content, list):
|
||||
text, thinking = _normalize_structured_content(delta_content)
|
||||
if thinking:
|
||||
reasoning_parts.append(thinking)
|
||||
else:
|
||||
text = delta_content
|
||||
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
full_content = "".join(content_parts)
|
||||
full_reasoning = "".join(reasoning_parts) or None
|
||||
|
||||
assert full_content == "Hellomore text"
|
||||
assert full_reasoning == "Okay"
|
||||
@@ -2900,164 +2900,3 @@ class TestMCPBuiltinCollisionGuard:
|
||||
assert mock_registry.get_toolset_for_tool("mcp_srv_do_thing") == "mcp-srv"
|
||||
|
||||
_servers.pop("srv", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sanitize_mcp_name_component
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeMcpNameComponent:
|
||||
"""Verify sanitize_mcp_name_component handles all edge cases."""
|
||||
|
||||
def test_hyphens_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("my-server") == "my_server"
|
||||
|
||||
def test_dots_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("ai.exa") == "ai_exa"
|
||||
|
||||
def test_slashes_replaced(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("ai.exa/exa") == "ai_exa_exa"
|
||||
|
||||
def test_mixed_special_characters(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("@scope/my-pkg.v2") == "_scope_my_pkg_v2"
|
||||
|
||||
def test_alphanumeric_and_underscores_preserved(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("my_server_123") == "my_server_123"
|
||||
|
||||
def test_empty_string(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component("") == ""
|
||||
|
||||
def test_none_returns_empty(self):
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
assert sanitize_mcp_name_component(None) == ""
|
||||
|
||||
def test_slash_in_convert_mcp_schema(self):
|
||||
"""Server names with slashes produce valid tool names via _convert_mcp_schema."""
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
mcp_tool = _make_mcp_tool(name="search")
|
||||
schema = _convert_mcp_schema("ai.exa/exa", mcp_tool)
|
||||
assert schema["name"] == "mcp_ai_exa_exa_search"
|
||||
# Must match Anthropic's pattern: ^[a-zA-Z0-9_-]{1,128}$
|
||||
import re
|
||||
assert re.match(r"^[a-zA-Z0-9_-]{1,128}$", schema["name"])
|
||||
|
||||
def test_slash_in_build_utility_schemas(self):
|
||||
"""Server names with slashes produce valid utility tool names."""
|
||||
from tools.mcp_tool import _build_utility_schemas
|
||||
|
||||
schemas = _build_utility_schemas("ai.exa/exa")
|
||||
for s in schemas:
|
||||
name = s["schema"]["name"]
|
||||
assert "/" not in name
|
||||
assert "." not in name
|
||||
|
||||
def test_slash_in_sync_mcp_toolsets(self):
|
||||
"""_sync_mcp_toolsets uses sanitize consistently with _convert_mcp_schema."""
|
||||
from tools.mcp_tool import sanitize_mcp_name_component
|
||||
|
||||
# Verify the prefix generation matches what _convert_mcp_schema produces
|
||||
server_name = "ai.exa/exa"
|
||||
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
|
||||
assert safe_prefix == "mcp_ai_exa_exa_"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_mcp_servers public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegisterMcpServers:
|
||||
"""Verify the new register_mcp_servers() public API."""
|
||||
|
||||
def test_empty_servers_returns_empty(self):
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True):
|
||||
result = register_mcp_servers({})
|
||||
assert result == []
|
||||
|
||||
def test_mcp_not_available_returns_empty(self):
|
||||
from tools.mcp_tool import register_mcp_servers
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", False):
|
||||
result = register_mcp_servers({"srv": {"command": "test"}})
|
||||
assert result == []
|
||||
|
||||
def test_skips_already_connected_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers
|
||||
|
||||
mock_server = _make_mock_server("existing")
|
||||
_servers["existing"] = mock_server
|
||||
|
||||
try:
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_existing_tool"]):
|
||||
result = register_mcp_servers({"existing": {"command": "test"}})
|
||||
assert result == ["mcp_existing_tool"]
|
||||
finally:
|
||||
_servers.pop("existing", None)
|
||||
|
||||
def test_skips_disabled_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers
|
||||
|
||||
try:
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=[]):
|
||||
result = register_mcp_servers({"srv": {"command": "test", "enabled": False}})
|
||||
assert result == []
|
||||
finally:
|
||||
_servers.pop("srv", None)
|
||||
|
||||
def test_connects_new_servers(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
|
||||
|
||||
fake_config = {"my_server": {"command": "npx", "args": ["test"]}}
|
||||
|
||||
async def fake_register(name, cfg):
|
||||
server = _make_mock_server(name)
|
||||
server._registered_tool_names = ["mcp_my_server_tool1"]
|
||||
_servers[name] = server
|
||||
return ["mcp_my_server_tool1"]
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_my_server_tool1"]):
|
||||
_ensure_mcp_loop()
|
||||
result = register_mcp_servers(fake_config)
|
||||
|
||||
assert "mcp_my_server_tool1" in result
|
||||
_servers.pop("my_server", None)
|
||||
|
||||
def test_logs_summary_on_success(self):
|
||||
from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
|
||||
|
||||
fake_config = {"srv": {"command": "npx", "args": ["test"]}}
|
||||
|
||||
async def fake_register(name, cfg):
|
||||
server = _make_mock_server(name)
|
||||
server._registered_tool_names = ["mcp_srv_t1", "mcp_srv_t2"]
|
||||
_servers[name] = server
|
||||
return ["mcp_srv_t1", "mcp_srv_t2"]
|
||||
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
|
||||
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_srv_t1", "mcp_srv_t2"]):
|
||||
_ensure_mcp_loop()
|
||||
|
||||
with patch("tools.mcp_tool.logger") as mock_logger:
|
||||
register_mcp_servers(fake_config)
|
||||
|
||||
info_calls = [str(c) for c in mock_logger.info.call_args_list]
|
||||
assert any("2 tool(s)" in c and "1 server(s)" in c for c in info_calls), (
|
||||
f"Summary should report 2 tools from 1 server, got: {info_calls}"
|
||||
)
|
||||
|
||||
_servers.pop("srv", None)
|
||||
|
||||
@@ -1406,17 +1406,6 @@ def _normalize_mcp_input_schema(schema: dict | None) -> dict:
|
||||
return schema
|
||||
|
||||
|
||||
def sanitize_mcp_name_component(value: str) -> str:
|
||||
"""Return an MCP name component safe for tool and prefix generation.
|
||||
|
||||
Preserves Hermes's historical behavior of converting hyphens to
|
||||
underscores, and also replaces any other character outside
|
||||
``[A-Za-z0-9_]`` with ``_`` so generated tool names are compatible with
|
||||
provider validation rules.
|
||||
"""
|
||||
return re.sub(r"[^A-Za-z0-9_]", "_", str(value or ""))
|
||||
|
||||
|
||||
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
"""Convert an MCP tool listing to the Hermes registry schema format.
|
||||
|
||||
@@ -1428,8 +1417,9 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
Returns:
|
||||
A dict suitable for ``registry.register(schema=...)``.
|
||||
"""
|
||||
safe_tool_name = sanitize_mcp_name_component(mcp_tool.name)
|
||||
safe_server_name = sanitize_mcp_name_component(server_name)
|
||||
# Sanitize: replace hyphens and dots with underscores for LLM API compatibility
|
||||
safe_tool_name = mcp_tool.name.replace("-", "_").replace(".", "_")
|
||||
safe_server_name = server_name.replace("-", "_").replace(".", "_")
|
||||
prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}"
|
||||
return {
|
||||
"name": prefixed_name,
|
||||
@@ -1459,7 +1449,7 @@ def _sync_mcp_toolsets(server_names: Optional[List[str]] = None) -> None:
|
||||
all_mcp_tools: List[str] = []
|
||||
|
||||
for server_name in server_names:
|
||||
safe_prefix = f"mcp_{sanitize_mcp_name_component(server_name)}_"
|
||||
safe_prefix = f"mcp_{server_name.replace('-', '_').replace('.', '_')}_"
|
||||
server_tools = sorted(
|
||||
t for t in existing if t.startswith(safe_prefix)
|
||||
)
|
||||
@@ -1495,7 +1485,7 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
|
||||
with keys: schema, handler_key.
|
||||
"""
|
||||
safe_name = sanitize_mcp_name_component(server_name)
|
||||
safe_name = server_name.replace("-", "_").replace(".", "_")
|
||||
return [
|
||||
{
|
||||
"schema": {
|
||||
@@ -1782,86 +1772,6 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def register_mcp_servers(servers: Dict[str, dict]) -> List[str]:
|
||||
"""Connect to explicit MCP servers and register their tools.
|
||||
|
||||
Idempotent for already-connected server names. Servers with
|
||||
``enabled: false`` are skipped without disconnecting existing sessions.
|
||||
|
||||
Args:
|
||||
servers: Mapping of ``{server_name: server_config}``.
|
||||
|
||||
Returns:
|
||||
List of all currently registered MCP tool names.
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
logger.debug("MCP SDK not available -- skipping explicit MCP registration")
|
||||
return []
|
||||
|
||||
if not servers:
|
||||
logger.debug("No explicit MCP servers provided")
|
||||
return []
|
||||
|
||||
# Only attempt servers that aren't already connected and are enabled
|
||||
# (enabled: false skips the server entirely without removing its config)
|
||||
with _lock:
|
||||
new_servers = {
|
||||
k: v
|
||||
for k, v in servers.items()
|
||||
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
|
||||
if not new_servers:
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
return _existing_tool_names()
|
||||
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
return await _discover_and_register_server(name, cfg)
|
||||
|
||||
async def _discover_all():
|
||||
server_names = list(new_servers.keys())
|
||||
# Connect to all servers in PARALLEL
|
||||
results = await asyncio.gather(
|
||||
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for name, result in zip(server_names, results):
|
||||
if isinstance(result, Exception):
|
||||
command = new_servers.get(name, {}).get("command")
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s'%s: %s",
|
||||
name,
|
||||
f" (command={command})" if command else "",
|
||||
_format_connect_error(result),
|
||||
)
|
||||
|
||||
# Per-server timeouts are handled inside _discover_and_register_server.
|
||||
# The outer timeout is generous: 120s total for parallel discovery.
|
||||
_run_on_mcp_loop(_discover_all(), timeout=120)
|
||||
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
|
||||
# Log a summary so ACP callers get visibility into what was registered.
|
||||
with _lock:
|
||||
connected = [n for n in new_servers if n in _servers]
|
||||
new_tool_count = sum(
|
||||
len(getattr(_servers[n], "_registered_tool_names", []))
|
||||
for n in connected
|
||||
)
|
||||
failed = len(new_servers) - len(connected)
|
||||
if new_tool_count or failed:
|
||||
summary = f"MCP: registered {new_tool_count} tool(s) from {len(connected)} server(s)"
|
||||
if failed:
|
||||
summary += f" ({failed} failed)"
|
||||
logger.info(summary)
|
||||
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def discover_mcp_tools() -> List[str]:
|
||||
"""Entry point: load config, connect to MCP servers, register tools.
|
||||
|
||||
@@ -1883,32 +1793,69 @@ def discover_mcp_tools() -> List[str]:
|
||||
logger.debug("No MCP servers configured")
|
||||
return []
|
||||
|
||||
# Only attempt servers that aren't already connected and are enabled
|
||||
# (enabled: false skips the server entirely without removing its config)
|
||||
with _lock:
|
||||
new_server_names = [
|
||||
name
|
||||
for name, cfg in servers.items()
|
||||
if name not in _servers and _parse_boolish(cfg.get("enabled", True), default=True)
|
||||
]
|
||||
new_servers = {
|
||||
k: v
|
||||
for k, v in servers.items()
|
||||
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
|
||||
tool_names = register_mcp_servers(servers)
|
||||
if not new_server_names:
|
||||
return tool_names
|
||||
if not new_servers:
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
return _existing_tool_names()
|
||||
|
||||
with _lock:
|
||||
connected_server_names = [name for name in new_server_names if name in _servers]
|
||||
new_tool_count = sum(
|
||||
len(getattr(_servers[name], "_registered_tool_names", []))
|
||||
for name in connected_server_names
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
all_tools: List[str] = []
|
||||
failed_count = 0
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
return await _discover_and_register_server(name, cfg)
|
||||
|
||||
async def _discover_all():
|
||||
nonlocal failed_count
|
||||
server_names = list(new_servers.keys())
|
||||
# Connect to all servers in PARALLEL
|
||||
results = await asyncio.gather(
|
||||
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for name, result in zip(server_names, results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
command = new_servers.get(name, {}).get("command")
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s'%s: %s",
|
||||
name,
|
||||
f" (command={command})" if command else "",
|
||||
_format_connect_error(result),
|
||||
)
|
||||
elif isinstance(result, list):
|
||||
all_tools.extend(result)
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
failed_count = len(new_server_names) - len(connected_server_names)
|
||||
if new_tool_count or failed_count:
|
||||
summary = f" MCP: {new_tool_count} tool(s) from {len(connected_server_names)} server(s)"
|
||||
# Per-server timeouts are handled inside _discover_and_register_server.
|
||||
# The outer timeout is generous: 120s total for parallel discovery.
|
||||
_run_on_mcp_loop(_discover_all(), timeout=120)
|
||||
|
||||
_sync_mcp_toolsets(list(servers.keys()))
|
||||
|
||||
# Print summary
|
||||
total_servers = len(new_servers)
|
||||
ok_servers = total_servers - failed_count
|
||||
if all_tools or failed_count:
|
||||
summary = f" MCP: {len(all_tools)} tool(s) from {ok_servers} server(s)"
|
||||
if failed_count:
|
||||
summary += f" ({failed_count} failed)"
|
||||
logger.info(summary)
|
||||
|
||||
return tool_names
|
||||
# Return ALL registered tools (existing + newly discovered)
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def get_mcp_status() -> List[dict]:
|
||||
|
||||
@@ -127,12 +127,8 @@ def is_stt_enabled(stt_config: Optional[dict] = None) -> bool:
|
||||
|
||||
|
||||
def _has_openai_audio_backend() -> bool:
|
||||
"""Return True when OpenAI audio can use config credentials, env credentials, or the managed gateway."""
|
||||
try:
|
||||
_resolve_openai_audio_client_config()
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
"""Return True when OpenAI audio can use direct credentials or the managed gateway."""
|
||||
return bool(resolve_openai_audio_api_key() or resolve_managed_tool_gateway("openai-audio"))
|
||||
|
||||
|
||||
def _find_binary(binary_name: str) -> Optional[str]:
|
||||
@@ -581,20 +577,13 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
||||
|
||||
def _resolve_openai_audio_client_config() -> tuple[str, str]:
|
||||
"""Return direct OpenAI audio config or a managed gateway fallback."""
|
||||
stt_config = _load_stt_config()
|
||||
openai_cfg = stt_config.get("openai", {})
|
||||
cfg_api_key = openai_cfg.get("api_key", "")
|
||||
cfg_base_url = openai_cfg.get("base_url", "")
|
||||
if cfg_api_key:
|
||||
return cfg_api_key, (cfg_base_url or OPENAI_BASE_URL)
|
||||
|
||||
direct_api_key = resolve_openai_audio_api_key()
|
||||
if direct_api_key:
|
||||
return direct_api_key, OPENAI_BASE_URL
|
||||
|
||||
managed_gateway = resolve_managed_tool_gateway("openai-audio")
|
||||
if managed_gateway is None:
|
||||
message = "Neither stt.openai.api_key in config nor VOICE_TOOLS_OPENAI_KEY/OPENAI_API_KEY is set"
|
||||
message = "Neither VOICE_TOOLS_OPENAI_KEY nor OPENAI_API_KEY is set"
|
||||
if managed_nous_tools_enabled():
|
||||
message += ", and the managed OpenAI audio gateway is unavailable"
|
||||
raise ValueError(message)
|
||||
|
||||
Reference in New Issue
Block a user