feat(mcp): make selective tool loading capability-aware

Extend the salvaged MCP filtering work so utility tools are also governed by policy and server capabilities. Store the registered tool subset per server so rediscovery and status reporting stay accurate after filtering.
This commit is contained in:
teknium1
2026-03-14 06:22:02 -07:00
parent 3198cc8fd9
commit 04e151714f
2 changed files with 287 additions and 67 deletions

View File

@@ -688,7 +688,7 @@ class MCPServerTask:
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"_sampling",
"_sampling", "_registered_tool_names",
)
def __init__(self, name: str):
@@ -702,6 +702,7 @@ class MCPServerTask:
self._error: Optional[Exception] = None
self._config: dict = {}
self._sampling: Optional[SamplingHandler] = None
self._registered_tool_names: list[str] = []
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@@ -1308,16 +1309,81 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
]
def _normalize_name_filter(value: Any, label: str) -> set[str]:
"""Normalize include/exclude config to a set of tool names."""
if value is None:
return set()
if isinstance(value, str):
return {value}
if isinstance(value, (list, tuple, set)):
return {str(item) for item in value}
logger.warning("MCP config %s must be a string or list of strings; ignoring %r", label, value)
return set()
def _parse_boolish(value: Any, default: bool = True) -> bool:
"""Parse a bool-like config value with safe fallback."""
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
lowered = value.strip().lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
logger.warning("MCP config expected a boolean-ish value, got %r; using default=%s", value, default)
return default
_UTILITY_CAPABILITY_METHODS = {
"list_resources": "list_resources",
"read_resource": "read_resource",
"list_prompts": "list_prompts",
"get_prompt": "get_prompt",
}
def _select_utility_schemas(server_name: str, server: MCPServerTask, config: dict) -> List[dict]:
"""Select utility schemas based on config and server capabilities."""
tools_filter = config.get("tools") or {}
resources_enabled = _parse_boolish(tools_filter.get("resources"), default=True)
prompts_enabled = _parse_boolish(tools_filter.get("prompts"), default=True)
selected: List[dict] = []
for entry in _build_utility_schemas(server_name):
handler_key = entry["handler_key"]
if handler_key in {"list_resources", "read_resource"} and not resources_enabled:
logger.debug("MCP server '%s': skipping utility '%s' (resources disabled)", server_name, handler_key)
continue
if handler_key in {"list_prompts", "get_prompt"} and not prompts_enabled:
logger.debug("MCP server '%s': skipping utility '%s' (prompts disabled)", server_name, handler_key)
continue
required_method = _UTILITY_CAPABILITY_METHODS[handler_key]
if not hasattr(server.session, required_method):
logger.debug(
"MCP server '%s': skipping utility '%s' (session lacks %s)",
server_name,
handler_key,
required_method,
)
continue
selected.append(entry)
return selected
def _existing_tool_names() -> List[str]:
"""Return tool names for all currently connected servers."""
names: List[str] = []
for sname, server in _servers.items():
for _sname, server in _servers.items():
if hasattr(server, "_registered_tool_names"):
names.extend(server._registered_tool_names)
continue
for mcp_tool in server._tools:
schema = _convert_mcp_schema(sname, mcp_tool)
schema = _convert_mcp_schema(server.name, mcp_tool)
names.append(schema["name"])
# Also include utility tool names
for entry in _build_utility_schemas(sname):
names.append(entry["schema"]["name"])
return names
@@ -1347,11 +1413,11 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
# Rules (matching issue #690 spec):
# tools.include — whitelist: only these tool names are registered
# tools.exclude — blacklist: all tools EXCEPT these are registered
# include and exclude are mutually exclusive; include takes precedence
# include takes precedence over exclude
# Neither set → register all tools (backward-compatible default)
tools_filter = config.get("tools") or {}
include_set = set(tools_filter.get("include") or [])
exclude_set = set(tools_filter.get("exclude") or [])
include_set = _normalize_name_filter(tools_filter.get("include"), f"mcp_servers.{name}.tools.include")
exclude_set = _normalize_name_filter(tools_filter.get("exclude"), f"mcp_servers.{name}.tools.exclude")
def _should_register(tool_name: str) -> bool:
if include_set:
@@ -1378,7 +1444,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
)
registered_names.append(tool_name_prefixed)
# Register MCP Resources & Prompts utility tools
# Register MCP Resources & Prompts utility tools, filtered by config and
# only when the server actually supports the corresponding capability.
_handler_factories = {
"list_resources": _make_list_resources_handler,
"read_resource": _make_read_resource_handler,
@@ -1386,7 +1453,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
"get_prompt": _make_get_prompt_handler,
}
check_fn = _make_check_fn(name)
for entry in _build_utility_schemas(name):
for entry in _select_utility_schemas(name, server, config):
schema = entry["schema"]
handler_key = entry["handler_key"]
handler = _handler_factories[handler_key](name, server.tool_timeout)
@@ -1402,6 +1469,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
)
registered_names.append(schema["name"])
server._registered_tool_names = list(registered_names)
# Create a custom toolset so these tools are discoverable
if registered_names:
create_custom_toolset(
@@ -1448,8 +1517,9 @@ def discover_mcp_tools() -> List[str]:
# (enabled: false skips the server entirely without removing its config)
with _lock:
new_servers = {
k: v for k, v in servers.items()
if k not in _servers and v.get("enabled", True) is not False
k: v
for k, v in servers.items()
if k not in _servers and _parse_boolish(v.get("enabled", True), default=True)
}
if not new_servers:
@@ -1537,7 +1607,7 @@ def get_mcp_status() -> List[dict]:
entry = {
"name": name,
"transport": transport,
"tools": len(server._tools),
"tools": len(server._registered_tool_names) if hasattr(server, "_registered_tool_names") else len(server._tools),
"connected": True,
}
if server._sampling: