diff --git a/agent/builtin_memory_provider.py b/agent/builtin_memory_provider.py new file mode 100644 index 0000000000..0be94a1f5f --- /dev/null +++ b/agent/builtin_memory_provider.py @@ -0,0 +1,113 @@ +"""BuiltinMemoryProvider — wraps MEMORY.md / USER.md as a MemoryProvider. + +Always registered as the first provider. Cannot be disabled or removed. +This is the existing Hermes memory system exposed through the provider +interface for compatibility with the MemoryManager. + +The actual storage logic lives in tools/memory_tool.py (MemoryStore). +This provider is a thin adapter that delegates to MemoryStore and +exposes the memory tool schema. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Optional + +from agent.memory_provider import MemoryProvider + +logger = logging.getLogger(__name__) + + +class BuiltinMemoryProvider(MemoryProvider): + """Built-in file-backed memory (MEMORY.md + USER.md). + + Always active, never disabled by other providers. The `memory` tool + is handled by run_agent.py's agent-level tool interception (not through + the normal registry), so get_tool_schemas() returns an empty list — + the memory tool is already wired separately. + """ + + def __init__( + self, + memory_store=None, + memory_enabled: bool = False, + user_profile_enabled: bool = False, + ): + self._store = memory_store + self._memory_enabled = memory_enabled + self._user_profile_enabled = user_profile_enabled + + @property + def name(self) -> str: + return "builtin" + + def is_available(self) -> bool: + """Built-in memory is always available.""" + return True + + def initialize(self, session_id: str, **kwargs) -> None: + """Load memory from disk if not already loaded.""" + if self._store is not None: + self._store.load_from_disk() + + def system_prompt_block(self) -> str: + """Return MEMORY.md and USER.md content for the system prompt. + + Uses the frozen snapshot captured at load time. This ensures the + system prompt stays stable throughout a session (preserving the + prompt cache), even though the live entries may change via tool calls. + """ + if not self._store: + return "" + + parts = [] + if self._memory_enabled: + mem_block = self._store.format_for_system_prompt("memory") + if mem_block: + parts.append(mem_block) + if self._user_profile_enabled: + user_block = self._store.format_for_system_prompt("user") + if user_block: + parts.append(user_block) + + return "\n\n".join(parts) + + def prefetch(self, query: str) -> str: + """Built-in memory doesn't do query-based recall — it's injected via system_prompt_block.""" + return "" + + def sync_turn(self, user_content: str, assistant_content: str) -> None: + """Built-in memory doesn't auto-sync turns — writes happen via the memory tool.""" + + def get_tool_schemas(self) -> List[Dict[str, Any]]: + """Return empty list. + + The `memory` tool is an agent-level intercepted tool, handled + specially in run_agent.py before normal tool dispatch. It's not + part of the standard tool registry. We don't duplicate it here. + """ + return [] + + def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str: + """Not used — the memory tool is intercepted in run_agent.py.""" + return json.dumps({"error": "Built-in memory tool is handled by the agent loop"}) + + def shutdown(self) -> None: + """No cleanup needed — files are saved on every write.""" + + # -- Property access for backward compatibility -------------------------- + + @property + def store(self): + """Access the underlying MemoryStore for legacy code paths.""" + return self._store + + @property + def memory_enabled(self) -> bool: + return self._memory_enabled + + @property + def user_profile_enabled(self) -> bool: + return self._user_profile_enabled diff --git a/agent/memory_manager.py b/agent/memory_manager.py new file mode 100644 index 0000000000..af999c0271 --- /dev/null +++ b/agent/memory_manager.py @@ -0,0 +1,281 @@ +"""MemoryManager — orchestrates multiple memory providers. + +Single integration point in run_agent.py. Replaces scattered per-backend +code with one manager that delegates to all registered providers. + +The BuiltinMemoryProvider is always registered first and cannot be removed. +External providers are additive — they never disable the built-in store. + +Usage in run_agent.py: + self._memory_manager = MemoryManager() + self._memory_manager.add_provider(BuiltinMemoryProvider(...)) + if honcho_configured: + self._memory_manager.add_provider(HonchoProvider(...)) + # Plugin providers are added via register_memory_provider() + + # System prompt + prompt_parts.append(self._memory_manager.build_system_prompt()) + + # Pre-turn + context = self._memory_manager.prefetch_all(user_message) + + # Post-turn + self._memory_manager.sync_all(user_msg, assistant_response) + self._memory_manager.queue_prefetch_all(user_msg) +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Optional + +from agent.memory_provider import MemoryProvider + +logger = logging.getLogger(__name__) + + +class MemoryManager: + """Orchestrates multiple memory providers. + + Providers are called in registration order. The builtin provider + is always first. Failures in one provider never block others. + """ + + def __init__(self) -> None: + self._providers: List[MemoryProvider] = [] + self._tool_to_provider: Dict[str, MemoryProvider] = {} + + # -- Registration -------------------------------------------------------- + + def add_provider(self, provider: MemoryProvider) -> None: + """Register a memory provider. + + Providers are called in registration order for all operations. + Tool name conflicts are resolved first-registered-wins. + """ + self._providers.append(provider) + + # Index tool names → provider for routing + for schema in provider.get_tool_schemas(): + tool_name = schema.get("name", "") + if tool_name and tool_name not in self._tool_to_provider: + self._tool_to_provider[tool_name] = provider + elif tool_name in self._tool_to_provider: + logger.warning( + "Memory tool name conflict: '%s' already registered by %s, " + "ignoring from %s", + tool_name, + self._tool_to_provider[tool_name].name, + provider.name, + ) + + logger.info( + "Memory provider '%s' registered (%d tools)", + provider.name, + len(provider.get_tool_schemas()), + ) + + @property + def providers(self) -> List[MemoryProvider]: + """All registered providers in order.""" + return list(self._providers) + + @property + def provider_names(self) -> List[str]: + """Names of all registered providers.""" + return [p.name for p in self._providers] + + def get_provider(self, name: str) -> Optional[MemoryProvider]: + """Get a provider by name, or None if not registered.""" + for p in self._providers: + if p.name == name: + return p + return None + + # -- System prompt ------------------------------------------------------- + + def build_system_prompt(self) -> str: + """Collect system prompt blocks from all providers. + + Returns combined text, or empty string if no providers contribute. + Each non-empty block is labeled with the provider name. + """ + blocks = [] + for provider in self._providers: + try: + block = provider.system_prompt_block() + if block and block.strip(): + blocks.append(block) + except Exception as e: + logger.warning( + "Memory provider '%s' system_prompt_block() failed: %s", + provider.name, e, + ) + return "\n\n".join(blocks) + + # -- Prefetch / recall --------------------------------------------------- + + def prefetch_all(self, query: str) -> str: + """Collect prefetch context from all providers. + + Returns merged context text labeled by provider. Empty providers + are skipped. Failures in one provider don't block others. + """ + parts = [] + for provider in self._providers: + try: + result = provider.prefetch(query) + if result and result.strip(): + parts.append(result) + except Exception as e: + logger.debug( + "Memory provider '%s' prefetch failed (non-fatal): %s", + provider.name, e, + ) + return "\n\n".join(parts) + + def queue_prefetch_all(self, query: str) -> None: + """Queue background prefetch on all providers for the next turn.""" + for provider in self._providers: + try: + provider.queue_prefetch(query) + except Exception as e: + logger.debug( + "Memory provider '%s' queue_prefetch failed (non-fatal): %s", + provider.name, e, + ) + + # -- Sync ---------------------------------------------------------------- + + def sync_all(self, user_content: str, assistant_content: str) -> None: + """Sync a completed turn to all providers.""" + for provider in self._providers: + try: + provider.sync_turn(user_content, assistant_content) + except Exception as e: + logger.warning( + "Memory provider '%s' sync_turn failed: %s", + provider.name, e, + ) + + # -- Tools --------------------------------------------------------------- + + def get_all_tool_schemas(self) -> List[Dict[str, Any]]: + """Collect tool schemas from all providers.""" + schemas = [] + seen = set() + for provider in self._providers: + try: + for schema in provider.get_tool_schemas(): + name = schema.get("name", "") + if name and name not in seen: + schemas.append(schema) + seen.add(name) + except Exception as e: + logger.warning( + "Memory provider '%s' get_tool_schemas() failed: %s", + provider.name, e, + ) + return schemas + + def get_all_tool_names(self) -> set: + """Return set of all tool names across all providers.""" + return set(self._tool_to_provider.keys()) + + def has_tool(self, tool_name: str) -> bool: + """Check if any provider handles this tool.""" + return tool_name in self._tool_to_provider + + def handle_tool_call( + self, tool_name: str, args: Dict[str, Any], **kwargs + ) -> str: + """Route a tool call to the correct provider. + + Returns JSON string result. Raises ValueError if no provider + handles the tool. + """ + provider = self._tool_to_provider.get(tool_name) + if provider is None: + return json.dumps({"error": f"No memory provider handles tool '{tool_name}'"}) + try: + return provider.handle_tool_call(tool_name, args, **kwargs) + except Exception as e: + logger.error( + "Memory provider '%s' handle_tool_call(%s) failed: %s", + provider.name, tool_name, e, + ) + return json.dumps({"error": f"Memory tool '{tool_name}' failed: {e}"}) + + # -- Lifecycle hooks ----------------------------------------------------- + + def on_turn_start(self, turn_number: int, message: str) -> None: + """Notify all providers of a new turn.""" + for provider in self._providers: + try: + provider.on_turn_start(turn_number, message) + except Exception as e: + logger.debug( + "Memory provider '%s' on_turn_start failed: %s", + provider.name, e, + ) + + def on_session_end(self, messages: List[Dict[str, Any]]) -> None: + """Notify all providers of session end.""" + for provider in self._providers: + try: + provider.on_session_end(messages) + except Exception as e: + logger.debug( + "Memory provider '%s' on_session_end failed: %s", + provider.name, e, + ) + + def on_pre_compress(self, messages: List[Dict[str, Any]]) -> None: + """Notify all providers before context compression.""" + for provider in self._providers: + try: + provider.on_pre_compress(messages) + except Exception as e: + logger.debug( + "Memory provider '%s' on_pre_compress failed: %s", + provider.name, e, + ) + + def on_memory_write(self, action: str, target: str, content: str) -> None: + """Notify external providers when the built-in memory tool writes. + + Skips the builtin provider itself (it's the source of the write). + """ + for provider in self._providers: + if provider.name == "builtin": + continue + try: + provider.on_memory_write(action, target, content) + except Exception as e: + logger.debug( + "Memory provider '%s' on_memory_write failed: %s", + provider.name, e, + ) + + def shutdown_all(self) -> None: + """Shut down all providers (reverse order for clean teardown).""" + for provider in reversed(self._providers): + try: + provider.shutdown() + except Exception as e: + logger.warning( + "Memory provider '%s' shutdown failed: %s", + provider.name, e, + ) + + def initialize_all(self, session_id: str, **kwargs) -> None: + """Initialize all providers.""" + for provider in self._providers: + try: + provider.initialize(session_id=session_id, **kwargs) + except Exception as e: + logger.warning( + "Memory provider '%s' initialize failed: %s", + provider.name, e, + ) diff --git a/agent/memory_provider.py b/agent/memory_provider.py new file mode 100644 index 0000000000..2ee45431b1 --- /dev/null +++ b/agent/memory_provider.py @@ -0,0 +1,153 @@ +"""Abstract base class for pluggable memory providers. + +Memory providers give the agent persistent recall across sessions. Multiple +providers can be active simultaneously — the MemoryManager orchestrates them. + +Built-in memory (MEMORY.md / USER.md) is always active as the first provider. +External providers (Honcho, Hindsight, Mem0, etc.) are additive — they never +disable the built-in store. + +Three registration paths: + 1. Built-in: BuiltinMemoryProvider — always present, not removable. + 2. First-party: Ship with the repo, activated by config (e.g. Honcho). + 3. Plugin: External packages register via ctx.register_memory_provider(). + +Lifecycle (called by MemoryManager, wired in run_agent.py): + initialize() — connect, create resources, warm up + system_prompt_block() — static text for the system prompt + prefetch(query) — background recall before each turn + sync_turn(user, asst) — async write after each turn + get_tool_schemas() — tool schemas to expose to the model + handle_tool_call() — dispatch a tool call + shutdown() — clean exit + +Optional hooks (override to opt in): + on_turn_start(turn, message) — per-turn tick (scope cooling, etc.) + on_session_end(messages) — end-of-session extraction + on_pre_compress(messages) — extract before context compression + on_memory_write(action, target, content) — mirror built-in memory writes +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class MemoryProvider(ABC): + """Abstract base class for memory providers.""" + + @property + @abstractmethod + def name(self) -> str: + """Short identifier for this provider (e.g. 'builtin', 'honcho', 'hindsight').""" + + # -- Core lifecycle (implement these) ------------------------------------ + + @abstractmethod + def is_available(self) -> bool: + """Return True if this provider is configured, has credentials, and is ready. + + Called during agent init to decide whether to activate the provider. + Should not make network calls — just check config and installed deps. + """ + + @abstractmethod + def initialize(self, session_id: str, **kwargs) -> None: + """Initialize for a session. + + Called once at agent startup. May create resources (banks, tables), + establish connections, start background threads, etc. + + kwargs may include: platform, model, user_id, and other session context. + """ + + def system_prompt_block(self) -> str: + """Return text to include in the system prompt. + + Called during system prompt assembly. Return empty string to skip. + This is for STATIC provider info (instructions, status). Prefetched + recall context is injected separately via prefetch(). + """ + return "" + + def prefetch(self, query: str) -> str: + """Recall relevant context for the upcoming turn. + + Called before each API call. Return formatted text to inject as + context, or empty string if nothing relevant. Implementations + should be fast — use background threads for the actual recall + and return cached results here. + """ + return "" + + def queue_prefetch(self, query: str) -> None: + """Queue a background recall for the NEXT turn. + + Called after each turn completes. The result will be consumed + by prefetch() on the next turn. Default is no-op — providers + that do background prefetching should override this. + """ + + def sync_turn(self, user_content: str, assistant_content: str) -> None: + """Persist a completed turn to the backend. + + Called after each turn. Should be non-blocking — queue for + background processing if the backend has latency. + """ + + @abstractmethod + def get_tool_schemas(self) -> List[Dict[str, Any]]: + """Return tool schemas this provider exposes. + + Each schema follows the OpenAI function calling format: + {"name": "...", "description": "...", "parameters": {...}} + + Return empty list if this provider has no tools (context-only). + """ + + def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str: + """Handle a tool call for one of this provider's tools. + + Must return a JSON string (the tool result). + Only called for tool names returned by get_tool_schemas(). + """ + raise NotImplementedError(f"Provider {self.name} does not handle tool {tool_name}") + + def shutdown(self) -> None: + """Clean shutdown — flush queues, close connections.""" + + # -- Optional hooks (override to opt in) --------------------------------- + + def on_turn_start(self, turn_number: int, message: str) -> None: + """Called at the start of each turn with the user message. + + Use for turn-counting, scope management, periodic maintenance. + """ + + def on_session_end(self, messages: List[Dict[str, Any]]) -> None: + """Called when a session ends (explicit exit or timeout). + + Use for end-of-session fact extraction, summarization, etc. + messages is the full conversation history. + """ + + def on_pre_compress(self, messages: List[Dict[str, Any]]) -> None: + """Called before context compression discards old messages. + + Use to extract insights from messages about to be compressed. + messages is the list that will be summarized/discarded. + """ + + def on_memory_write(self, action: str, target: str, content: str) -> None: + """Called when the built-in memory tool writes an entry. + + action: 'add', 'replace', or 'remove' + target: 'memory' or 'user' + content: the entry content + + Use to mirror built-in memory writes to your backend. + """ diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index c72bc59e7c..b70a1d3abb 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -152,6 +152,28 @@ class PluginContext: self._manager._plugin_tool_names.add(name) logger.debug("Plugin %s registered tool: %s", self.manifest.name, name) + # -- memory provider registration ---------------------------------------- + + def register_memory_provider(self, provider) -> None: + """Register a memory provider (must implement MemoryProvider ABC). + + The provider will be added to the MemoryManager during agent init. + Providers registered this way are additive — they never disable + the built-in MEMORY.md/USER.md store. + + Example plugin __init__.py:: + + from my_memory_backend import MyMemoryProvider + + def register(ctx): + ctx.register_memory_provider(MyMemoryProvider()) + """ + self._manager._memory_providers.append(provider) + logger.debug( + "Plugin %s registered memory provider: %s", + self.manifest.name, getattr(provider, "name", "unknown"), + ) + # -- hook registration -------------------------------------------------- def register_hook(self, hook_name: str, callback: Callable) -> None: @@ -183,6 +205,7 @@ class PluginManager: self._plugins: Dict[str, LoadedPlugin] = {} self._hooks: Dict[str, List[Callable]] = {} self._plugin_tool_names: Set[str] = set() + self._memory_providers: List = [] # MemoryProvider instances from plugins self._discovered: bool = False # ----------------------------------------------------------------------- @@ -528,3 +551,13 @@ def get_plugin_toolsets() -> List[tuple]: result.append((ts_key, label, desc)) return result + + +def get_plugin_memory_providers() -> List: + """Return MemoryProvider instances registered by plugins. + + Called during AIAgent init to add plugin memory providers to + the MemoryManager alongside built-in providers. + """ + manager = get_plugin_manager() + return list(manager._memory_providers) diff --git a/run_agent.py b/run_agent.py index 026e22c454..ef89b14a2c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1043,7 +1043,20 @@ class AIAgent: self._memory_store.load_from_disk() except Exception: pass # Memory is optional -- don't break agent init - + + # Memory provider manager — orchestrates built-in + plugin providers. + # Existing Honcho code stays as-is (will migrate in a follow-up PR). + # The manager provides the extension point for plugin memory backends. + from agent.memory_manager import MemoryManager + from agent.builtin_memory_provider import BuiltinMemoryProvider + self._memory_manager = MemoryManager() + self._memory_manager.add_provider(BuiltinMemoryProvider( + memory_store=self._memory_store, + memory_enabled=self._memory_enabled, + user_profile_enabled=self._user_profile_enabled, + )) + # Plugin memory providers are added after Honcho init (below). + # Honcho AI-native memory (cross-session user modeling) # Reads $HERMES_HOME/honcho.json (instance) or ~/.honcho/config.json (global). self._honcho = None # HonchoSessionManager | None @@ -1114,6 +1127,28 @@ class AIAgent: self._user_profile_enabled = False logger.debug("peer %s memory_mode=honcho: local USER.md writes disabled", _hcfg.peer_name or "user") + # Register plugin memory providers with the manager. + # Plugins call ctx.register_memory_provider() during discover_and_load(). + if not skip_memory: + try: + from hermes_cli.plugins import get_plugin_memory_providers + for plugin_provider in get_plugin_memory_providers(): + try: + if plugin_provider.is_available(): + self._memory_manager.add_provider(plugin_provider) + plugin_provider.initialize( + session_id=self.session_id or "", + platform=self.platform, + model=self.model, + ) + except Exception as e: + logger.warning( + "Plugin memory provider '%s' init failed: %s", + getattr(plugin_provider, "name", "unknown"), e, + ) + except Exception as e: + logger.debug("Plugin memory provider loading skipped: %s", e) + # Skills config: nudge interval for skill creation reminders self._skill_nudge_interval = 10 try: @@ -2659,6 +2694,19 @@ class AIAgent: if user_block: prompt_parts.append(user_block) + # Plugin memory providers contribute system prompt blocks. + # (Builtin provider's block is already handled above via _memory_store.) + if hasattr(self, "_memory_manager"): + for provider in self._memory_manager.providers: + if provider.name == "builtin": + continue # already handled above + try: + block = provider.system_prompt_block() + if block and block.strip(): + prompt_parts.append(block) + except Exception as e: + logger.debug("Memory provider '%s' prompt block failed: %s", provider.name, e) + has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skill_view', 'skill_manage']) if has_skills_tools: avail_toolsets = { @@ -6099,6 +6147,23 @@ class AIAgent: except Exception as e: logger.debug("Honcho prefetch failed (non-fatal): %s", e) + # Plugin memory provider prefetch (non-builtin providers only). + self._plugin_memory_context = "" + self._plugin_memory_turn_context = "" + if hasattr(self, "_memory_manager"): + for provider in self._memory_manager.providers: + if provider.name == "builtin": + continue + try: + ctx = provider.prefetch(original_user_message) + if ctx and ctx.strip(): + if not conversation_history: + self._plugin_memory_context += ("\n\n" + ctx if self._plugin_memory_context else ctx) + else: + self._plugin_memory_turn_context += ("\n\n" + ctx if self._plugin_memory_turn_context else ctx) + except Exception as e: + logger.debug("Memory provider '%s' prefetch failed: %s", provider.name, e) + # Add user message user_msg = {"role": "user", "content": user_message} messages.append(user_msg) @@ -6142,6 +6207,11 @@ class AIAgent: self._cached_system_prompt = ( self._cached_system_prompt + "\n\n" + self._honcho_context ).strip() + # Plugin memory provider context (first turn → bake into system prompt) + if self._plugin_memory_context: + self._cached_system_prompt = ( + self._cached_system_prompt + "\n\n" + self._plugin_memory_context + ).strip() # Plugin hook: on_session_start # Fired once when a brand-new session is created (not on @@ -6312,6 +6382,15 @@ class AIAgent: api_msg.get("content", ""), self._honcho_turn_context ) + # Plugin memory provider turn context injection + if idx == current_turn_user_idx and msg.get("role") == "user" and self._plugin_memory_turn_context: + existing = api_msg.get("content", "") + api_msg["content"] = ( + existing + + "\n\n[Plugin memory context — relevant memories for this turn.]\n\n" + + self._plugin_memory_turn_context + ) + # For ALL assistant messages, pass reasoning back to the API # This ensures multi-turn reasoning context is preserved if msg.get("role") == "assistant": @@ -7949,6 +8028,17 @@ class AIAgent: self._honcho_sync(original_user_message, final_response) self._queue_honcho_prefetch(original_user_message) + # Sync to plugin memory providers and queue prefetch for next turn + if final_response and not interrupted and hasattr(self, "_memory_manager"): + for provider in self._memory_manager.providers: + if provider.name == "builtin": + continue + try: + provider.sync_turn(original_user_message, final_response) + provider.queue_prefetch(original_user_message) + except Exception as e: + logger.debug("Memory provider '%s' post-turn failed: %s", provider.name, e) + # Plugin hook: post_llm_call # Fired once per turn after the tool-calling loop completes. # Plugins can use this to persist conversation data (e.g. sync diff --git a/tests/agent/test_memory_plugin_e2e.py b/tests/agent/test_memory_plugin_e2e.py new file mode 100644 index 0000000000..ba06d40c67 --- /dev/null +++ b/tests/agent/test_memory_plugin_e2e.py @@ -0,0 +1,341 @@ +"""End-to-end test: a SQLite-backed memory plugin exercising the full interface. + +This proves a real plugin can register as a MemoryProvider and get wired +into the agent loop via MemoryManager. Uses SQLite + FTS5 (stdlib, no +external deps, no API keys). +""" + +import json +import os +import sqlite3 +import tempfile +import pytest +from unittest.mock import patch, MagicMock + +from agent.memory_provider import MemoryProvider +from agent.memory_manager import MemoryManager +from agent.builtin_memory_provider import BuiltinMemoryProvider + + +# --------------------------------------------------------------------------- +# SQLite FTS5 memory provider — a real, minimal plugin implementation +# --------------------------------------------------------------------------- + + +class SQLiteMemoryProvider(MemoryProvider): + """Minimal SQLite + FTS5 memory provider for testing. + + Demonstrates the full MemoryProvider interface with a real backend. + No external dependencies — just stdlib sqlite3. + """ + + def __init__(self, db_path: str = ":memory:"): + self._db_path = db_path + self._conn = None + + @property + def name(self) -> str: + return "sqlite_memory" + + def is_available(self) -> bool: + return True # SQLite is always available + + def initialize(self, session_id: str, **kwargs) -> None: + self._conn = sqlite3.connect(self._db_path) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS memories + USING fts5(content, context, session_id) + """) + self._session_id = session_id + + def system_prompt_block(self) -> str: + if not self._conn: + return "" + count = self._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0] + if count == 0: + return "" + return ( + f"# SQLite Memory Plugin\n" + f"Active. {count} memories stored.\n" + f"Use sqlite_recall to search, sqlite_retain to store." + ) + + def prefetch(self, query: str) -> str: + if not self._conn or not query: + return "" + # FTS5 search + try: + rows = self._conn.execute( + "SELECT content FROM memories WHERE memories MATCH ? LIMIT 5", + (query,) + ).fetchall() + if not rows: + return "" + results = [row[0] for row in rows] + return "## SQLite Memory\n" + "\n".join(f"- {r}" for r in results) + except sqlite3.OperationalError: + return "" + + def sync_turn(self, user_content: str, assistant_content: str) -> None: + if not self._conn: + return + combined = f"User: {user_content}\nAssistant: {assistant_content}" + self._conn.execute( + "INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)", + (combined, "conversation", self._session_id), + ) + self._conn.commit() + + def get_tool_schemas(self): + return [ + { + "name": "sqlite_retain", + "description": "Store a fact to SQLite memory.", + "parameters": { + "type": "object", + "properties": { + "content": {"type": "string", "description": "What to remember"}, + "context": {"type": "string", "description": "Category/context"}, + }, + "required": ["content"], + }, + }, + { + "name": "sqlite_recall", + "description": "Search SQLite memory.", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + }, + }, + ] + + def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str: + if tool_name == "sqlite_retain": + content = args.get("content", "") + context = args.get("context", "explicit") + if not content: + return json.dumps({"error": "content is required"}) + self._conn.execute( + "INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)", + (content, context, self._session_id), + ) + self._conn.commit() + return json.dumps({"result": "Stored."}) + + elif tool_name == "sqlite_recall": + query = args.get("query", "") + if not query: + return json.dumps({"error": "query is required"}) + try: + rows = self._conn.execute( + "SELECT content, context FROM memories WHERE memories MATCH ? LIMIT 10", + (query,) + ).fetchall() + results = [{"content": r[0], "context": r[1]} for r in rows] + return json.dumps({"results": results}) + except sqlite3.OperationalError: + return json.dumps({"results": []}) + + return json.dumps({"error": f"Unknown tool: {tool_name}"}) + + def on_memory_write(self, action, target, content): + """Mirror built-in memory writes to SQLite.""" + if action == "add" and self._conn: + self._conn.execute( + "INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)", + (content, f"builtin_{target}", self._session_id), + ) + self._conn.commit() + + def shutdown(self): + if self._conn: + self._conn.close() + self._conn = None + + +# --------------------------------------------------------------------------- +# End-to-end tests +# --------------------------------------------------------------------------- + + +class TestSQLiteMemoryPlugin: + """Full lifecycle test with the SQLite provider.""" + + def test_full_lifecycle(self): + """Exercise init → store → recall → sync → prefetch → shutdown.""" + mgr = MemoryManager() + builtin = BuiltinMemoryProvider() + sqlite_mem = SQLiteMemoryProvider() + + mgr.add_provider(builtin) + mgr.add_provider(sqlite_mem) + + # Initialize + mgr.initialize_all(session_id="test-session-1", platform="cli") + assert sqlite_mem._conn is not None + + # System prompt — empty at first + prompt = mgr.build_system_prompt() + assert "SQLite Memory Plugin" not in prompt + + # Store via tool call + result = json.loads(mgr.handle_tool_call( + "sqlite_retain", {"content": "User prefers dark mode", "context": "preference"} + )) + assert result["result"] == "Stored." + + # System prompt now shows count + prompt = mgr.build_system_prompt() + assert "1 memories stored" in prompt + + # Recall via tool call + result = json.loads(mgr.handle_tool_call( + "sqlite_recall", {"query": "dark mode"} + )) + assert len(result["results"]) == 1 + assert "dark mode" in result["results"][0]["content"] + + # Sync a turn (auto-stores conversation) + mgr.sync_all("What's my theme?", "You prefer dark mode.") + count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0] + assert count == 2 # 1 explicit + 1 synced + + # Prefetch for next turn + prefetched = mgr.prefetch_all("dark mode") + assert "dark mode" in prefetched + + # Memory bridge — mirroring builtin writes + mgr.on_memory_write("add", "user", "Timezone: US Pacific") + count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0] + assert count == 3 + + # Shutdown + mgr.shutdown_all() + assert sqlite_mem._conn is None + + def test_tool_routing_with_builtin(self): + """Verify builtin + plugin tools coexist without conflict.""" + mgr = MemoryManager() + builtin = BuiltinMemoryProvider() + sqlite_mem = SQLiteMemoryProvider() + mgr.add_provider(builtin) + mgr.add_provider(sqlite_mem) + mgr.initialize_all(session_id="test-2") + + # Builtin has no tools + assert len(builtin.get_tool_schemas()) == 0 + # SQLite has 2 tools + schemas = mgr.get_all_tool_schemas() + names = {s["name"] for s in schemas} + assert names == {"sqlite_retain", "sqlite_recall"} + + # Routing works + assert mgr.has_tool("sqlite_retain") + assert mgr.has_tool("sqlite_recall") + assert not mgr.has_tool("memory") # builtin doesn't register this + + def test_multiple_plugins_coexist(self): + """Two plugin providers can run simultaneously.""" + mgr = MemoryManager() + p1 = SQLiteMemoryProvider() + p2 = SQLiteMemoryProvider() + # Hack name for p2 + p2._name_override = "sqlite_memory_2" + original_name = p2.__class__.name + type(p2).name = property(lambda self: getattr(self, '_name_override', 'sqlite_memory')) + + mgr.add_provider(p1) + mgr.add_provider(p2) + mgr.initialize_all(session_id="test-3") + + # Store in p1 + p1._conn.execute( + "INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)", + ("fact from p1", "test", "test-3"), + ) + p1._conn.commit() + + # Store in p2 + p2._conn.execute( + "INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)", + ("fact from p2", "test", "test-3"), + ) + p2._conn.commit() + + # Prefetch merges both + result = mgr.prefetch_all("fact") + assert "fact from p1" in result + assert "fact from p2" in result + + # Sync goes to both + mgr.sync_all("user msg", "assistant msg") + count1 = p1._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0] + count2 = p2._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0] + assert count1 == 2 # 1 explicit + 1 synced + assert count2 == 2 + + # Restore class + type(p2).name = original_name + mgr.shutdown_all() + + def test_provider_failure_isolation(self): + """Failing provider doesn't break others.""" + mgr = MemoryManager() + good = SQLiteMemoryProvider() + bad = SQLiteMemoryProvider() + + mgr.add_provider(good) + mgr.add_provider(bad) + mgr.initialize_all(session_id="test-4") + + # Break bad provider's connection + bad._conn.close() + bad._conn = None + + # Good provider still works + good._conn.execute( + "INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)", + ("still works", "test", "test-4"), + ) + good._conn.commit() + + # Sync — bad fails silently, good succeeds + mgr.sync_all("user", "assistant") + count = good._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0] + assert count == 2 + + mgr.shutdown_all() + + def test_plugin_registration_flow(self): + """Simulate the full plugin registration → agent init path.""" + from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest + + # Simulate plugin discovery + manager = PluginManager() + manifest = PluginManifest(name="sqlite-memory", source="test") + ctx = PluginContext(manifest, manager) + + # Plugin registers its provider + provider = SQLiteMemoryProvider() + ctx.register_memory_provider(provider) + + assert len(manager._memory_providers) == 1 + + # Simulate what AIAgent.__init__ does + mem_mgr = MemoryManager() + mem_mgr.add_provider(BuiltinMemoryProvider()) + for plugin_provider in manager._memory_providers: + if plugin_provider.is_available(): + mem_mgr.add_provider(plugin_provider) + plugin_provider.initialize(session_id="agent-session") + + assert len(mem_mgr.providers) == 2 + assert mem_mgr.provider_names == ["builtin", "sqlite_memory"] + assert provider._conn is not None # initialized = connection established + + mem_mgr.shutdown_all() diff --git a/tests/agent/test_memory_provider.py b/tests/agent/test_memory_provider.py new file mode 100644 index 0000000000..22942043bc --- /dev/null +++ b/tests/agent/test_memory_provider.py @@ -0,0 +1,478 @@ +"""Tests for the memory provider interface, manager, and builtin provider.""" + +import json +import pytest +from unittest.mock import MagicMock, patch + +from agent.memory_provider import MemoryProvider +from agent.memory_manager import MemoryManager +from agent.builtin_memory_provider import BuiltinMemoryProvider + + +# --------------------------------------------------------------------------- +# Concrete test provider +# --------------------------------------------------------------------------- + + +class FakeMemoryProvider(MemoryProvider): + """Minimal concrete provider for testing.""" + + def __init__(self, name="fake", available=True, tools=None): + self._name = name + self._available = available + self._tools = tools or [] + self.initialized = False + self.synced_turns = [] + self.prefetch_queries = [] + self.queued_prefetches = [] + self.turn_starts = [] + self.session_end_called = False + self.pre_compress_called = False + self.memory_writes = [] + self.shutdown_called = False + self._prefetch_result = "" + self._prompt_block = "" + + @property + def name(self) -> str: + return self._name + + def is_available(self) -> bool: + return self._available + + def initialize(self, session_id, **kwargs): + self.initialized = True + self._init_kwargs = {"session_id": session_id, **kwargs} + + def system_prompt_block(self) -> str: + return self._prompt_block + + def prefetch(self, query): + self.prefetch_queries.append(query) + return self._prefetch_result + + def queue_prefetch(self, query): + self.queued_prefetches.append(query) + + def sync_turn(self, user_content, assistant_content): + self.synced_turns.append((user_content, assistant_content)) + + def get_tool_schemas(self): + return self._tools + + def handle_tool_call(self, tool_name, args, **kwargs): + return json.dumps({"handled": tool_name, "args": args}) + + def shutdown(self): + self.shutdown_called = True + + def on_turn_start(self, turn_number, message): + self.turn_starts.append((turn_number, message)) + + def on_session_end(self, messages): + self.session_end_called = True + + def on_pre_compress(self, messages): + self.pre_compress_called = True + + def on_memory_write(self, action, target, content): + self.memory_writes.append((action, target, content)) + + +# --------------------------------------------------------------------------- +# MemoryProvider ABC tests +# --------------------------------------------------------------------------- + + +class TestMemoryProviderABC: + def test_cannot_instantiate_abstract(self): + """ABC cannot be instantiated directly.""" + with pytest.raises(TypeError): + MemoryProvider() + + def test_concrete_provider_works(self): + """Concrete implementation can be instantiated.""" + p = FakeMemoryProvider() + assert p.name == "fake" + assert p.is_available() + + def test_default_optional_hooks_are_noop(self): + """Optional hooks have default no-op implementations.""" + p = FakeMemoryProvider() + # These should not raise + p.on_turn_start(1, "hello") + p.on_session_end([]) + p.on_pre_compress([]) + p.on_memory_write("add", "memory", "test") + p.queue_prefetch("query") + p.sync_turn("user", "assistant") + p.shutdown() + + +# --------------------------------------------------------------------------- +# MemoryManager tests +# --------------------------------------------------------------------------- + + +class TestMemoryManager: + def test_empty_manager(self): + mgr = MemoryManager() + assert mgr.providers == [] + assert mgr.provider_names == [] + assert mgr.get_all_tool_schemas() == [] + assert mgr.build_system_prompt() == "" + assert mgr.prefetch_all("test") == "" + + def test_add_provider(self): + mgr = MemoryManager() + p = FakeMemoryProvider("test1") + mgr.add_provider(p) + assert len(mgr.providers) == 1 + assert mgr.provider_names == ["test1"] + + def test_get_provider_by_name(self): + mgr = MemoryManager() + p = FakeMemoryProvider("test1") + mgr.add_provider(p) + assert mgr.get_provider("test1") is p + assert mgr.get_provider("nonexistent") is None + + def test_multiple_providers(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p2 = FakeMemoryProvider("p2") + mgr.add_provider(p1) + mgr.add_provider(p2) + assert mgr.provider_names == ["p1", "p2"] + + def test_system_prompt_merges_blocks(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p1._prompt_block = "Block from p1" + p2 = FakeMemoryProvider("p2") + p2._prompt_block = "Block from p2" + mgr.add_provider(p1) + mgr.add_provider(p2) + + result = mgr.build_system_prompt() + assert "Block from p1" in result + assert "Block from p2" in result + + def test_system_prompt_skips_empty(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p1._prompt_block = "Has content" + p2 = FakeMemoryProvider("p2") + p2._prompt_block = "" + mgr.add_provider(p1) + mgr.add_provider(p2) + + result = mgr.build_system_prompt() + assert result == "Has content" + + def test_prefetch_merges_results(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p1._prefetch_result = "Memory from p1" + p2 = FakeMemoryProvider("p2") + p2._prefetch_result = "Memory from p2" + mgr.add_provider(p1) + mgr.add_provider(p2) + + result = mgr.prefetch_all("what do you know?") + assert "Memory from p1" in result + assert "Memory from p2" in result + assert p1.prefetch_queries == ["what do you know?"] + assert p2.prefetch_queries == ["what do you know?"] + + def test_prefetch_skips_empty(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p1._prefetch_result = "Has memories" + p2 = FakeMemoryProvider("p2") + p2._prefetch_result = "" + mgr.add_provider(p1) + mgr.add_provider(p2) + + result = mgr.prefetch_all("query") + assert result == "Has memories" + + def test_queue_prefetch_all(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p2 = FakeMemoryProvider("p2") + mgr.add_provider(p1) + mgr.add_provider(p2) + + mgr.queue_prefetch_all("next turn") + assert p1.queued_prefetches == ["next turn"] + assert p2.queued_prefetches == ["next turn"] + + def test_sync_all(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p2 = FakeMemoryProvider("p2") + mgr.add_provider(p1) + mgr.add_provider(p2) + + mgr.sync_all("user msg", "assistant msg") + assert p1.synced_turns == [("user msg", "assistant msg")] + assert p2.synced_turns == [("user msg", "assistant msg")] + + def test_sync_failure_doesnt_block_others(self): + """If one provider's sync fails, others still run.""" + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p1.sync_turn = MagicMock(side_effect=RuntimeError("boom")) + p2 = FakeMemoryProvider("p2") + mgr.add_provider(p1) + mgr.add_provider(p2) + + mgr.sync_all("user", "assistant") + # p1 failed but p2 still synced + assert p2.synced_turns == [("user", "assistant")] + + # -- Tool routing ------------------------------------------------------- + + def test_tool_schemas_collected(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1", tools=[ + {"name": "recall_p1", "description": "P1 recall", "parameters": {}} + ]) + p2 = FakeMemoryProvider("p2", tools=[ + {"name": "recall_p2", "description": "P2 recall", "parameters": {}} + ]) + mgr.add_provider(p1) + mgr.add_provider(p2) + + schemas = mgr.get_all_tool_schemas() + names = {s["name"] for s in schemas} + assert names == {"recall_p1", "recall_p2"} + + def test_tool_name_conflict_first_wins(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1", tools=[ + {"name": "shared_tool", "description": "From P1", "parameters": {}} + ]) + p2 = FakeMemoryProvider("p2", tools=[ + {"name": "shared_tool", "description": "From P2", "parameters": {}} + ]) + mgr.add_provider(p1) + mgr.add_provider(p2) + + assert mgr.has_tool("shared_tool") + result = json.loads(mgr.handle_tool_call("shared_tool", {"q": "test"})) + assert result["handled"] == "shared_tool" + # Should be handled by p1 (first registered) + + def test_handle_unknown_tool(self): + mgr = MemoryManager() + result = json.loads(mgr.handle_tool_call("nonexistent", {})) + assert "error" in result + + def test_tool_routing(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1", tools=[ + {"name": "p1_tool", "description": "P1", "parameters": {}} + ]) + p2 = FakeMemoryProvider("p2", tools=[ + {"name": "p2_tool", "description": "P2", "parameters": {}} + ]) + mgr.add_provider(p1) + mgr.add_provider(p2) + + r1 = json.loads(mgr.handle_tool_call("p1_tool", {"a": 1})) + assert r1["handled"] == "p1_tool" + r2 = json.loads(mgr.handle_tool_call("p2_tool", {"b": 2})) + assert r2["handled"] == "p2_tool" + + # -- Lifecycle hooks ----------------------------------------------------- + + def test_on_turn_start(self): + mgr = MemoryManager() + p = FakeMemoryProvider("p") + mgr.add_provider(p) + mgr.on_turn_start(3, "hello") + assert p.turn_starts == [(3, "hello")] + + def test_on_session_end(self): + mgr = MemoryManager() + p = FakeMemoryProvider("p") + mgr.add_provider(p) + mgr.on_session_end([{"role": "user", "content": "hi"}]) + assert p.session_end_called + + def test_on_pre_compress(self): + mgr = MemoryManager() + p = FakeMemoryProvider("p") + mgr.add_provider(p) + mgr.on_pre_compress([{"role": "user", "content": "old"}]) + assert p.pre_compress_called + + def test_on_memory_write_skips_builtin(self): + """on_memory_write should skip the builtin provider.""" + mgr = MemoryManager() + builtin = BuiltinMemoryProvider() + external = FakeMemoryProvider("external") + mgr.add_provider(builtin) + mgr.add_provider(external) + + mgr.on_memory_write("add", "memory", "test fact") + assert external.memory_writes == [("add", "memory", "test fact")] + + def test_shutdown_all_reverse_order(self): + mgr = MemoryManager() + order = [] + p1 = FakeMemoryProvider("p1") + p1.shutdown = lambda: order.append("p1") + p2 = FakeMemoryProvider("p2") + p2.shutdown = lambda: order.append("p2") + mgr.add_provider(p1) + mgr.add_provider(p2) + + mgr.shutdown_all() + assert order == ["p2", "p1"] # reverse order + + def test_initialize_all(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p2 = FakeMemoryProvider("p2") + mgr.add_provider(p1) + mgr.add_provider(p2) + + mgr.initialize_all(session_id="test-123", platform="cli") + assert p1.initialized + assert p2.initialized + assert p1._init_kwargs["session_id"] == "test-123" + assert p1._init_kwargs["platform"] == "cli" + + # -- Error resilience --------------------------------------------------- + + def test_prefetch_failure_doesnt_block(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p1.prefetch = MagicMock(side_effect=RuntimeError("network error")) + p2 = FakeMemoryProvider("p2") + p2._prefetch_result = "p2 memory" + mgr.add_provider(p1) + mgr.add_provider(p2) + + result = mgr.prefetch_all("query") + assert "p2 memory" in result + + def test_system_prompt_failure_doesnt_block(self): + mgr = MemoryManager() + p1 = FakeMemoryProvider("p1") + p1.system_prompt_block = MagicMock(side_effect=RuntimeError("broken")) + p2 = FakeMemoryProvider("p2") + p2._prompt_block = "works fine" + mgr.add_provider(p1) + mgr.add_provider(p2) + + result = mgr.build_system_prompt() + assert result == "works fine" + + +# --------------------------------------------------------------------------- +# BuiltinMemoryProvider tests +# --------------------------------------------------------------------------- + + +class TestBuiltinMemoryProvider: + def test_name(self): + p = BuiltinMemoryProvider() + assert p.name == "builtin" + + def test_always_available(self): + p = BuiltinMemoryProvider() + assert p.is_available() + + def test_no_tools(self): + """Builtin provider exposes no tools (memory tool is agent-level).""" + p = BuiltinMemoryProvider() + assert p.get_tool_schemas() == [] + + def test_system_prompt_with_store(self): + store = MagicMock() + store.format_for_system_prompt.side_effect = lambda t: f"BLOCK_{t}" if t == "memory" else f"BLOCK_{t}" + + p = BuiltinMemoryProvider( + memory_store=store, + memory_enabled=True, + user_profile_enabled=True, + ) + block = p.system_prompt_block() + assert "BLOCK_memory" in block + assert "BLOCK_user" in block + + def test_system_prompt_memory_disabled(self): + store = MagicMock() + store.format_for_system_prompt.return_value = "content" + + p = BuiltinMemoryProvider( + memory_store=store, + memory_enabled=False, + user_profile_enabled=False, + ) + assert p.system_prompt_block() == "" + + def test_system_prompt_no_store(self): + p = BuiltinMemoryProvider(memory_store=None, memory_enabled=True) + assert p.system_prompt_block() == "" + + def test_prefetch_returns_empty(self): + p = BuiltinMemoryProvider() + assert p.prefetch("anything") == "" + + def test_store_property(self): + store = MagicMock() + p = BuiltinMemoryProvider(memory_store=store) + assert p.store is store + + def test_initialize_loads_from_disk(self): + store = MagicMock() + p = BuiltinMemoryProvider(memory_store=store) + p.initialize(session_id="test") + store.load_from_disk.assert_called_once() + + +# --------------------------------------------------------------------------- +# Plugin registration tests +# --------------------------------------------------------------------------- + + +class TestPluginMemoryProviderRegistration: + def test_register_memory_provider(self): + """PluginContext.register_memory_provider adds to manager list.""" + from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest + + manager = PluginManager() + manifest = PluginManifest( + name="test-plugin", + version="1.0.0", + description="Test", + source="test", + ) + ctx = PluginContext(manifest, manager) + + fake_provider = FakeMemoryProvider("test-mem") + ctx.register_memory_provider(fake_provider) + + assert len(manager._memory_providers) == 1 + assert manager._memory_providers[0] is fake_provider + + def test_get_plugin_memory_providers(self): + """get_plugin_memory_providers returns registered providers.""" + from hermes_cli.plugins import PluginManager, get_plugin_memory_providers + + with patch("hermes_cli.plugins.get_plugin_manager") as mock_get: + mgr = PluginManager() + p1 = FakeMemoryProvider("p1") + p2 = FakeMemoryProvider("p2") + mgr._memory_providers = [p1, p2] + mock_get.return_value = mgr + + result = get_plugin_memory_providers() + assert len(result) == 2 + assert result[0] is p1 + assert result[1] is p2