mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 15:31:38 +08:00
Compare commits
7 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
229be62098 | ||
|
|
525a859b8f | ||
|
|
48364a011f | ||
|
|
521a1df587 | ||
|
|
44b7df4090 | ||
|
|
1452c81941 | ||
|
|
13871e9a8e |
113
agent/builtin_memory_provider.py
Normal file
113
agent/builtin_memory_provider.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""BuiltinMemoryProvider — wraps MEMORY.md / USER.md as a MemoryProvider.
|
||||
|
||||
Always registered as the first provider. Cannot be disabled or removed.
|
||||
This is the existing Hermes memory system exposed through the provider
|
||||
interface for compatibility with the MemoryManager.
|
||||
|
||||
The actual storage logic lives in tools/memory_tool.py (MemoryStore).
|
||||
This provider is a thin adapter that delegates to MemoryStore and
|
||||
exposes the memory tool schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuiltinMemoryProvider(MemoryProvider):
|
||||
"""Built-in file-backed memory (MEMORY.md + USER.md).
|
||||
|
||||
Always active, never disabled by other providers. The `memory` tool
|
||||
is handled by run_agent.py's agent-level tool interception (not through
|
||||
the normal registry), so get_tool_schemas() returns an empty list —
|
||||
the memory tool is already wired separately.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory_store=None,
|
||||
memory_enabled: bool = False,
|
||||
user_profile_enabled: bool = False,
|
||||
):
|
||||
self._store = memory_store
|
||||
self._memory_enabled = memory_enabled
|
||||
self._user_profile_enabled = user_profile_enabled
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "builtin"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Built-in memory is always available."""
|
||||
return True
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
"""Load memory from disk if not already loaded."""
|
||||
if self._store is not None:
|
||||
self._store.load_from_disk()
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
"""Return MEMORY.md and USER.md content for the system prompt.
|
||||
|
||||
Uses the frozen snapshot captured at load time. This ensures the
|
||||
system prompt stays stable throughout a session (preserving the
|
||||
prompt cache), even though the live entries may change via tool calls.
|
||||
"""
|
||||
if not self._store:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
if self._memory_enabled:
|
||||
mem_block = self._store.format_for_system_prompt("memory")
|
||||
if mem_block:
|
||||
parts.append(mem_block)
|
||||
if self._user_profile_enabled:
|
||||
user_block = self._store.format_for_system_prompt("user")
|
||||
if user_block:
|
||||
parts.append(user_block)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
"""Built-in memory doesn't do query-based recall — it's injected via system_prompt_block."""
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str) -> None:
|
||||
"""Built-in memory doesn't auto-sync turns — writes happen via the memory tool."""
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""Return empty list.
|
||||
|
||||
The `memory` tool is an agent-level intercepted tool, handled
|
||||
specially in run_agent.py before normal tool dispatch. It's not
|
||||
part of the standard tool registry. We don't duplicate it here.
|
||||
"""
|
||||
return []
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||
"""Not used — the memory tool is intercepted in run_agent.py."""
|
||||
return json.dumps({"error": "Built-in memory tool is handled by the agent loop"})
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""No cleanup needed — files are saved on every write."""
|
||||
|
||||
# -- Property access for backward compatibility --------------------------
|
||||
|
||||
@property
|
||||
def store(self):
|
||||
"""Access the underlying MemoryStore for legacy code paths."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def memory_enabled(self) -> bool:
|
||||
return self._memory_enabled
|
||||
|
||||
@property
|
||||
def user_profile_enabled(self) -> bool:
|
||||
return self._user_profile_enabled
|
||||
281
agent/memory_manager.py
Normal file
281
agent/memory_manager.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""MemoryManager — orchestrates multiple memory providers.
|
||||
|
||||
Single integration point in run_agent.py. Replaces scattered per-backend
|
||||
code with one manager that delegates to all registered providers.
|
||||
|
||||
The BuiltinMemoryProvider is always registered first and cannot be removed.
|
||||
External providers are additive — they never disable the built-in store.
|
||||
|
||||
Usage in run_agent.py:
|
||||
self._memory_manager = MemoryManager()
|
||||
self._memory_manager.add_provider(BuiltinMemoryProvider(...))
|
||||
if honcho_configured:
|
||||
self._memory_manager.add_provider(HonchoProvider(...))
|
||||
# Plugin providers are added via register_memory_provider()
|
||||
|
||||
# System prompt
|
||||
prompt_parts.append(self._memory_manager.build_system_prompt())
|
||||
|
||||
# Pre-turn
|
||||
context = self._memory_manager.prefetch_all(user_message)
|
||||
|
||||
# Post-turn
|
||||
self._memory_manager.sync_all(user_msg, assistant_response)
|
||||
self._memory_manager.queue_prefetch_all(user_msg)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""Orchestrates multiple memory providers.
|
||||
|
||||
Providers are called in registration order. The builtin provider
|
||||
is always first. Failures in one provider never block others.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._providers: List[MemoryProvider] = []
|
||||
self._tool_to_provider: Dict[str, MemoryProvider] = {}
|
||||
|
||||
# -- Registration --------------------------------------------------------
|
||||
|
||||
def add_provider(self, provider: MemoryProvider) -> None:
|
||||
"""Register a memory provider.
|
||||
|
||||
Providers are called in registration order for all operations.
|
||||
Tool name conflicts are resolved first-registered-wins.
|
||||
"""
|
||||
self._providers.append(provider)
|
||||
|
||||
# Index tool names → provider for routing
|
||||
for schema in provider.get_tool_schemas():
|
||||
tool_name = schema.get("name", "")
|
||||
if tool_name and tool_name not in self._tool_to_provider:
|
||||
self._tool_to_provider[tool_name] = provider
|
||||
elif tool_name in self._tool_to_provider:
|
||||
logger.warning(
|
||||
"Memory tool name conflict: '%s' already registered by %s, "
|
||||
"ignoring from %s",
|
||||
tool_name,
|
||||
self._tool_to_provider[tool_name].name,
|
||||
provider.name,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Memory provider '%s' registered (%d tools)",
|
||||
provider.name,
|
||||
len(provider.get_tool_schemas()),
|
||||
)
|
||||
|
||||
@property
|
||||
def providers(self) -> List[MemoryProvider]:
|
||||
"""All registered providers in order."""
|
||||
return list(self._providers)
|
||||
|
||||
@property
|
||||
def provider_names(self) -> List[str]:
|
||||
"""Names of all registered providers."""
|
||||
return [p.name for p in self._providers]
|
||||
|
||||
def get_provider(self, name: str) -> Optional[MemoryProvider]:
|
||||
"""Get a provider by name, or None if not registered."""
|
||||
for p in self._providers:
|
||||
if p.name == name:
|
||||
return p
|
||||
return None
|
||||
|
||||
# -- System prompt -------------------------------------------------------
|
||||
|
||||
def build_system_prompt(self) -> str:
|
||||
"""Collect system prompt blocks from all providers.
|
||||
|
||||
Returns combined text, or empty string if no providers contribute.
|
||||
Each non-empty block is labeled with the provider name.
|
||||
"""
|
||||
blocks = []
|
||||
for provider in self._providers:
|
||||
try:
|
||||
block = provider.system_prompt_block()
|
||||
if block and block.strip():
|
||||
blocks.append(block)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Memory provider '%s' system_prompt_block() failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
return "\n\n".join(blocks)
|
||||
|
||||
# -- Prefetch / recall ---------------------------------------------------
|
||||
|
||||
def prefetch_all(self, query: str) -> str:
|
||||
"""Collect prefetch context from all providers.
|
||||
|
||||
Returns merged context text labeled by provider. Empty providers
|
||||
are skipped. Failures in one provider don't block others.
|
||||
"""
|
||||
parts = []
|
||||
for provider in self._providers:
|
||||
try:
|
||||
result = provider.prefetch(query)
|
||||
if result and result.strip():
|
||||
parts.append(result)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Memory provider '%s' prefetch failed (non-fatal): %s",
|
||||
provider.name, e,
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def queue_prefetch_all(self, query: str) -> None:
|
||||
"""Queue background prefetch on all providers for the next turn."""
|
||||
for provider in self._providers:
|
||||
try:
|
||||
provider.queue_prefetch(query)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Memory provider '%s' queue_prefetch failed (non-fatal): %s",
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
# -- Sync ----------------------------------------------------------------
|
||||
|
||||
def sync_all(self, user_content: str, assistant_content: str) -> None:
|
||||
"""Sync a completed turn to all providers."""
|
||||
for provider in self._providers:
|
||||
try:
|
||||
provider.sync_turn(user_content, assistant_content)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Memory provider '%s' sync_turn failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
# -- Tools ---------------------------------------------------------------
|
||||
|
||||
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""Collect tool schemas from all providers."""
|
||||
schemas = []
|
||||
seen = set()
|
||||
for provider in self._providers:
|
||||
try:
|
||||
for schema in provider.get_tool_schemas():
|
||||
name = schema.get("name", "")
|
||||
if name and name not in seen:
|
||||
schemas.append(schema)
|
||||
seen.add(name)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Memory provider '%s' get_tool_schemas() failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
return schemas
|
||||
|
||||
def get_all_tool_names(self) -> set:
|
||||
"""Return set of all tool names across all providers."""
|
||||
return set(self._tool_to_provider.keys())
|
||||
|
||||
def has_tool(self, tool_name: str) -> bool:
|
||||
"""Check if any provider handles this tool."""
|
||||
return tool_name in self._tool_to_provider
|
||||
|
||||
def handle_tool_call(
|
||||
self, tool_name: str, args: Dict[str, Any], **kwargs
|
||||
) -> str:
|
||||
"""Route a tool call to the correct provider.
|
||||
|
||||
Returns JSON string result. Raises ValueError if no provider
|
||||
handles the tool.
|
||||
"""
|
||||
provider = self._tool_to_provider.get(tool_name)
|
||||
if provider is None:
|
||||
return json.dumps({"error": f"No memory provider handles tool '{tool_name}'"})
|
||||
try:
|
||||
return provider.handle_tool_call(tool_name, args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Memory provider '%s' handle_tool_call(%s) failed: %s",
|
||||
provider.name, tool_name, e,
|
||||
)
|
||||
return json.dumps({"error": f"Memory tool '{tool_name}' failed: {e}"})
|
||||
|
||||
# -- Lifecycle hooks -----------------------------------------------------
|
||||
|
||||
def on_turn_start(self, turn_number: int, message: str) -> None:
|
||||
"""Notify all providers of a new turn."""
|
||||
for provider in self._providers:
|
||||
try:
|
||||
provider.on_turn_start(turn_number, message)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Memory provider '%s' on_turn_start failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Notify all providers of session end."""
|
||||
for provider in self._providers:
|
||||
try:
|
||||
provider.on_session_end(messages)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Memory provider '%s' on_session_end failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Notify all providers before context compression."""
|
||||
for provider in self._providers:
|
||||
try:
|
||||
provider.on_pre_compress(messages)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Memory provider '%s' on_pre_compress failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
||||
"""Notify external providers when the built-in memory tool writes.
|
||||
|
||||
Skips the builtin provider itself (it's the source of the write).
|
||||
"""
|
||||
for provider in self._providers:
|
||||
if provider.name == "builtin":
|
||||
continue
|
||||
try:
|
||||
provider.on_memory_write(action, target, content)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Memory provider '%s' on_memory_write failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
def shutdown_all(self) -> None:
|
||||
"""Shut down all providers (reverse order for clean teardown)."""
|
||||
for provider in reversed(self._providers):
|
||||
try:
|
||||
provider.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Memory provider '%s' shutdown failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
def initialize_all(self, session_id: str, **kwargs) -> None:
|
||||
"""Initialize all providers."""
|
||||
for provider in self._providers:
|
||||
try:
|
||||
provider.initialize(session_id=session_id, **kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Memory provider '%s' initialize failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
171
agent/memory_provider.py
Normal file
171
agent/memory_provider.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Abstract base class for pluggable memory providers.
|
||||
|
||||
Memory providers give the agent persistent recall across sessions. Multiple
|
||||
providers can be active simultaneously — the MemoryManager orchestrates them.
|
||||
|
||||
Built-in memory (MEMORY.md / USER.md) is always active as the first provider.
|
||||
External providers (Honcho, Hindsight, Mem0, etc.) are additive — they never
|
||||
disable the built-in store.
|
||||
|
||||
Three registration paths:
|
||||
1. Built-in: BuiltinMemoryProvider — always present, not removable.
|
||||
2. First-party: Ship with the repo, activated by config (e.g. Honcho).
|
||||
3. Plugin: External packages register via ctx.register_memory_provider().
|
||||
|
||||
Lifecycle (called by MemoryManager, wired in run_agent.py):
|
||||
initialize() — connect, create resources, warm up
|
||||
system_prompt_block() — static text for the system prompt
|
||||
prefetch(query) — background recall before each turn
|
||||
sync_turn(user, asst) — async write after each turn
|
||||
get_tool_schemas() — tool schemas to expose to the model
|
||||
handle_tool_call() — dispatch a tool call
|
||||
shutdown() — clean exit
|
||||
|
||||
Optional hooks (override to opt in):
|
||||
on_turn_start(turn, message) — per-turn tick (scope cooling, etc.)
|
||||
on_session_end(messages) — end-of-session extraction
|
||||
on_pre_compress(messages) — extract before context compression
|
||||
on_memory_write(action, target, content) — mirror built-in memory writes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryProvider(ABC):
|
||||
"""Abstract base class for memory providers."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Short identifier for this provider (e.g. 'builtin', 'honcho', 'hindsight')."""
|
||||
|
||||
# -- Core lifecycle (implement these) ------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Return True if this provider is configured, has credentials, and is ready.
|
||||
|
||||
Called during agent init to decide whether to activate the provider.
|
||||
Should not make network calls — just check config and installed deps.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
"""Initialize for a session.
|
||||
|
||||
Called once at agent startup. May create resources (banks, tables),
|
||||
establish connections, start background threads, etc.
|
||||
|
||||
kwargs may include: platform, model, user_id, and other session context.
|
||||
"""
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
"""Return text to include in the system prompt.
|
||||
|
||||
Called during system prompt assembly. Return empty string to skip.
|
||||
This is for STATIC provider info (instructions, status). Prefetched
|
||||
recall context is injected separately via prefetch().
|
||||
"""
|
||||
return ""
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
"""Recall relevant context for the upcoming turn.
|
||||
|
||||
Called before each API call. Return formatted text to inject as
|
||||
context, or empty string if nothing relevant. Implementations
|
||||
should be fast — use background threads for the actual recall
|
||||
and return cached results here.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def queue_prefetch(self, query: str) -> None:
|
||||
"""Queue a background recall for the NEXT turn.
|
||||
|
||||
Called after each turn completes. The result will be consumed
|
||||
by prefetch() on the next turn. Default is no-op — providers
|
||||
that do background prefetching should override this.
|
||||
"""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str) -> None:
|
||||
"""Persist a completed turn to the backend.
|
||||
|
||||
Called after each turn. Should be non-blocking — queue for
|
||||
background processing if the backend has latency.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""Return tool schemas this provider exposes.
|
||||
|
||||
Each schema follows the OpenAI function calling format:
|
||||
{"name": "...", "description": "...", "parameters": {...}}
|
||||
|
||||
Return empty list if this provider has no tools (context-only).
|
||||
"""
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||
"""Handle a tool call for one of this provider's tools.
|
||||
|
||||
Must return a JSON string (the tool result).
|
||||
Only called for tool names returned by get_tool_schemas().
|
||||
"""
|
||||
raise NotImplementedError(f"Provider {self.name} does not handle tool {tool_name}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Clean shutdown — flush queues, close connections."""
|
||||
|
||||
# -- Optional hooks (override to opt in) ---------------------------------
|
||||
|
||||
def on_turn_start(self, turn_number: int, message: str) -> None:
|
||||
"""Called at the start of each turn with the user message.
|
||||
|
||||
Use for turn-counting, scope management, periodic maintenance.
|
||||
"""
|
||||
|
||||
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Called when a session ends (explicit exit or timeout).
|
||||
|
||||
Use for end-of-session fact extraction, summarization, etc.
|
||||
messages is the full conversation history.
|
||||
"""
|
||||
|
||||
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Called before context compression discards old messages.
|
||||
|
||||
Use to extract insights from messages about to be compressed.
|
||||
messages is the list that will be summarized/discarded.
|
||||
"""
|
||||
|
||||
def get_config_schema(self) -> List[Dict[str, Any]]:
|
||||
"""Return config fields this provider needs for setup.
|
||||
|
||||
Used by 'hermes memory setup' to walk the user through configuration.
|
||||
Each field is a dict with:
|
||||
key: config key name (e.g. 'api_key', 'mode')
|
||||
description: human-readable description
|
||||
secret: True if this should go to .env (default: False)
|
||||
required: True if required (default: False)
|
||||
default: default value (optional)
|
||||
choices: list of valid values (optional)
|
||||
url: URL where user can get this credential (optional)
|
||||
env_var: explicit env var name for secrets (default: auto-generated)
|
||||
|
||||
Return empty list if no config needed (e.g. local-only providers).
|
||||
"""
|
||||
return []
|
||||
|
||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
||||
"""Called when the built-in memory tool writes an entry.
|
||||
|
||||
action: 'add', 'replace', or 'remove'
|
||||
target: 'memory' or 'user'
|
||||
content: the entry content
|
||||
|
||||
Use to mirror built-in memory writes to your backend.
|
||||
"""
|
||||
@@ -346,6 +346,11 @@ DEFAULT_CONFIG = {
|
||||
"user_profile_enabled": True,
|
||||
"memory_char_limit": 2200, # ~800 tokens at 2.75 chars/token
|
||||
"user_char_limit": 1375, # ~500 tokens at 2.75 chars/token
|
||||
# External memory provider (plugin). At most one active at a time.
|
||||
# Set to the provider name (e.g. "holographic", "hindsight", "mem0")
|
||||
# or leave empty for built-in only. Auto-detected from plugins that
|
||||
# call ctx.register_memory_provider().
|
||||
"provider": "",
|
||||
},
|
||||
|
||||
# Subagent delegation — override the provider:model used by delegate_task
|
||||
|
||||
@@ -4147,6 +4147,30 @@ For more help on a command:
|
||||
|
||||
plugins_parser.set_defaults(func=cmd_plugins)
|
||||
|
||||
# =========================================================================
|
||||
# memory command
|
||||
# =========================================================================
|
||||
memory_parser = subparsers.add_parser(
|
||||
"memory",
|
||||
help="Manage memory provider plugins",
|
||||
description=(
|
||||
"Configure which memory provider plugin is active.\n\n"
|
||||
"Memory providers give the agent persistent recall across sessions.\n"
|
||||
"Built-in memory (MEMORY.md / USER.md) is always active.\n"
|
||||
"One external provider can be active at a time."
|
||||
),
|
||||
formatter_class=__import__("argparse").RawDescriptionHelpFormatter,
|
||||
)
|
||||
memory_subparsers = memory_parser.add_subparsers(dest="memory_command")
|
||||
memory_subparsers.add_parser("setup", help="Interactive setup wizard")
|
||||
memory_subparsers.add_parser("status", help="Show current provider and config")
|
||||
|
||||
def cmd_memory(args):
|
||||
from hermes_cli.memory_setup import memory_command
|
||||
memory_command(args)
|
||||
|
||||
memory_parser.set_defaults(func=cmd_memory)
|
||||
|
||||
# =========================================================================
|
||||
# honcho command
|
||||
# =========================================================================
|
||||
|
||||
357
hermes_cli/memory_setup.py
Normal file
357
hermes_cli/memory_setup.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""hermes memory setup|status — configure memory provider plugins.
|
||||
|
||||
Auto-detects installed memory providers via the plugin system.
|
||||
Interactive curses-based UI for provider selection, then walks through
|
||||
the provider's config schema. Writes config to config.yaml + .env.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Curses-based interactive picker (same pattern as hermes tools)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -> int:
|
||||
"""Interactive single-select with arrow keys.
|
||||
|
||||
items: list of (label, description) tuples.
|
||||
Returns selected index, or default on escape/quit.
|
||||
"""
|
||||
try:
|
||||
import curses
|
||||
result = [default]
|
||||
|
||||
def _menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
curses.init_pair(3, curses.COLOR_CYAN, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Title
|
||||
try:
|
||||
stdscr.addnstr(0, 0, title, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1,
|
||||
curses.color_pair(3) if curses.has_colors() else curses.A_DIM)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, (label, desc) in enumerate(items):
|
||||
y = i + 3
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {label}"
|
||||
if desc:
|
||||
line += f" {desc}"
|
||||
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(items)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(items)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_menu)
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
# Fallback: numbered input
|
||||
print(f"\n {title}\n")
|
||||
for i, (label, desc) in enumerate(items):
|
||||
marker = "→" if i == default else " "
|
||||
d = f" {desc}" if desc else ""
|
||||
print(f" {marker} {i + 1}. {label}{d}")
|
||||
while True:
|
||||
try:
|
||||
val = input(f"\n Select [1-{len(items)}] ({default + 1}): ")
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(items):
|
||||
return idx
|
||||
except (ValueError, EOFError):
|
||||
return default
|
||||
|
||||
|
||||
def _prompt(label: str, default: str | None = None, secret: bool = False) -> str:
|
||||
"""Prompt for a value with optional default and secret masking."""
|
||||
suffix = f" [{default}]" if default else ""
|
||||
if secret:
|
||||
sys.stdout.write(f" {label}{suffix}: ")
|
||||
sys.stdout.flush()
|
||||
if sys.stdin.isatty():
|
||||
val = getpass.getpass(prompt="")
|
||||
else:
|
||||
val = sys.stdin.readline().strip()
|
||||
else:
|
||||
sys.stdout.write(f" {label}{suffix}: ")
|
||||
sys.stdout.flush()
|
||||
val = sys.stdin.readline().strip()
|
||||
return val or (default or "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_available_providers() -> list:
|
||||
"""Discover memory providers from installed plugins.
|
||||
|
||||
Returns list of (name, description, provider_instance) tuples.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_memory_providers
|
||||
providers = get_plugin_memory_providers()
|
||||
except Exception:
|
||||
providers = []
|
||||
|
||||
results = []
|
||||
for p in providers:
|
||||
name = getattr(p, "name", "unknown")
|
||||
schema = p.get_config_schema() if hasattr(p, "get_config_schema") else []
|
||||
has_secrets = any(f.get("secret") for f in schema)
|
||||
if has_secrets:
|
||||
desc = "requires API key"
|
||||
elif not schema:
|
||||
desc = "no setup needed"
|
||||
else:
|
||||
desc = "local"
|
||||
results.append((name, desc, p))
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Setup wizard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_setup(args) -> None:
|
||||
"""Interactive memory provider setup wizard."""
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
providers = _get_available_providers()
|
||||
|
||||
if not providers:
|
||||
print("\n No memory provider plugins detected.")
|
||||
print(" Install a plugin to ~/.hermes/plugins/ and try again.\n")
|
||||
return
|
||||
|
||||
# Build picker items
|
||||
items = []
|
||||
for name, desc, _ in providers:
|
||||
items.append((name, f"— {desc}"))
|
||||
items.append(("Built-in only", "— MEMORY.md / USER.md (default)"))
|
||||
|
||||
builtin_idx = len(items) - 1
|
||||
selected = _curses_select("Memory provider setup", items, default=builtin_idx)
|
||||
|
||||
config = load_config()
|
||||
if not isinstance(config.get("memory"), dict):
|
||||
config["memory"] = {}
|
||||
|
||||
# Built-in only
|
||||
if selected >= len(providers) or selected < 0:
|
||||
config["memory"]["provider"] = ""
|
||||
save_config(config)
|
||||
print("\n ✓ Memory provider: built-in only")
|
||||
print(" Saved to config.yaml\n")
|
||||
return
|
||||
|
||||
name, _, provider = providers[selected]
|
||||
schema = provider.get_config_schema() if hasattr(provider, "get_config_schema") else []
|
||||
|
||||
# Provider config section
|
||||
provider_config = config["memory"].get(name, {})
|
||||
if not isinstance(provider_config, dict):
|
||||
provider_config = {}
|
||||
|
||||
env_path = Path(os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes"))) / ".env"
|
||||
env_writes = {}
|
||||
|
||||
if schema:
|
||||
print(f"\n Configuring {name}:\n")
|
||||
|
||||
for field in schema:
|
||||
key = field["key"]
|
||||
desc = field.get("description", key)
|
||||
default = field.get("default")
|
||||
is_secret = field.get("secret", False)
|
||||
choices = field.get("choices")
|
||||
env_var = field.get("env_var")
|
||||
url = field.get("url")
|
||||
|
||||
if choices and not is_secret:
|
||||
# Use curses picker for choice fields
|
||||
choice_items = [(c, "") for c in choices]
|
||||
current = provider_config.get(key, default)
|
||||
current_idx = 0
|
||||
if current and current in choices:
|
||||
current_idx = choices.index(current)
|
||||
sel = _curses_select(f" {desc}", choice_items, default=current_idx)
|
||||
provider_config[key] = choices[sel]
|
||||
elif is_secret:
|
||||
# Prompt for secret
|
||||
existing = os.environ.get(env_var, "") if env_var else ""
|
||||
if existing:
|
||||
masked = f"...{existing[-4:]}" if len(existing) > 4 else "set"
|
||||
val = _prompt(f"{desc} (current: {masked}, blank to keep)", secret=True)
|
||||
else:
|
||||
hint = f" Get yours at {url}" if url else ""
|
||||
if hint:
|
||||
print(hint)
|
||||
val = _prompt(desc, secret=True)
|
||||
if val and env_var:
|
||||
env_writes[env_var] = val
|
||||
else:
|
||||
# Regular text prompt
|
||||
current = provider_config.get(key)
|
||||
effective_default = current or default
|
||||
val = _prompt(desc, default=str(effective_default) if effective_default else None)
|
||||
if val:
|
||||
provider_config[key] = val
|
||||
|
||||
# Write config
|
||||
config["memory"]["provider"] = name
|
||||
config["memory"][name] = provider_config
|
||||
save_config(config)
|
||||
|
||||
# Write secrets to .env
|
||||
if env_writes:
|
||||
_write_env_vars(env_path, env_writes)
|
||||
|
||||
print(f"\n ✓ Memory provider: {name}")
|
||||
print(f" ✓ Config saved to config.yaml")
|
||||
if env_writes:
|
||||
print(f" ✓ API keys saved to .env")
|
||||
print(f"\n Start a new session to activate.\n")
|
||||
|
||||
|
||||
def _write_env_vars(env_path: Path, env_writes: dict) -> None:
|
||||
"""Append or update env vars in .env file."""
|
||||
env_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
existing_lines = []
|
||||
if env_path.exists():
|
||||
existing_lines = env_path.read_text().splitlines()
|
||||
|
||||
updated_keys = set()
|
||||
new_lines = []
|
||||
for line in existing_lines:
|
||||
key_match = line.split("=", 1)[0].strip() if "=" in line else ""
|
||||
if key_match in env_writes:
|
||||
new_lines.append(f"{key_match}={env_writes[key_match]}")
|
||||
updated_keys.add(key_match)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
for key, val in env_writes.items():
|
||||
if key not in updated_keys:
|
||||
new_lines.append(f"{key}={val}")
|
||||
|
||||
env_path.write_text("\n".join(new_lines) + "\n")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_status(args) -> None:
|
||||
"""Show current memory provider config."""
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
mem_config = config.get("memory", {})
|
||||
provider_name = mem_config.get("provider", "")
|
||||
|
||||
print(f"\nMemory status\n" + "─" * 40)
|
||||
print(f" Built-in: always active")
|
||||
print(f" Provider: {provider_name or '(none — built-in only)'}")
|
||||
|
||||
if provider_name:
|
||||
provider_config = mem_config.get(provider_name, {})
|
||||
if provider_config:
|
||||
print(f"\n {provider_name} config:")
|
||||
for key, val in provider_config.items():
|
||||
print(f" {key}: {val}")
|
||||
|
||||
providers = _get_available_providers()
|
||||
found = any(name == provider_name for name, _, _ in providers)
|
||||
if found:
|
||||
print(f"\n Plugin: installed ✓")
|
||||
for pname, _, p in providers:
|
||||
if pname == provider_name:
|
||||
if p.is_available():
|
||||
print(f" Status: available ✓")
|
||||
else:
|
||||
print(f" Status: not available ✗")
|
||||
schema = p.get_config_schema() if hasattr(p, "get_config_schema") else []
|
||||
secrets = [f for f in schema if f.get("secret")]
|
||||
if secrets:
|
||||
print(f" Missing:")
|
||||
for s in secrets:
|
||||
env_var = s.get("env_var", "")
|
||||
url = s.get("url", "")
|
||||
is_set = bool(os.environ.get(env_var))
|
||||
mark = "✓" if is_set else "✗"
|
||||
line = f" {mark} {env_var}"
|
||||
if url and not is_set:
|
||||
line += f" → {url}"
|
||||
print(line)
|
||||
break
|
||||
else:
|
||||
print(f"\n Plugin: NOT installed ✗")
|
||||
print(f" Install the '{provider_name}' memory plugin to ~/.hermes/plugins/")
|
||||
|
||||
providers = _get_available_providers()
|
||||
if providers:
|
||||
print(f"\n Installed plugins:")
|
||||
for pname, desc, _ in providers:
|
||||
active = " ← active" if pname == provider_name else ""
|
||||
print(f" • {pname} ({desc}){active}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def memory_command(args) -> None:
|
||||
"""Route memory subcommands."""
|
||||
sub = getattr(args, "memory_command", None)
|
||||
if sub == "setup":
|
||||
cmd_setup(args)
|
||||
elif sub == "status":
|
||||
cmd_status(args)
|
||||
else:
|
||||
cmd_status(args)
|
||||
@@ -152,6 +152,28 @@ class PluginContext:
|
||||
self._manager._plugin_tool_names.add(name)
|
||||
logger.debug("Plugin %s registered tool: %s", self.manifest.name, name)
|
||||
|
||||
# -- memory provider registration ----------------------------------------
|
||||
|
||||
def register_memory_provider(self, provider) -> None:
|
||||
"""Register a memory provider (must implement MemoryProvider ABC).
|
||||
|
||||
The provider will be added to the MemoryManager during agent init.
|
||||
Providers registered this way are additive — they never disable
|
||||
the built-in MEMORY.md/USER.md store.
|
||||
|
||||
Example plugin __init__.py::
|
||||
|
||||
from my_memory_backend import MyMemoryProvider
|
||||
|
||||
def register(ctx):
|
||||
ctx.register_memory_provider(MyMemoryProvider())
|
||||
"""
|
||||
self._manager._memory_providers.append(provider)
|
||||
logger.debug(
|
||||
"Plugin %s registered memory provider: %s",
|
||||
self.manifest.name, getattr(provider, "name", "unknown"),
|
||||
)
|
||||
|
||||
# -- hook registration --------------------------------------------------
|
||||
|
||||
def register_hook(self, hook_name: str, callback: Callable) -> None:
|
||||
@@ -183,6 +205,7 @@ class PluginManager:
|
||||
self._plugins: Dict[str, LoadedPlugin] = {}
|
||||
self._hooks: Dict[str, List[Callable]] = {}
|
||||
self._plugin_tool_names: Set[str] = set()
|
||||
self._memory_providers: List = [] # MemoryProvider instances from plugins
|
||||
self._discovered: bool = False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -528,3 +551,13 @@ def get_plugin_toolsets() -> List[tuple]:
|
||||
result.append((ts_key, label, desc))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_plugin_memory_providers() -> List:
|
||||
"""Return MemoryProvider instances registered by plugins.
|
||||
|
||||
Called during AIAgent init to add plugin memory providers to
|
||||
the MemoryManager alongside built-in providers.
|
||||
"""
|
||||
manager = get_plugin_manager()
|
||||
return list(manager._memory_providers)
|
||||
|
||||
1
optional-skills/communication/DESCRIPTION.md
Normal file
1
optional-skills/communication/DESCRIPTION.md
Normal file
@@ -0,0 +1 @@
|
||||
Communication and decision-making frameworks — structured response formats for proposals, trade-off analysis, and stakeholder-ready recommendations.
|
||||
103
optional-skills/communication/one-three-one-rule/SKILL.md
Normal file
103
optional-skills/communication/one-three-one-rule/SKILL.md
Normal file
@@ -0,0 +1,103 @@
|
||||
---
|
||||
name: one-three-one-rule
|
||||
description: >
|
||||
Structured decision-making framework for technical proposals and trade-off analysis.
|
||||
When the user faces a choice between multiple approaches (architecture decisions,
|
||||
tool selection, refactoring strategies, migration paths), this skill produces a
|
||||
1-3-1 format: one clear problem statement, three distinct options with pros/cons,
|
||||
and one concrete recommendation with definition of done and implementation plan.
|
||||
Use when the user asks for a "1-3-1", says "give me options", or needs help
|
||||
choosing between competing approaches.
|
||||
version: 1.0.0
|
||||
author: Willard Moore
|
||||
license: MIT
|
||||
category: communication
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [communication, decision-making, proposals, trade-offs]
|
||||
---
|
||||
|
||||
# 1-3-1 Communication Rule
|
||||
|
||||
Structured decision-making format for when a task has multiple viable approaches and the user needs a clear recommendation. Produces a concise problem framing, three options with trade-offs, and an actionable plan for the recommended path.
|
||||
|
||||
## When to Use
|
||||
|
||||
- The user explicitly asks for a "1-3-1" response.
|
||||
- The user says "give me options" or "what are my choices" for a technical decision.
|
||||
- A task has multiple viable approaches with meaningful trade-offs (architecture, tooling, migration strategy).
|
||||
- The user needs a proposal they can forward to a team or stakeholder.
|
||||
|
||||
Do NOT use for simple questions with one obvious answer, debugging sessions, or tasks where the user has already decided on an approach.
|
||||
|
||||
## Procedure
|
||||
|
||||
1. **Problem** (one sentence)
|
||||
- State the core decision or desired outcome in a single concise sentence.
|
||||
- Focus on the *what*, not the *how* — no implementation details, tool names, or specific technologies.
|
||||
- Keep it tight. If you need "and", you're describing two problems.
|
||||
|
||||
2. **Options** (exactly three)
|
||||
- Present three distinct, viable approaches labeled A, B, C.
|
||||
- Each option gets a brief description, pros, and cons.
|
||||
- Options should represent genuinely different strategies, not minor variations of the same approach.
|
||||
|
||||
3. **Recommendation** (one option)
|
||||
- State which option you recommend and why, based on the user's context and priorities.
|
||||
- Be direct — this is your professional judgment, not a hedge.
|
||||
|
||||
4. **Definition of Done**
|
||||
- List the specific success criteria for the recommended option.
|
||||
- These are concrete, verifiable outcomes — not vague aspirations.
|
||||
- If the user picks a different option, revise this section to match.
|
||||
|
||||
5. **Implementation Plan**
|
||||
- Concrete steps to execute the recommended option.
|
||||
- Include specific commands, tools, or actions where applicable.
|
||||
- If the user picks a different option, revise this section to match.
|
||||
|
||||
## Verification
|
||||
|
||||
- Response has exactly one Problem sentence.
|
||||
- Response has exactly three Options (A, B, C) with pros and cons for each.
|
||||
- Response has a single Recommendation that picks one option with reasoning.
|
||||
- Definition of Done and Implementation Plan align with the recommended option.
|
||||
- If the user selects a different option, Recommendation, DoD, and Implementation Plan update accordingly.
|
||||
|
||||
## Example
|
||||
|
||||
User: "Give me a 1-3-1 for adding retry logic to our API client."
|
||||
|
||||
Agent:
|
||||
|
||||
**Problem:** API calls to external services fail intermittently under load and need graceful retry handling.
|
||||
|
||||
**Options:**
|
||||
|
||||
- **Option A: Exponential backoff with jitter in a shared utility.**
|
||||
Pros: Simple, well-understood pattern; single implementation shared across all callers.
|
||||
Cons: No circuit-breaking; callers still block during retries.
|
||||
|
||||
- **Option B: Retry library (tenacity) as a decorator on each endpoint call.**
|
||||
Pros: Configurable per-endpoint; built-in support for circuit breakers and rate limits.
|
||||
Cons: Adds a dependency; decorator pattern can obscure control flow.
|
||||
|
||||
- **Option C: Queue-based async retry with dead letter handling.**
|
||||
Pros: Non-blocking; failed calls are tracked and replayable; scales well.
|
||||
Cons: Significant complexity increase; requires a queue backend; overkill for low-volume services.
|
||||
|
||||
**Recommendation:** Option A. The call volume doesn't justify queue infrastructure, and a shared utility with exponential backoff + jitter covers the failure mode cleanly without new dependencies.
|
||||
|
||||
**Definition of Done:**
|
||||
- `api_retry()` utility handles transient HTTP errors (429, 502, 503, 504).
|
||||
- Retries up to 3 times with exponential backoff and jitter.
|
||||
- All external API calls route through the retry utility.
|
||||
- Logs each retry attempt with endpoint and attempt number.
|
||||
- Tests cover retry exhaustion, immediate success, and transient-then-success scenarios.
|
||||
|
||||
**Implementation Plan:**
|
||||
1. Create `utils/api_retry.py` with configurable max retries, base delay, and retryable status codes.
|
||||
2. Add jitter using `random.uniform(0, base_delay)` to prevent thundering herd.
|
||||
3. Wrap existing API calls in `api_client.py` with the retry utility.
|
||||
4. Add unit tests mocking HTTP responses for each retry scenario.
|
||||
5. Verify under load with a simple stress test against a flaky endpoint mock.
|
||||
384
plugins/cognitive-memory/__init__.py
Normal file
384
plugins/cognitive-memory/__init__.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""Cognitive memory plugin — MemoryProvider interface.
|
||||
|
||||
Semantic memory with vector embeddings (via litellm), auto-classification,
|
||||
contradiction detection, importance decay, and time-based forgetting.
|
||||
Local SQLite storage with binary-packed float32 embeddings.
|
||||
|
||||
Original PR #727 by 0xbyt4, adapted to MemoryProvider ABC.
|
||||
|
||||
Requires: litellm (for embeddings via any provider — OpenAI, Cohere, etc.)
|
||||
Config via environment: uses litellm's standard env vars (OPENAI_API_KEY, etc.)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DB_DIR = Path(os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes"))) / "cognitive_memory"
|
||||
_EMBEDDING_DIM = 1536 # text-embedding-3-small default
|
||||
_SIMILARITY_DEDUP_THRESHOLD = 0.95
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Embedding helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_embedding(text: str) -> Optional[List[float]]:
|
||||
"""Get embedding via litellm."""
|
||||
try:
|
||||
import litellm
|
||||
resp = litellm.embedding(model="text-embedding-3-small", input=[text])
|
||||
return resp.data[0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.debug("Embedding failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _cosine_similarity(a: List[float], b: List[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
mag_a = math.sqrt(sum(x * x for x in a))
|
||||
mag_b = math.sqrt(sum(x * x for x in b))
|
||||
if mag_a == 0 or mag_b == 0:
|
||||
return 0.0
|
||||
return dot / (mag_a * mag_b)
|
||||
|
||||
|
||||
def _pack_embedding(emb: List[float]) -> bytes:
|
||||
return struct.pack(f"{len(emb)}f", *emb)
|
||||
|
||||
|
||||
def _unpack_embedding(data: bytes) -> List[float]:
|
||||
n = len(data) // 4
|
||||
return list(struct.unpack(f"{n}f", data))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CATEGORY_PATTERNS = {
|
||||
"preference": [r"\b(?:prefer|like|love|hate|dislike|favorite)\b"],
|
||||
"correction": [r"\b(?:actually|no,|wrong|incorrect|not right)\b"],
|
||||
"fact": [r"\b(?:is|are|was|were|has|have)\b"],
|
||||
"procedure": [r"\b(?:first|then|step|always|never|usually)\b"],
|
||||
"environment": [r"\b(?:running|using|installed|version|os|platform)\b"],
|
||||
}
|
||||
|
||||
|
||||
def _classify(content: str) -> str:
|
||||
content_lower = content.lower()
|
||||
for category, patterns in _CATEGORY_PATTERNS.items():
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, content_lower):
|
||||
return category
|
||||
return "general"
|
||||
|
||||
|
||||
def _estimate_importance(content: str, category: str) -> float:
|
||||
base = {"correction": 0.9, "preference": 0.7, "procedure": 0.6}.get(category, 0.5)
|
||||
# Longer content slightly more important
|
||||
length_bonus = min(len(content) / 500, 0.2)
|
||||
return min(base + length_bonus, 1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
COGNITIVE_RECALL_SCHEMA = {
|
||||
"name": "cognitive_recall",
|
||||
"description": (
|
||||
"Semantic memory with automatic classification and importance scoring. "
|
||||
"Actions: recall (search by meaning), store (add a fact), "
|
||||
"forget (remove by ID), status (memory stats)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["recall", "store", "forget", "status"],
|
||||
"description": "Action to perform.",
|
||||
},
|
||||
"query": {"type": "string", "description": "Search query (for 'recall')."},
|
||||
"content": {"type": "string", "description": "Fact to store (for 'store')."},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["preference", "fact", "procedure", "environment", "correction", "general"],
|
||||
"description": "Category (auto-detected if omitted).",
|
||||
},
|
||||
"memory_id": {"type": "integer", "description": "Memory ID (for 'forget')."},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CognitiveMemoryProvider(MemoryProvider):
|
||||
"""Semantic memory with embeddings, classification, and forgetting."""
|
||||
|
||||
def __init__(self):
|
||||
self._conn = None
|
||||
self._decay_half_life = 30 # days
|
||||
self._last_decay = 0
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cognitive"
|
||||
|
||||
def get_config_schema(self):
|
||||
return [
|
||||
{"key": "embedding_model", "description": "Embedding model (litellm format)", "default": "text-embedding-3-small"},
|
||||
{"key": "decay_half_life", "description": "Importance decay half-life in days (0=disabled)", "default": "30"},
|
||||
]
|
||||
|
||||
def is_available(self) -> bool:
|
||||
try:
|
||||
import litellm # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
_DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||
db_path = _DB_DIR / "cognitive.db"
|
||||
self._conn = sqlite3.connect(str(db_path))
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
content TEXT NOT NULL,
|
||||
category TEXT DEFAULT 'general',
|
||||
importance REAL DEFAULT 0.5,
|
||||
embedding BLOB,
|
||||
retrieval_count INTEGER DEFAULT 0,
|
||||
helpful_count INTEGER DEFAULT 0,
|
||||
created_at REAL,
|
||||
updated_at REAL,
|
||||
deleted INTEGER DEFAULT 0
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mem_importance ON memories(importance DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_mem_category ON memories(category);
|
||||
""")
|
||||
self._conn.commit()
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._conn:
|
||||
return ""
|
||||
try:
|
||||
count = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM memories WHERE deleted = 0"
|
||||
).fetchone()[0]
|
||||
except Exception:
|
||||
count = 0
|
||||
if count == 0:
|
||||
return ""
|
||||
return (
|
||||
f"# Cognitive Memory\n"
|
||||
f"Active. {count} memories with semantic recall and importance scoring.\n"
|
||||
f"Use cognitive_recall to search, store facts, or check status.\n"
|
||||
f"Memories decay over time — frequently used facts persist, unused ones fade."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
if not self._conn or not query:
|
||||
return ""
|
||||
emb = _get_embedding(query)
|
||||
if not emb:
|
||||
return ""
|
||||
try:
|
||||
rows = self._conn.execute(
|
||||
"SELECT id, content, importance, embedding FROM memories "
|
||||
"WHERE deleted = 0 AND embedding IS NOT NULL "
|
||||
"ORDER BY importance DESC LIMIT 50"
|
||||
).fetchall()
|
||||
scored = []
|
||||
now = time.time()
|
||||
for row in rows:
|
||||
mem_emb = _unpack_embedding(row[3])
|
||||
sim = _cosine_similarity(emb, mem_emb)
|
||||
importance = row[2]
|
||||
score = 0.5 * sim + 0.3 * importance + 0.2 * max(0, 1 - (now - (row[0] * 86400)) / (30 * 86400))
|
||||
if sim > 0.3:
|
||||
scored.append((score, row[1]))
|
||||
scored.sort(reverse=True)
|
||||
if not scored:
|
||||
return ""
|
||||
lines = [f"- {content}" for _, content in scored[:5]]
|
||||
return "## Cognitive Memory\n" + "\n".join(lines)
|
||||
except Exception as e:
|
||||
logger.debug("Cognitive prefetch failed: %s", e)
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str) -> None:
|
||||
# Run decay cycle periodically
|
||||
self._maybe_decay()
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [COGNITIVE_RECALL_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
if tool_name != "cognitive_recall":
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
action = args.get("action", "")
|
||||
|
||||
if action == "store":
|
||||
return self._store(args)
|
||||
elif action == "recall":
|
||||
return self._recall(args)
|
||||
elif action == "forget":
|
||||
return self._forget(args)
|
||||
elif action == "status":
|
||||
return self._status()
|
||||
return json.dumps({"error": f"Unknown action: {action}"})
|
||||
|
||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
||||
if action == "add" and self._conn and content:
|
||||
category = "preference" if target == "user" else _classify(content)
|
||||
importance = _estimate_importance(content, category)
|
||||
emb = _get_embedding(content)
|
||||
now = time.time()
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, category, importance, embedding, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(content, category, importance, _pack_embedding(emb) if emb else None, now, now),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
# -- Internal methods ----------------------------------------------------
|
||||
|
||||
def _store(self, args: dict) -> str:
|
||||
content = args.get("content", "")
|
||||
if not content:
|
||||
return json.dumps({"error": "content is required"})
|
||||
|
||||
category = args.get("category") or _classify(content)
|
||||
importance = _estimate_importance(content, category)
|
||||
emb = _get_embedding(content)
|
||||
|
||||
# Dedup check
|
||||
if emb:
|
||||
rows = self._conn.execute(
|
||||
"SELECT id, embedding FROM memories WHERE deleted = 0 AND embedding IS NOT NULL"
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
existing_emb = _unpack_embedding(row[1])
|
||||
if _cosine_similarity(emb, existing_emb) > _SIMILARITY_DEDUP_THRESHOLD:
|
||||
return json.dumps({"error": "Very similar memory already exists", "existing_id": row[0]})
|
||||
|
||||
now = time.time()
|
||||
cur = self._conn.execute(
|
||||
"INSERT INTO memories (content, category, importance, embedding, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(content, category, importance, _pack_embedding(emb) if emb else None, now, now),
|
||||
)
|
||||
self._conn.commit()
|
||||
return json.dumps({"id": cur.lastrowid, "category": category, "importance": round(importance, 2)})
|
||||
|
||||
def _recall(self, args: dict) -> str:
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
|
||||
emb = _get_embedding(query)
|
||||
if not emb:
|
||||
return json.dumps({"error": "Embedding generation failed"})
|
||||
|
||||
rows = self._conn.execute(
|
||||
"SELECT id, content, category, importance, embedding, created_at FROM memories "
|
||||
"WHERE deleted = 0 AND embedding IS NOT NULL "
|
||||
"ORDER BY importance DESC LIMIT 50"
|
||||
).fetchall()
|
||||
|
||||
now = time.time()
|
||||
results = []
|
||||
for row in rows:
|
||||
mem_emb = _unpack_embedding(row[4])
|
||||
sim = _cosine_similarity(emb, mem_emb)
|
||||
days_old = (now - (row[5] or now)) / 86400
|
||||
recency = max(0, 1 - days_old / 90)
|
||||
score = 0.5 * sim + 0.3 * row[3] + 0.2 * recency
|
||||
if sim > 0.2:
|
||||
results.append({
|
||||
"id": row[0], "content": row[1], "category": row[2],
|
||||
"score": round(score, 3), "similarity": round(sim, 3),
|
||||
})
|
||||
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
# Bump retrieval counts
|
||||
for r in results[:10]:
|
||||
self._conn.execute(
|
||||
"UPDATE memories SET retrieval_count = retrieval_count + 1 WHERE id = ?",
|
||||
(r["id"],),
|
||||
)
|
||||
self._conn.commit()
|
||||
return json.dumps({"results": results[:10], "count": len(results[:10])})
|
||||
|
||||
def _forget(self, args: dict) -> str:
|
||||
memory_id = args.get("memory_id")
|
||||
if memory_id is None:
|
||||
return json.dumps({"error": "memory_id is required"})
|
||||
self._conn.execute("UPDATE memories SET deleted = 1 WHERE id = ?", (int(memory_id),))
|
||||
self._conn.commit()
|
||||
return json.dumps({"forgotten": True, "id": memory_id})
|
||||
|
||||
def _status(self) -> str:
|
||||
total = self._conn.execute("SELECT COUNT(*) FROM memories WHERE deleted = 0").fetchone()[0]
|
||||
by_cat = self._conn.execute(
|
||||
"SELECT category, COUNT(*) FROM memories WHERE deleted = 0 GROUP BY category"
|
||||
).fetchall()
|
||||
return json.dumps({
|
||||
"total": total,
|
||||
"by_category": {row[0]: row[1] for row in by_cat},
|
||||
"decay_half_life_days": self._decay_half_life,
|
||||
})
|
||||
|
||||
def _maybe_decay(self) -> None:
|
||||
"""Run importance decay every ~1 hour."""
|
||||
now = time.time()
|
||||
if now - self._last_decay < 3600:
|
||||
return
|
||||
self._last_decay = now
|
||||
if not self._conn or self._decay_half_life <= 0:
|
||||
return
|
||||
try:
|
||||
factor = 0.5 ** (1.0 / self._decay_half_life)
|
||||
self._conn.execute(
|
||||
"UPDATE memories SET importance = importance * ? WHERE deleted = 0",
|
||||
(factor,),
|
||||
)
|
||||
# Prune very low importance
|
||||
self._conn.execute(
|
||||
"UPDATE memories SET deleted = 1 WHERE deleted = 0 AND importance < 0.05"
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception as e:
|
||||
logger.debug("Cognitive decay failed: %s", e)
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
"""Register cognitive memory as a memory provider plugin."""
|
||||
ctx.register_memory_provider(CognitiveMemoryProvider())
|
||||
6
plugins/cognitive-memory/plugin.yaml
Normal file
6
plugins/cognitive-memory/plugin.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
name: cognitive-memory
|
||||
version: 1.0.0
|
||||
description: >
|
||||
Semantic memory with vector embeddings, auto-classification, contradiction
|
||||
detection, importance decay, and time-based forgetting. Local SQLite storage,
|
||||
requires litellm for embeddings.
|
||||
373
plugins/hermes-memory-store/__init__.py
Normal file
373
plugins/hermes-memory-store/__init__.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""hermes-memory-store — holographic memory plugin using MemoryProvider interface.
|
||||
|
||||
Registers as a MemoryProvider plugin, giving the agent structured fact storage
|
||||
with entity resolution, trust scoring, and HRR-based compositional retrieval.
|
||||
|
||||
Original plugin by dusterbloom (PR #2351), adapted to the MemoryProvider ABC.
|
||||
|
||||
Config in ~/.hermes/config.yaml:
|
||||
plugins:
|
||||
hermes-memory-store:
|
||||
db_path: ~/.hermes/memory_store.db
|
||||
auto_extract: false
|
||||
default_trust: 0.5
|
||||
min_trust_threshold: 0.3
|
||||
temporal_decay_half_life: 0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from .store import MemoryStore
|
||||
from .retrieval import FactRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool schemas (unchanged from original PR)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FACT_STORE_SCHEMA = {
|
||||
"name": "fact_store",
|
||||
"description": (
|
||||
"Deep structured memory with algebraic reasoning. "
|
||||
"Use alongside the memory tool — memory for always-on context, "
|
||||
"fact_store for deep recall and compositional queries.\n\n"
|
||||
"ACTIONS (simple → powerful):\n"
|
||||
"• add — Store a fact the user would expect you to remember.\n"
|
||||
"• search — Keyword lookup ('editor config', 'deploy process').\n"
|
||||
"• probe — Entity recall: ALL facts about a person/thing.\n"
|
||||
"• related — What connects to an entity? Structural adjacency.\n"
|
||||
"• reason — Compositional: facts connected to MULTIPLE entities simultaneously.\n"
|
||||
"• contradict — Memory hygiene: find facts making conflicting claims.\n"
|
||||
"• update/remove/list — CRUD operations.\n\n"
|
||||
"IMPORTANT: Before answering questions about the user, ALWAYS probe or reason first."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "search", "probe", "related", "reason", "contradict", "update", "remove", "list"],
|
||||
},
|
||||
"content": {"type": "string", "description": "Fact content (required for 'add')."},
|
||||
"query": {"type": "string", "description": "Search query (required for 'search')."},
|
||||
"entity": {"type": "string", "description": "Entity name for 'probe'/'related'."},
|
||||
"entities": {"type": "array", "items": {"type": "string"}, "description": "Entity names for 'reason'."},
|
||||
"fact_id": {"type": "integer", "description": "Fact ID for 'update'/'remove'."},
|
||||
"category": {"type": "string", "enum": ["user_pref", "project", "tool", "general"]},
|
||||
"tags": {"type": "string", "description": "Comma-separated tags."},
|
||||
"trust_delta": {"type": "number", "description": "Trust adjustment for 'update'."},
|
||||
"min_trust": {"type": "number", "description": "Minimum trust filter (default: 0.3)."},
|
||||
"limit": {"type": "integer", "description": "Max results (default: 10)."},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
FACT_FEEDBACK_SCHEMA = {
|
||||
"name": "fact_feedback",
|
||||
"description": (
|
||||
"Rate a fact after using it. Mark 'helpful' if accurate, 'unhelpful' if outdated. "
|
||||
"This trains the memory — good facts rise, bad facts sink."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {"type": "string", "enum": ["helpful", "unhelpful"]},
|
||||
"fact_id": {"type": "integer", "description": "The fact ID to rate."},
|
||||
},
|
||||
"required": ["action", "fact_id"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_plugin_config() -> dict:
|
||||
config_path = Path("~/.hermes/config.yaml").expanduser()
|
||||
if not config_path.exists():
|
||||
return {}
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
all_config = yaml.safe_load(f) or {}
|
||||
return all_config.get("plugins", {}).get("hermes-memory-store", {}) or {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class HolographicMemoryProvider(MemoryProvider):
|
||||
"""Holographic memory with structured facts, entity resolution, and HRR retrieval."""
|
||||
|
||||
def __init__(self, config: dict | None = None):
|
||||
self._config = config or _load_plugin_config()
|
||||
self._store = None
|
||||
self._retriever = None
|
||||
self._min_trust = float(self._config.get("min_trust_threshold", 0.3))
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "holographic"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True # SQLite is always available, numpy is optional
|
||||
|
||||
def get_config_schema(self):
|
||||
return [
|
||||
{"key": "db_path", "description": "SQLite database path", "default": "~/.hermes/memory_store.db"},
|
||||
{"key": "auto_extract", "description": "Auto-extract facts at session end", "default": "false", "choices": ["true", "false"]},
|
||||
{"key": "default_trust", "description": "Default trust score for new facts", "default": "0.5"},
|
||||
{"key": "hrr_dim", "description": "HRR vector dimensions", "default": "1024"},
|
||||
]
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
db_path = self._config.get("db_path", "~/.hermes/memory_store.db")
|
||||
default_trust = float(self._config.get("default_trust", 0.5))
|
||||
hrr_dim = int(self._config.get("hrr_dim", 1024))
|
||||
hrr_weight = float(self._config.get("hrr_weight", 0.3))
|
||||
temporal_decay = int(self._config.get("temporal_decay_half_life", 0))
|
||||
|
||||
self._store = MemoryStore(db_path=db_path, default_trust=default_trust, hrr_dim=hrr_dim)
|
||||
self._retriever = FactRetriever(
|
||||
store=self._store,
|
||||
temporal_decay_half_life=temporal_decay,
|
||||
hrr_weight=hrr_weight,
|
||||
hrr_dim=hrr_dim,
|
||||
)
|
||||
self._session_id = session_id
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._store:
|
||||
return ""
|
||||
try:
|
||||
total = self._store._conn.execute(
|
||||
"SELECT COUNT(*) FROM facts"
|
||||
).fetchone()[0]
|
||||
except Exception:
|
||||
total = 0
|
||||
if total == 0:
|
||||
return ""
|
||||
return (
|
||||
f"# Holographic Memory\n"
|
||||
f"Active. {total} facts stored with entity resolution and trust scoring.\n"
|
||||
f"Use fact_store to search, probe entities, reason across entities, or add facts.\n"
|
||||
f"Use fact_feedback to rate facts after using them (trains trust scores)."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
if not self._retriever or not query:
|
||||
return ""
|
||||
try:
|
||||
results = self._retriever.search(query, min_trust=self._min_trust, limit=5)
|
||||
if not results:
|
||||
return ""
|
||||
lines = []
|
||||
for r in results:
|
||||
trust = r.get("trust", 0)
|
||||
lines.append(f"- [{trust:.1f}] {r.get('content', '')}")
|
||||
return "## Holographic Memory\n" + "\n".join(lines)
|
||||
except Exception as e:
|
||||
logger.debug("Holographic prefetch failed: %s", e)
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str) -> None:
|
||||
# Holographic memory stores explicit facts via tools, not auto-sync.
|
||||
# The on_session_end hook handles auto-extraction if configured.
|
||||
pass
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [FACT_STORE_SCHEMA, FACT_FEEDBACK_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||
if tool_name == "fact_store":
|
||||
return self._handle_fact_store(args)
|
||||
elif tool_name == "fact_feedback":
|
||||
return self._handle_fact_feedback(args)
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
|
||||
if not self._config.get("auto_extract", False):
|
||||
return
|
||||
if not self._store or not messages:
|
||||
return
|
||||
self._auto_extract_facts(messages)
|
||||
|
||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
||||
"""Mirror built-in memory writes as facts."""
|
||||
if action == "add" and self._store and content:
|
||||
try:
|
||||
category = "user_pref" if target == "user" else "general"
|
||||
self._store.add_fact(content, category=category)
|
||||
except Exception as e:
|
||||
logger.debug("Holographic memory_write mirror failed: %s", e)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._store = None
|
||||
self._retriever = None
|
||||
|
||||
# -- Tool handlers -------------------------------------------------------
|
||||
|
||||
def _handle_fact_store(self, args: dict) -> str:
|
||||
try:
|
||||
action = args["action"]
|
||||
store = self._store
|
||||
retriever = self._retriever
|
||||
|
||||
if action == "add":
|
||||
fact_id = store.add_fact(
|
||||
args["content"],
|
||||
category=args.get("category", "general"),
|
||||
tags=args.get("tags", ""),
|
||||
)
|
||||
return json.dumps({"fact_id": fact_id, "status": "added"})
|
||||
|
||||
elif action == "search":
|
||||
results = retriever.search(
|
||||
args["query"],
|
||||
category=args.get("category"),
|
||||
min_trust=float(args.get("min_trust", self._min_trust)),
|
||||
limit=int(args.get("limit", 10)),
|
||||
)
|
||||
return json.dumps({"results": results, "count": len(results)})
|
||||
|
||||
elif action == "probe":
|
||||
results = retriever.probe(
|
||||
args["entity"],
|
||||
category=args.get("category"),
|
||||
limit=int(args.get("limit", 10)),
|
||||
)
|
||||
return json.dumps({"results": results, "count": len(results)})
|
||||
|
||||
elif action == "related":
|
||||
results = retriever.related(
|
||||
args["entity"],
|
||||
category=args.get("category"),
|
||||
limit=int(args.get("limit", 10)),
|
||||
)
|
||||
return json.dumps({"results": results, "count": len(results)})
|
||||
|
||||
elif action == "reason":
|
||||
entities = args.get("entities", [])
|
||||
if not entities:
|
||||
return json.dumps({"error": "reason requires 'entities' list"})
|
||||
results = retriever.reason(
|
||||
entities,
|
||||
category=args.get("category"),
|
||||
limit=int(args.get("limit", 10)),
|
||||
)
|
||||
return json.dumps({"results": results, "count": len(results)})
|
||||
|
||||
elif action == "contradict":
|
||||
results = retriever.contradict(
|
||||
category=args.get("category"),
|
||||
limit=int(args.get("limit", 10)),
|
||||
)
|
||||
return json.dumps({"results": results, "count": len(results)})
|
||||
|
||||
elif action == "update":
|
||||
updated = store.update_fact(
|
||||
int(args["fact_id"]),
|
||||
content=args.get("content"),
|
||||
trust_delta=float(args["trust_delta"]) if "trust_delta" in args else None,
|
||||
tags=args.get("tags"),
|
||||
category=args.get("category"),
|
||||
)
|
||||
return json.dumps({"updated": updated})
|
||||
|
||||
elif action == "remove":
|
||||
removed = store.remove_fact(int(args["fact_id"]))
|
||||
return json.dumps({"removed": removed})
|
||||
|
||||
elif action == "list":
|
||||
facts = store.list_facts(
|
||||
category=args.get("category"),
|
||||
min_trust=float(args.get("min_trust", 0.0)),
|
||||
limit=int(args.get("limit", 10)),
|
||||
)
|
||||
return json.dumps({"facts": facts, "count": len(facts)})
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown action: {action}"})
|
||||
|
||||
except KeyError as exc:
|
||||
return json.dumps({"error": f"Missing required argument: {exc}"})
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": str(exc)})
|
||||
|
||||
def _handle_fact_feedback(self, args: dict) -> str:
|
||||
try:
|
||||
fact_id = int(args["fact_id"])
|
||||
helpful = args["action"] == "helpful"
|
||||
result = self._store.record_feedback(fact_id, helpful=helpful)
|
||||
return json.dumps(result)
|
||||
except KeyError as exc:
|
||||
return json.dumps({"error": f"Missing required argument: {exc}"})
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": str(exc)})
|
||||
|
||||
# -- Auto-extraction (on_session_end) ------------------------------------
|
||||
|
||||
def _auto_extract_facts(self, messages: list) -> None:
|
||||
_PREF_PATTERNS = [
|
||||
re.compile(r'\bI\s+(?:prefer|like|love|use|want|need)\s+(.+)', re.IGNORECASE),
|
||||
re.compile(r'\bmy\s+(?:favorite|preferred|default)\s+\w+\s+is\s+(.+)', re.IGNORECASE),
|
||||
re.compile(r'\bI\s+(?:always|never|usually)\s+(.+)', re.IGNORECASE),
|
||||
]
|
||||
_DECISION_PATTERNS = [
|
||||
re.compile(r'\bwe\s+(?:decided|agreed|chose)\s+(?:to\s+)?(.+)', re.IGNORECASE),
|
||||
re.compile(r'\bthe\s+project\s+(?:uses|needs|requires)\s+(.+)', re.IGNORECASE),
|
||||
]
|
||||
|
||||
extracted = 0
|
||||
for msg in messages:
|
||||
if msg.get("role") != "user":
|
||||
continue
|
||||
content = msg.get("content", "")
|
||||
if not isinstance(content, str) or len(content) < 10:
|
||||
continue
|
||||
|
||||
for pattern in _PREF_PATTERNS:
|
||||
if pattern.search(content):
|
||||
try:
|
||||
self._store.add_fact(content[:400], category="user_pref")
|
||||
extracted += 1
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
for pattern in _DECISION_PATTERNS:
|
||||
if pattern.search(content):
|
||||
try:
|
||||
self._store.add_fact(content[:400], category="project")
|
||||
extracted += 1
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
if extracted:
|
||||
logger.info("Auto-extracted %d facts from conversation", extracted)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def register(ctx) -> None:
|
||||
"""Register the holographic memory provider with the plugin system."""
|
||||
config = _load_plugin_config()
|
||||
provider = HolographicMemoryProvider(config=config)
|
||||
ctx.register_memory_provider(provider)
|
||||
203
plugins/hermes-memory-store/holographic.py
Normal file
203
plugins/hermes-memory-store/holographic.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Holographic Reduced Representations (HRR) with phase encoding.
|
||||
|
||||
HRRs are a vector symbolic architecture for encoding compositional structure
|
||||
into fixed-width distributed representations. This module uses *phase vectors*:
|
||||
each concept is a vector of angles in [0, 2π). The algebraic operations are:
|
||||
|
||||
bind — circular convolution (phase addition) — associates two concepts
|
||||
unbind — circular correlation (phase subtraction) — retrieves a bound value
|
||||
bundle — superposition (circular mean) — merges multiple concepts
|
||||
|
||||
Phase encoding is numerically stable, avoids the magnitude collapse of
|
||||
traditional complex-number HRRs, and maps cleanly to cosine similarity.
|
||||
|
||||
Atoms are generated deterministically from SHA-256 so representations are
|
||||
identical across processes, machines, and language versions.
|
||||
|
||||
References:
|
||||
Plate (1995) — Holographic Reduced Representations
|
||||
Gayler (2004) — Vector Symbolic Architectures answer Jackendoff's challenges
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import struct
|
||||
import math
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
_HAS_NUMPY = True
|
||||
except ImportError:
|
||||
_HAS_NUMPY = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TWO_PI = 2.0 * math.pi
|
||||
|
||||
|
||||
def _require_numpy() -> None:
|
||||
if not _HAS_NUMPY:
|
||||
raise RuntimeError("numpy is required for holographic operations")
|
||||
|
||||
|
||||
def encode_atom(word: str, dim: int = 1024) -> "np.ndarray":
|
||||
"""Deterministic phase vector via SHA-256 counter blocks.
|
||||
|
||||
Uses hashlib (not numpy RNG) for cross-platform reproducibility.
|
||||
|
||||
Algorithm:
|
||||
- Generate enough SHA-256 blocks by hashing f"{word}:{i}" for i=0,1,2,...
|
||||
- Concatenate digests, interpret as uint16 values via struct.unpack
|
||||
- Scale to [0, 2π): phases = values * (2π / 65536)
|
||||
- Truncate to dim elements
|
||||
- Returns np.float64 array of shape (dim,)
|
||||
"""
|
||||
_require_numpy()
|
||||
|
||||
# Each SHA-256 digest is 32 bytes = 16 uint16 values.
|
||||
values_per_block = 16
|
||||
blocks_needed = math.ceil(dim / values_per_block)
|
||||
|
||||
uint16_values: list[int] = []
|
||||
for i in range(blocks_needed):
|
||||
digest = hashlib.sha256(f"{word}:{i}".encode()).digest()
|
||||
uint16_values.extend(struct.unpack("<16H", digest))
|
||||
|
||||
phases = np.array(uint16_values[:dim], dtype=np.float64) * (_TWO_PI / 65536.0)
|
||||
return phases
|
||||
|
||||
|
||||
def bind(a: "np.ndarray", b: "np.ndarray") -> "np.ndarray":
|
||||
"""Circular convolution = element-wise phase addition.
|
||||
|
||||
Binding associates two concepts into a single composite vector.
|
||||
The result is dissimilar to both inputs (quasi-orthogonal).
|
||||
"""
|
||||
_require_numpy()
|
||||
return (a + b) % _TWO_PI
|
||||
|
||||
|
||||
def unbind(memory: "np.ndarray", key: "np.ndarray") -> "np.ndarray":
|
||||
"""Circular correlation = element-wise phase subtraction.
|
||||
|
||||
Unbinding retrieves the value associated with a key from a memory vector.
|
||||
unbind(bind(a, b), a) ≈ b (up to superposition noise)
|
||||
"""
|
||||
_require_numpy()
|
||||
return (memory - key) % _TWO_PI
|
||||
|
||||
|
||||
def bundle(*vectors: "np.ndarray") -> "np.ndarray":
|
||||
"""Superposition via circular mean of complex exponentials.
|
||||
|
||||
Bundling merges multiple vectors into one that is similar to each input.
|
||||
The result can hold O(sqrt(dim)) items before similarity degrades.
|
||||
"""
|
||||
_require_numpy()
|
||||
complex_sum = np.sum([np.exp(1j * v) for v in vectors], axis=0)
|
||||
return np.angle(complex_sum) % _TWO_PI
|
||||
|
||||
|
||||
def similarity(a: "np.ndarray", b: "np.ndarray") -> float:
|
||||
"""Phase cosine similarity. Range [-1, 1].
|
||||
|
||||
Returns 1.0 for identical vectors, near 0.0 for random (unrelated) vectors,
|
||||
and -1.0 for perfectly anti-correlated vectors.
|
||||
"""
|
||||
_require_numpy()
|
||||
return float(np.mean(np.cos(a - b)))
|
||||
|
||||
|
||||
def encode_text(text: str, dim: int = 1024) -> "np.ndarray":
|
||||
"""Bag-of-words: bundle of atom vectors for each token.
|
||||
|
||||
Tokenizes by lowercasing, splitting on whitespace, and stripping
|
||||
leading/trailing punctuation from each token.
|
||||
|
||||
Returns bundle of all token atom vectors.
|
||||
If text is empty or produces no tokens, returns encode_atom("__hrr_empty__", dim).
|
||||
"""
|
||||
_require_numpy()
|
||||
|
||||
tokens = [
|
||||
token.strip(".,!?;:\"'()[]{}")
|
||||
for token in text.lower().split()
|
||||
]
|
||||
tokens = [t for t in tokens if t]
|
||||
|
||||
if not tokens:
|
||||
return encode_atom("__hrr_empty__", dim)
|
||||
|
||||
atom_vectors = [encode_atom(token, dim) for token in tokens]
|
||||
return bundle(*atom_vectors)
|
||||
|
||||
|
||||
def encode_fact(content: str, entities: list[str], dim: int = 1024) -> "np.ndarray":
|
||||
"""Structured encoding: content bound to ROLE_CONTENT, each entity bound to ROLE_ENTITY, all bundled.
|
||||
|
||||
Role vectors are reserved atoms: "__hrr_role_content__", "__hrr_role_entity__"
|
||||
|
||||
Components:
|
||||
1. bind(encode_text(content, dim), encode_atom("__hrr_role_content__", dim))
|
||||
2. For each entity: bind(encode_atom(entity.lower(), dim), encode_atom("__hrr_role_entity__", dim))
|
||||
3. bundle all components together
|
||||
|
||||
This enables algebraic extraction:
|
||||
unbind(fact, bind(entity, ROLE_ENTITY)) ≈ content_vector
|
||||
"""
|
||||
_require_numpy()
|
||||
|
||||
role_content = encode_atom("__hrr_role_content__", dim)
|
||||
role_entity = encode_atom("__hrr_role_entity__", dim)
|
||||
|
||||
components: list[np.ndarray] = [
|
||||
bind(encode_text(content, dim), role_content)
|
||||
]
|
||||
|
||||
for entity in entities:
|
||||
components.append(bind(encode_atom(entity.lower(), dim), role_entity))
|
||||
|
||||
return bundle(*components)
|
||||
|
||||
|
||||
def phases_to_bytes(phases: "np.ndarray") -> bytes:
|
||||
"""Serialize phase vector to bytes. float64 tobytes — 8 KB at dim=1024."""
|
||||
_require_numpy()
|
||||
return phases.tobytes()
|
||||
|
||||
|
||||
def bytes_to_phases(data: bytes) -> "np.ndarray":
|
||||
"""Deserialize bytes back to phase vector. Inverse of phases_to_bytes.
|
||||
|
||||
The .copy() call is required because frombuffer returns a read-only view
|
||||
backed by the bytes object; callers expect a mutable array.
|
||||
"""
|
||||
_require_numpy()
|
||||
return np.frombuffer(data, dtype=np.float64).copy()
|
||||
|
||||
|
||||
def snr_estimate(dim: int, n_items: int) -> float:
|
||||
"""Signal-to-noise ratio estimate for holographic storage.
|
||||
|
||||
SNR = sqrt(dim / n_items) when n_items > 0, else inf.
|
||||
|
||||
The SNR falls below 2.0 when n_items > dim / 4, meaning retrieval
|
||||
errors become likely. Logs a warning when this threshold is crossed.
|
||||
"""
|
||||
_require_numpy()
|
||||
|
||||
if n_items <= 0:
|
||||
return float("inf")
|
||||
|
||||
snr = math.sqrt(dim / n_items)
|
||||
|
||||
if snr < 2.0:
|
||||
logger.warning(
|
||||
"HRR storage near capacity: SNR=%.2f (dim=%d, n_items=%d). "
|
||||
"Retrieval accuracy may degrade. Consider increasing dim or reducing stored items.",
|
||||
snr,
|
||||
dim,
|
||||
n_items,
|
||||
)
|
||||
|
||||
return snr
|
||||
6
plugins/hermes-memory-store/plugin.yaml
Normal file
6
plugins/hermes-memory-store/plugin.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
name: hermes-memory-store
|
||||
version: 0.1.0
|
||||
description: Structured memory backend with SQLite storage, trust scoring, entity resolution, and hybrid keyword/BM25 retrieval.
|
||||
author: peppi
|
||||
hooks:
|
||||
- on_session_end
|
||||
597
plugins/hermes-memory-store/retrieval.py
Normal file
597
plugins/hermes-memory-store/retrieval.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""Hybrid keyword/BM25 retrieval for the memory store.
|
||||
|
||||
Ported from KIK memory_agent.py — combines FTS5 full-text search with
|
||||
Jaccard similarity reranking and trust-weighted scoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .store import MemoryStore
|
||||
|
||||
try:
|
||||
from . import holographic as hrr
|
||||
except ImportError:
|
||||
import holographic as hrr # type: ignore[no-redef]
|
||||
|
||||
|
||||
class FactRetriever:
|
||||
"""Multi-strategy fact retrieval with trust-weighted scoring."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: MemoryStore,
|
||||
temporal_decay_half_life: int = 0, # days, 0 = disabled
|
||||
fts_weight: float = 0.4,
|
||||
jaccard_weight: float = 0.3,
|
||||
hrr_weight: float = 0.3,
|
||||
hrr_dim: int = 1024,
|
||||
):
|
||||
self.store = store
|
||||
self.half_life = temporal_decay_half_life
|
||||
self.hrr_dim = hrr_dim
|
||||
|
||||
# Auto-redistribute weights if numpy unavailable
|
||||
if hrr_weight > 0 and not hrr._HAS_NUMPY:
|
||||
fts_weight = 0.6
|
||||
jaccard_weight = 0.4
|
||||
hrr_weight = 0.0
|
||||
|
||||
self.fts_weight = fts_weight
|
||||
self.jaccard_weight = jaccard_weight
|
||||
self.hrr_weight = hrr_weight
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
category: str | None = None,
|
||||
min_trust: float = 0.3,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Hybrid search: FTS5 candidates → Jaccard rerank → trust weighting.
|
||||
|
||||
Pipeline:
|
||||
1. FTS5 search: Get limit*3 candidates from SQLite full-text search
|
||||
2. Jaccard boost: Token overlap between query and fact content
|
||||
3. Trust weighting: final_score = relevance * trust_score
|
||||
4. Temporal decay (optional): decay = 0.5^(age_days / half_life)
|
||||
|
||||
Returns list of dicts with fact data + 'score' field, sorted by score desc.
|
||||
"""
|
||||
# Stage 1: Get FTS5 candidates (more than limit for reranking headroom)
|
||||
candidates = self._fts_candidates(query, category, min_trust, limit * 3)
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
# Stage 2: Rerank with Jaccard + trust + optional decay
|
||||
query_tokens = self._tokenize(query)
|
||||
scored = []
|
||||
|
||||
for fact in candidates:
|
||||
content_tokens = self._tokenize(fact["content"])
|
||||
tag_tokens = self._tokenize(fact.get("tags", ""))
|
||||
all_tokens = content_tokens | tag_tokens
|
||||
|
||||
jaccard = self._jaccard_similarity(query_tokens, all_tokens)
|
||||
fts_score = fact.get("fts_rank", 0.0)
|
||||
|
||||
# HRR similarity
|
||||
if self.hrr_weight > 0 and fact.get("hrr_vector"):
|
||||
fact_vec = hrr.bytes_to_phases(fact["hrr_vector"])
|
||||
query_vec = hrr.encode_text(query, self.hrr_dim)
|
||||
hrr_sim = (hrr.similarity(query_vec, fact_vec) + 1.0) / 2.0 # shift to [0,1]
|
||||
else:
|
||||
hrr_sim = 0.5 # neutral
|
||||
|
||||
# Combine FTS5 + Jaccard + HRR
|
||||
relevance = (self.fts_weight * fts_score
|
||||
+ self.jaccard_weight * jaccard
|
||||
+ self.hrr_weight * hrr_sim)
|
||||
|
||||
# Trust weighting
|
||||
score = relevance * fact["trust_score"]
|
||||
|
||||
# Optional temporal decay
|
||||
if self.half_life > 0:
|
||||
score *= self._temporal_decay(fact.get("updated_at") or fact.get("created_at"))
|
||||
|
||||
fact["score"] = score
|
||||
scored.append(fact)
|
||||
|
||||
# Sort by score descending, return top limit
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
results = scored[:limit]
|
||||
# Strip raw HRR bytes — callers expect JSON-serializable dicts
|
||||
for fact in results:
|
||||
fact.pop("hrr_vector", None)
|
||||
return results
|
||||
|
||||
def probe(
|
||||
self,
|
||||
entity: str,
|
||||
category: str | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Compositional entity query using HRR algebra.
|
||||
|
||||
Unbinds entity from memory bank to extract associated content.
|
||||
This is NOT keyword search — it uses algebraic structure to find facts
|
||||
where the entity plays a structural role.
|
||||
|
||||
Falls back to FTS5 search if numpy unavailable.
|
||||
"""
|
||||
if not hrr._HAS_NUMPY:
|
||||
# Fallback to keyword search on entity name
|
||||
return self.search(entity, category=category, limit=limit)
|
||||
|
||||
conn = self.store._conn
|
||||
|
||||
# Encode entity as role-bound vector
|
||||
role_entity = hrr.encode_atom("__hrr_role_entity__", self.hrr_dim)
|
||||
entity_vec = hrr.encode_atom(entity.lower(), self.hrr_dim)
|
||||
probe_key = hrr.bind(entity_vec, role_entity)
|
||||
|
||||
# Try category-specific bank first, then all facts
|
||||
if category:
|
||||
bank_name = f"cat:{category}"
|
||||
bank_row = conn.execute(
|
||||
"SELECT vector FROM memory_banks WHERE bank_name = ?",
|
||||
(bank_name,),
|
||||
).fetchone()
|
||||
if bank_row:
|
||||
bank_vec = hrr.bytes_to_phases(bank_row["vector"])
|
||||
extracted = hrr.unbind(bank_vec, probe_key)
|
||||
# Use extracted signal to score individual facts
|
||||
return self._score_facts_by_vector(
|
||||
extracted, category=category, limit=limit
|
||||
)
|
||||
|
||||
# Score against individual fact vectors directly
|
||||
where = "WHERE hrr_vector IS NOT NULL"
|
||||
params: list = []
|
||||
if category:
|
||||
where += " AND category = ?"
|
||||
params.append(category)
|
||||
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT fact_id, content, category, tags, trust_score,
|
||||
retrieval_count, helpful_count, created_at, updated_at,
|
||||
hrr_vector
|
||||
FROM facts
|
||||
{where}
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
# Final fallback: keyword search
|
||||
return self.search(entity, category=category, limit=limit)
|
||||
|
||||
scored = []
|
||||
for row in rows:
|
||||
fact = dict(row)
|
||||
fact_vec = hrr.bytes_to_phases(fact.pop("hrr_vector"))
|
||||
# Unbind probe key from fact to see if entity is structurally present
|
||||
residual = hrr.unbind(fact_vec, probe_key)
|
||||
# Compare residual against content signal
|
||||
role_content = hrr.encode_atom("__hrr_role_content__", self.hrr_dim)
|
||||
content_vec = hrr.bind(hrr.encode_text(fact["content"], self.hrr_dim), role_content)
|
||||
sim = hrr.similarity(residual, content_vec)
|
||||
fact["score"] = (sim + 1.0) / 2.0 * fact["trust_score"]
|
||||
scored.append(fact)
|
||||
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored[:limit]
|
||||
|
||||
def related(
|
||||
self,
|
||||
entity: str,
|
||||
category: str | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Discover facts that share structural connections with an entity.
|
||||
|
||||
Unlike probe (which finds facts *about* an entity), related finds
|
||||
facts that are connected through shared context — e.g., other entities
|
||||
mentioned alongside this one, or content that overlaps structurally.
|
||||
|
||||
Falls back to FTS5 search if numpy unavailable.
|
||||
"""
|
||||
if not hrr._HAS_NUMPY:
|
||||
return self.search(entity, category=category, limit=limit)
|
||||
|
||||
conn = self.store._conn
|
||||
|
||||
# Encode entity as a bare atom (not role-bound — we want ANY structural match)
|
||||
entity_vec = hrr.encode_atom(entity.lower(), self.hrr_dim)
|
||||
|
||||
# Get all facts with vectors
|
||||
where = "WHERE hrr_vector IS NOT NULL"
|
||||
params: list = []
|
||||
if category:
|
||||
where += " AND category = ?"
|
||||
params.append(category)
|
||||
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT fact_id, content, category, tags, trust_score,
|
||||
retrieval_count, helpful_count, created_at, updated_at,
|
||||
hrr_vector
|
||||
FROM facts
|
||||
{where}
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
return self.search(entity, category=category, limit=limit)
|
||||
|
||||
# Score each fact by how much the entity's atom appears in its vector
|
||||
# This catches both role-bound entity matches AND content word matches
|
||||
scored = []
|
||||
for row in rows:
|
||||
fact = dict(row)
|
||||
fact_vec = hrr.bytes_to_phases(fact.pop("hrr_vector"))
|
||||
|
||||
# Check structural similarity: unbind entity from fact
|
||||
residual = hrr.unbind(fact_vec, entity_vec)
|
||||
# A high-similarity residual to ANY known role vector means this entity
|
||||
# plays a structural role in the fact
|
||||
role_entity = hrr.encode_atom("__hrr_role_entity__", self.hrr_dim)
|
||||
role_content = hrr.encode_atom("__hrr_role_content__", self.hrr_dim)
|
||||
|
||||
entity_role_sim = hrr.similarity(residual, role_entity)
|
||||
content_role_sim = hrr.similarity(residual, role_content)
|
||||
# Take the max — entity could appear in either role
|
||||
best_sim = max(entity_role_sim, content_role_sim)
|
||||
|
||||
fact["score"] = (best_sim + 1.0) / 2.0 * fact["trust_score"]
|
||||
scored.append(fact)
|
||||
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored[:limit]
|
||||
|
||||
def reason(
|
||||
self,
|
||||
entities: list[str],
|
||||
category: str | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Multi-entity compositional query — vector-space JOIN.
|
||||
|
||||
Given multiple entities, algebraically intersects their structural
|
||||
connections to find facts related to ALL of them simultaneously.
|
||||
This is compositional reasoning that no embedding DB can do.
|
||||
|
||||
Example: reason(["peppi", "backend"]) finds facts where peppi AND
|
||||
backend both play structural roles — without keyword matching.
|
||||
|
||||
Falls back to FTS5 search if numpy unavailable.
|
||||
"""
|
||||
if not hrr._HAS_NUMPY or not entities:
|
||||
# Fallback: search with all entities as keywords
|
||||
query = " ".join(entities)
|
||||
return self.search(query, category=category, limit=limit)
|
||||
|
||||
conn = self.store._conn
|
||||
role_entity = hrr.encode_atom("__hrr_role_entity__", self.hrr_dim)
|
||||
|
||||
# For each entity, compute what the bank "remembers" about it
|
||||
# by unbinding entity+role from each fact vector
|
||||
entity_residuals = []
|
||||
for entity in entities:
|
||||
entity_vec = hrr.encode_atom(entity.lower(), self.hrr_dim)
|
||||
probe_key = hrr.bind(entity_vec, role_entity)
|
||||
entity_residuals.append(probe_key)
|
||||
|
||||
# The intersection key: bundle all probe keys, then use it to find
|
||||
# facts that are structurally connected to ALL entities
|
||||
intersection_key = hrr.bundle(*entity_residuals) if len(entity_residuals) > 1 else entity_residuals[0]
|
||||
|
||||
# Get all facts with vectors
|
||||
where = "WHERE hrr_vector IS NOT NULL"
|
||||
params: list = []
|
||||
if category:
|
||||
where += " AND category = ?"
|
||||
params.append(category)
|
||||
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT fact_id, content, category, tags, trust_score,
|
||||
retrieval_count, helpful_count, created_at, updated_at,
|
||||
hrr_vector
|
||||
FROM facts
|
||||
{where}
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
query = " ".join(entities)
|
||||
return self.search(query, category=category, limit=limit)
|
||||
|
||||
# Score each fact: unbind the intersection key and check if the
|
||||
# residual is coherent (high self-similarity = structured match)
|
||||
scored = []
|
||||
for row in rows:
|
||||
fact = dict(row)
|
||||
fact_vec = hrr.bytes_to_phases(fact.pop("hrr_vector"))
|
||||
|
||||
# Unbind intersection key from fact
|
||||
residual = hrr.unbind(fact_vec, intersection_key)
|
||||
|
||||
# Score by how much EACH entity is present in this fact
|
||||
# A fact scores high only if ALL entities have structural presence
|
||||
entity_scores = []
|
||||
for entity in entities:
|
||||
entity_vec = hrr.encode_atom(entity.lower(), self.hrr_dim)
|
||||
probe_key = hrr.bind(entity_vec, role_entity)
|
||||
single_residual = hrr.unbind(fact_vec, probe_key)
|
||||
# Check residual against content role (does this entity participate?)
|
||||
role_content = hrr.encode_atom("__hrr_role_content__", self.hrr_dim)
|
||||
sim = hrr.similarity(single_residual, role_content)
|
||||
entity_scores.append(sim)
|
||||
|
||||
# Use minimum score — fact must match ALL entities, not just some
|
||||
# This is the AND semantics (vs OR which would use mean/max)
|
||||
min_sim = min(entity_scores)
|
||||
fact["score"] = (min_sim + 1.0) / 2.0 * fact["trust_score"]
|
||||
scored.append(fact)
|
||||
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored[:limit]
|
||||
|
||||
def contradict(
|
||||
self,
|
||||
category: str | None = None,
|
||||
threshold: float = 0.3,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Find potentially contradictory facts via entity overlap + content divergence.
|
||||
|
||||
Two facts contradict when they share entities (same subject) but have
|
||||
low content-vector similarity (different claims). This is automated
|
||||
memory hygiene — no other memory system does this.
|
||||
|
||||
Returns pairs of facts with a contradiction score.
|
||||
Falls back to empty list if numpy unavailable.
|
||||
"""
|
||||
if not hrr._HAS_NUMPY:
|
||||
return []
|
||||
|
||||
conn = self.store._conn
|
||||
|
||||
# Get all facts with vectors and their linked entities
|
||||
where = "WHERE f.hrr_vector IS NOT NULL"
|
||||
params: list = []
|
||||
if category:
|
||||
where += " AND f.category = ?"
|
||||
params.append(category)
|
||||
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT f.fact_id, f.content, f.category, f.tags, f.trust_score,
|
||||
f.created_at, f.updated_at, f.hrr_vector
|
||||
FROM facts f
|
||||
{where}
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
|
||||
if len(rows) < 2:
|
||||
return []
|
||||
|
||||
# Build entity sets per fact
|
||||
fact_entities: dict[int, set[str]] = {}
|
||||
for row in rows:
|
||||
fid = row["fact_id"]
|
||||
entity_rows = conn.execute(
|
||||
"""
|
||||
SELECT e.name FROM entities e
|
||||
JOIN fact_entities fe ON fe.entity_id = e.entity_id
|
||||
WHERE fe.fact_id = ?
|
||||
""",
|
||||
(fid,),
|
||||
).fetchall()
|
||||
fact_entities[fid] = {r["name"].lower() for r in entity_rows}
|
||||
|
||||
# Compare all pairs: high entity overlap + low content similarity = contradiction
|
||||
facts = [dict(r) for r in rows]
|
||||
contradictions = []
|
||||
|
||||
for i in range(len(facts)):
|
||||
for j in range(i + 1, len(facts)):
|
||||
f1, f2 = facts[i], facts[j]
|
||||
ents1 = fact_entities.get(f1["fact_id"], set())
|
||||
ents2 = fact_entities.get(f2["fact_id"], set())
|
||||
|
||||
if not ents1 or not ents2:
|
||||
continue
|
||||
|
||||
# Entity overlap (Jaccard)
|
||||
entity_overlap = len(ents1 & ents2) / len(ents1 | ents2) if (ents1 | ents2) else 0.0
|
||||
|
||||
if entity_overlap < 0.3:
|
||||
continue # Not enough entity overlap to be contradictory
|
||||
|
||||
# Content similarity via HRR vectors
|
||||
v1 = hrr.bytes_to_phases(f1["hrr_vector"])
|
||||
v2 = hrr.bytes_to_phases(f2["hrr_vector"])
|
||||
content_sim = hrr.similarity(v1, v2)
|
||||
|
||||
# High entity overlap + low content similarity = potential contradiction
|
||||
# contradiction_score: higher = more contradictory
|
||||
contradiction_score = entity_overlap * (1.0 - (content_sim + 1.0) / 2.0)
|
||||
|
||||
if contradiction_score >= threshold:
|
||||
# Strip hrr_vector from output (not JSON serializable)
|
||||
f1_clean = {k: v for k, v in f1.items() if k != "hrr_vector"}
|
||||
f2_clean = {k: v for k, v in f2.items() if k != "hrr_vector"}
|
||||
contradictions.append({
|
||||
"fact_a": f1_clean,
|
||||
"fact_b": f2_clean,
|
||||
"entity_overlap": round(entity_overlap, 3),
|
||||
"content_similarity": round(content_sim, 3),
|
||||
"contradiction_score": round(contradiction_score, 3),
|
||||
"shared_entities": sorted(ents1 & ents2),
|
||||
})
|
||||
|
||||
contradictions.sort(key=lambda x: x["contradiction_score"], reverse=True)
|
||||
return contradictions[:limit]
|
||||
|
||||
def _score_facts_by_vector(
|
||||
self,
|
||||
target_vec: "np.ndarray",
|
||||
category: str | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Score facts by similarity to a target vector."""
|
||||
conn = self.store._conn
|
||||
|
||||
where = "WHERE hrr_vector IS NOT NULL"
|
||||
params: list = []
|
||||
if category:
|
||||
where += " AND category = ?"
|
||||
params.append(category)
|
||||
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT fact_id, content, category, tags, trust_score,
|
||||
retrieval_count, helpful_count, created_at, updated_at,
|
||||
hrr_vector
|
||||
FROM facts
|
||||
{where}
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
|
||||
scored = []
|
||||
for row in rows:
|
||||
fact = dict(row)
|
||||
fact_vec = hrr.bytes_to_phases(fact.pop("hrr_vector"))
|
||||
sim = hrr.similarity(target_vec, fact_vec)
|
||||
fact["score"] = (sim + 1.0) / 2.0 * fact["trust_score"]
|
||||
scored.append(fact)
|
||||
|
||||
scored.sort(key=lambda x: x["score"], reverse=True)
|
||||
return scored[:limit]
|
||||
|
||||
def _fts_candidates(
|
||||
self,
|
||||
query: str,
|
||||
category: str | None,
|
||||
min_trust: float,
|
||||
limit: int,
|
||||
) -> list[dict]:
|
||||
"""Get raw FTS5 candidates from the store.
|
||||
|
||||
Uses the store's database connection directly for FTS5 MATCH
|
||||
with rank scoring. Normalizes FTS5 rank to [0, 1] range.
|
||||
"""
|
||||
conn = self.store._conn
|
||||
|
||||
# Build query - FTS5 rank is negative (lower = better match)
|
||||
# We need to join facts_fts with facts to get all columns
|
||||
params: list = []
|
||||
where_clauses = ["facts_fts MATCH ?"]
|
||||
params.append(query)
|
||||
|
||||
if category:
|
||||
where_clauses.append("f.category = ?")
|
||||
params.append(category)
|
||||
|
||||
where_clauses.append("f.trust_score >= ?")
|
||||
params.append(min_trust)
|
||||
|
||||
where_sql = " AND ".join(where_clauses)
|
||||
|
||||
sql = f"""
|
||||
SELECT f.*, facts_fts.rank as fts_rank_raw
|
||||
FROM facts_fts
|
||||
JOIN facts f ON f.fact_id = facts_fts.rowid
|
||||
WHERE {where_sql}
|
||||
ORDER BY facts_fts.rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
except Exception:
|
||||
# FTS5 MATCH can fail on malformed queries — fall back to empty
|
||||
return []
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Normalize FTS5 rank: rank is negative, lower = better
|
||||
# Convert to positive score in [0, 1] range
|
||||
raw_ranks = [abs(row["fts_rank_raw"]) for row in rows]
|
||||
max_rank = max(raw_ranks) if raw_ranks else 1.0
|
||||
max_rank = max(max_rank, 1e-6) # avoid div by zero
|
||||
|
||||
results = []
|
||||
for row, raw_rank in zip(rows, raw_ranks):
|
||||
fact = dict(row)
|
||||
fact.pop("fts_rank_raw", None)
|
||||
fact["fts_rank"] = raw_rank / max_rank # normalize to [0, 1]
|
||||
results.append(fact)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _tokenize(text: str) -> set[str]:
|
||||
"""Simple whitespace tokenization with lowercasing.
|
||||
|
||||
Strips common punctuation. No stemming/lemmatization (Phase 1).
|
||||
"""
|
||||
if not text:
|
||||
return set()
|
||||
# Split on whitespace, lowercase, strip punctuation
|
||||
tokens = set()
|
||||
for word in text.lower().split():
|
||||
cleaned = word.strip(".,;:!?\"'()[]{}#@<>")
|
||||
if cleaned:
|
||||
tokens.add(cleaned)
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def _jaccard_similarity(set_a: set, set_b: set) -> float:
|
||||
"""Jaccard similarity coefficient: |A ∩ B| / |A ∪ B|."""
|
||||
if not set_a or not set_b:
|
||||
return 0.0
|
||||
intersection = len(set_a & set_b)
|
||||
union = len(set_a | set_b)
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
def _temporal_decay(self, timestamp_str: str | None) -> float:
|
||||
"""Exponential decay: 0.5^(age_days / half_life_days).
|
||||
|
||||
Returns 1.0 if decay is disabled or timestamp is missing.
|
||||
"""
|
||||
if not self.half_life or not timestamp_str:
|
||||
return 1.0
|
||||
|
||||
try:
|
||||
if isinstance(timestamp_str, str):
|
||||
# Parse ISO format timestamp from SQLite
|
||||
ts = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
|
||||
else:
|
||||
ts = timestamp_str
|
||||
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
|
||||
age_days = (datetime.now(timezone.utc) - ts).total_seconds() / 86400
|
||||
if age_days < 0:
|
||||
return 1.0
|
||||
|
||||
return math.pow(0.5, age_days / self.half_life)
|
||||
except (ValueError, TypeError):
|
||||
return 1.0
|
||||
572
plugins/hermes-memory-store/store.py
Normal file
572
plugins/hermes-memory-store/store.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
SQLite-backed fact store with entity resolution and trust scoring.
|
||||
Single-user Hermes memory store plugin.
|
||||
"""
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from . import holographic as hrr
|
||||
except ImportError:
|
||||
import holographic as hrr # type: ignore[no-redef]
|
||||
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS facts (
|
||||
fact_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
content TEXT NOT NULL UNIQUE,
|
||||
category TEXT DEFAULT 'general',
|
||||
tags TEXT DEFAULT '',
|
||||
trust_score REAL DEFAULT 0.5,
|
||||
retrieval_count INTEGER DEFAULT 0,
|
||||
helpful_count INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
hrr_vector BLOB
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS entities (
|
||||
entity_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
entity_type TEXT DEFAULT 'unknown',
|
||||
aliases TEXT DEFAULT '',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS fact_entities (
|
||||
fact_id INTEGER REFERENCES facts(fact_id),
|
||||
entity_id INTEGER REFERENCES entities(entity_id),
|
||||
PRIMARY KEY (fact_id, entity_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_trust ON facts(trust_score DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_category ON facts(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_entities_name ON entities(name);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS facts_fts
|
||||
USING fts5(content, tags, content=facts, content_rowid=fact_id);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS facts_ai AFTER INSERT ON facts BEGIN
|
||||
INSERT INTO facts_fts(rowid, content, tags)
|
||||
VALUES (new.fact_id, new.content, new.tags);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS facts_ad AFTER DELETE ON facts BEGIN
|
||||
INSERT INTO facts_fts(facts_fts, rowid, content, tags)
|
||||
VALUES ('delete', old.fact_id, old.content, old.tags);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS facts_au AFTER UPDATE ON facts BEGIN
|
||||
INSERT INTO facts_fts(facts_fts, rowid, content, tags)
|
||||
VALUES ('delete', old.fact_id, old.content, old.tags);
|
||||
INSERT INTO facts_fts(rowid, content, tags)
|
||||
VALUES (new.fact_id, new.content, new.tags);
|
||||
END;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS memory_banks (
|
||||
bank_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
bank_name TEXT NOT NULL UNIQUE,
|
||||
vector BLOB NOT NULL,
|
||||
dim INTEGER NOT NULL,
|
||||
fact_count INTEGER DEFAULT 0,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
"""
|
||||
|
||||
# Trust adjustment constants
|
||||
_HELPFUL_DELTA = 0.05
|
||||
_UNHELPFUL_DELTA = -0.10
|
||||
_TRUST_MIN = 0.0
|
||||
_TRUST_MAX = 1.0
|
||||
|
||||
# Entity extraction patterns
|
||||
_RE_CAPITALIZED = re.compile(r'\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b')
|
||||
_RE_DOUBLE_QUOTE = re.compile(r'"([^"]+)"')
|
||||
_RE_SINGLE_QUOTE = re.compile(r"'([^']+)'")
|
||||
_RE_AKA = re.compile(
|
||||
r'(\w+(?:\s+\w+)*)\s+(?:aka|also known as)\s+(\w+(?:\s+\w+)*)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _clamp_trust(value: float) -> float:
|
||||
return max(_TRUST_MIN, min(_TRUST_MAX, value))
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""SQLite-backed fact store with entity resolution and trust scoring."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: "str | Path" = "~/.hermes/memory_store.db",
|
||||
default_trust: float = 0.5,
|
||||
hrr_dim: int = 1024,
|
||||
) -> None:
|
||||
self.db_path = Path(db_path).expanduser()
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.default_trust = _clamp_trust(default_trust)
|
||||
self.hrr_dim = hrr_dim
|
||||
self._hrr_available = hrr._HAS_NUMPY
|
||||
self._conn: sqlite3.Connection = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False,
|
||||
timeout=10.0,
|
||||
)
|
||||
self._lock = threading.RLock()
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._init_db()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_db(self) -> None:
|
||||
"""Create tables, indexes, and triggers if they do not exist. Enable WAL mode."""
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.executescript(_SCHEMA)
|
||||
# Migrate: add hrr_vector column if missing (safe for existing databases)
|
||||
columns = {row[1] for row in self._conn.execute("PRAGMA table_info(facts)").fetchall()}
|
||||
if "hrr_vector" not in columns:
|
||||
self._conn.execute("ALTER TABLE facts ADD COLUMN hrr_vector BLOB")
|
||||
self._conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def add_fact(
|
||||
self,
|
||||
content: str,
|
||||
category: str = "general",
|
||||
tags: str = "",
|
||||
) -> int:
|
||||
"""Insert a fact and return its fact_id.
|
||||
|
||||
Deduplicates by content (UNIQUE constraint). On duplicate, returns
|
||||
the existing fact_id without modifying the row. Extracts entities from
|
||||
the content and links them to the fact.
|
||||
"""
|
||||
with self._lock:
|
||||
content = content.strip()
|
||||
if not content:
|
||||
raise ValueError("content must not be empty")
|
||||
|
||||
try:
|
||||
cur = self._conn.execute(
|
||||
"""
|
||||
INSERT INTO facts (content, category, tags, trust_score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(content, category, tags, self.default_trust),
|
||||
)
|
||||
self._conn.commit()
|
||||
fact_id: int = cur.lastrowid # type: ignore[assignment]
|
||||
except sqlite3.IntegrityError:
|
||||
# Duplicate content — return existing id
|
||||
row = self._conn.execute(
|
||||
"SELECT fact_id FROM facts WHERE content = ?", (content,)
|
||||
).fetchone()
|
||||
return int(row["fact_id"])
|
||||
|
||||
# Entity extraction and linking
|
||||
for name in self._extract_entities(content):
|
||||
entity_id = self._resolve_entity(name)
|
||||
self._link_fact_entity(fact_id, entity_id)
|
||||
|
||||
# Compute HRR vector after entity linking
|
||||
self._compute_hrr_vector(fact_id, content)
|
||||
self._rebuild_bank(category)
|
||||
|
||||
return fact_id
|
||||
|
||||
def search_facts(
|
||||
self,
|
||||
query: str,
|
||||
category: str | None = None,
|
||||
min_trust: float = 0.3,
|
||||
limit: int = 10,
|
||||
) -> list[dict]:
|
||||
"""Full-text search over facts using FTS5.
|
||||
|
||||
Returns a list of fact dicts ordered by FTS5 rank, then trust_score
|
||||
descending. Also increments retrieval_count for matched facts.
|
||||
"""
|
||||
with self._lock:
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return []
|
||||
|
||||
params: list = [query, min_trust]
|
||||
category_clause = ""
|
||||
if category is not None:
|
||||
category_clause = "AND f.category = ?"
|
||||
params.append(category)
|
||||
params.append(limit)
|
||||
|
||||
sql = f"""
|
||||
SELECT f.fact_id, f.content, f.category, f.tags,
|
||||
f.trust_score, f.retrieval_count, f.helpful_count,
|
||||
f.created_at, f.updated_at
|
||||
FROM facts f
|
||||
JOIN facts_fts fts ON fts.rowid = f.fact_id
|
||||
WHERE facts_fts MATCH ?
|
||||
AND f.trust_score >= ?
|
||||
{category_clause}
|
||||
ORDER BY fts.rank, f.trust_score DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
rows = self._conn.execute(sql, params).fetchall()
|
||||
results = [self._row_to_dict(r) for r in rows]
|
||||
|
||||
if results:
|
||||
ids = [r["fact_id"] for r in results]
|
||||
placeholders = ",".join("?" * len(ids))
|
||||
self._conn.execute(
|
||||
f"UPDATE facts SET retrieval_count = retrieval_count + 1 WHERE fact_id IN ({placeholders})",
|
||||
ids,
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
return results
|
||||
|
||||
def update_fact(
|
||||
self,
|
||||
fact_id: int,
|
||||
content: str | None = None,
|
||||
trust_delta: float | None = None,
|
||||
tags: str | None = None,
|
||||
category: str | None = None,
|
||||
) -> bool:
|
||||
"""Partially update a fact. Trust is clamped to [0, 1].
|
||||
|
||||
Returns True if the row existed, False otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
row = self._conn.execute(
|
||||
"SELECT fact_id, trust_score FROM facts WHERE fact_id = ?", (fact_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return False
|
||||
|
||||
assignments: list[str] = ["updated_at = CURRENT_TIMESTAMP"]
|
||||
params: list = []
|
||||
|
||||
if content is not None:
|
||||
assignments.append("content = ?")
|
||||
params.append(content.strip())
|
||||
if tags is not None:
|
||||
assignments.append("tags = ?")
|
||||
params.append(tags)
|
||||
if category is not None:
|
||||
assignments.append("category = ?")
|
||||
params.append(category)
|
||||
if trust_delta is not None:
|
||||
new_trust = _clamp_trust(row["trust_score"] + trust_delta)
|
||||
assignments.append("trust_score = ?")
|
||||
params.append(new_trust)
|
||||
|
||||
params.append(fact_id)
|
||||
self._conn.execute(
|
||||
f"UPDATE facts SET {', '.join(assignments)} WHERE fact_id = ?",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
# If content changed, re-extract entities
|
||||
if content is not None:
|
||||
self._conn.execute(
|
||||
"DELETE FROM fact_entities WHERE fact_id = ?", (fact_id,)
|
||||
)
|
||||
for name in self._extract_entities(content):
|
||||
entity_id = self._resolve_entity(name)
|
||||
self._link_fact_entity(fact_id, entity_id)
|
||||
self._conn.commit()
|
||||
|
||||
# Recompute HRR vector if content changed
|
||||
if content is not None:
|
||||
self._compute_hrr_vector(fact_id, content)
|
||||
# Rebuild bank for relevant category
|
||||
cat = category or self._conn.execute(
|
||||
"SELECT category FROM facts WHERE fact_id = ?", (fact_id,)
|
||||
).fetchone()["category"]
|
||||
self._rebuild_bank(cat)
|
||||
|
||||
return True
|
||||
|
||||
def remove_fact(self, fact_id: int) -> bool:
|
||||
"""Delete a fact and its entity links. Returns True if the row existed."""
|
||||
with self._lock:
|
||||
row = self._conn.execute(
|
||||
"SELECT fact_id, category FROM facts WHERE fact_id = ?", (fact_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return False
|
||||
|
||||
self._conn.execute(
|
||||
"DELETE FROM fact_entities WHERE fact_id = ?", (fact_id,)
|
||||
)
|
||||
self._conn.execute("DELETE FROM facts WHERE fact_id = ?", (fact_id,))
|
||||
self._conn.commit()
|
||||
self._rebuild_bank(row["category"])
|
||||
return True
|
||||
|
||||
def list_facts(
|
||||
self,
|
||||
category: str | None = None,
|
||||
min_trust: float = 0.0,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""Browse facts ordered by trust_score descending.
|
||||
|
||||
Optionally filter by category and minimum trust score.
|
||||
"""
|
||||
with self._lock:
|
||||
params: list = [min_trust]
|
||||
category_clause = ""
|
||||
if category is not None:
|
||||
category_clause = "AND category = ?"
|
||||
params.append(category)
|
||||
params.append(limit)
|
||||
|
||||
sql = f"""
|
||||
SELECT fact_id, content, category, tags, trust_score,
|
||||
retrieval_count, helpful_count, created_at, updated_at
|
||||
FROM facts
|
||||
WHERE trust_score >= ?
|
||||
{category_clause}
|
||||
ORDER BY trust_score DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
rows = self._conn.execute(sql, params).fetchall()
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
def record_feedback(self, fact_id: int, helpful: bool) -> dict:
|
||||
"""Record user feedback and adjust trust asymmetrically.
|
||||
|
||||
helpful=True -> trust += 0.05, helpful_count += 1
|
||||
helpful=False -> trust -= 0.10
|
||||
|
||||
Returns a dict with fact_id, old_trust, new_trust, helpful_count.
|
||||
Raises KeyError if fact_id does not exist.
|
||||
"""
|
||||
with self._lock:
|
||||
row = self._conn.execute(
|
||||
"SELECT fact_id, trust_score, helpful_count FROM facts WHERE fact_id = ?",
|
||||
(fact_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise KeyError(f"fact_id {fact_id} not found")
|
||||
|
||||
old_trust: float = row["trust_score"]
|
||||
delta = _HELPFUL_DELTA if helpful else _UNHELPFUL_DELTA
|
||||
new_trust = _clamp_trust(old_trust + delta)
|
||||
|
||||
helpful_increment = 1 if helpful else 0
|
||||
self._conn.execute(
|
||||
"""
|
||||
UPDATE facts
|
||||
SET trust_score = ?,
|
||||
helpful_count = helpful_count + ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE fact_id = ?
|
||||
""",
|
||||
(new_trust, helpful_increment, fact_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
return {
|
||||
"fact_id": fact_id,
|
||||
"old_trust": old_trust,
|
||||
"new_trust": new_trust,
|
||||
"helpful_count": row["helpful_count"] + helpful_increment,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Entity helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_entities(self, text: str) -> list[str]:
|
||||
"""Extract entity candidates from text using simple regex rules.
|
||||
|
||||
Rules applied (in order):
|
||||
1. Capitalized multi-word phrases e.g. "John Doe"
|
||||
2. Double-quoted terms e.g. "Python"
|
||||
3. Single-quoted terms e.g. 'pytest'
|
||||
4. AKA patterns e.g. "Guido aka BDFL" -> two entities
|
||||
|
||||
Returns a deduplicated list preserving first-seen order.
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
candidates: list[str] = []
|
||||
|
||||
def _add(name: str) -> None:
|
||||
stripped = name.strip()
|
||||
if stripped and stripped.lower() not in seen:
|
||||
seen.add(stripped.lower())
|
||||
candidates.append(stripped)
|
||||
|
||||
for m in _RE_CAPITALIZED.finditer(text):
|
||||
_add(m.group(1))
|
||||
|
||||
for m in _RE_DOUBLE_QUOTE.finditer(text):
|
||||
_add(m.group(1))
|
||||
|
||||
for m in _RE_SINGLE_QUOTE.finditer(text):
|
||||
_add(m.group(1))
|
||||
|
||||
for m in _RE_AKA.finditer(text):
|
||||
_add(m.group(1))
|
||||
_add(m.group(2))
|
||||
|
||||
return candidates
|
||||
|
||||
def _resolve_entity(self, name: str) -> int:
|
||||
"""Find an existing entity by name or alias (case-insensitive) or create one.
|
||||
|
||||
Returns the entity_id.
|
||||
"""
|
||||
# Exact name match
|
||||
row = self._conn.execute(
|
||||
"SELECT entity_id FROM entities WHERE name LIKE ?", (name,)
|
||||
).fetchone()
|
||||
if row is not None:
|
||||
return int(row["entity_id"])
|
||||
|
||||
# Search aliases — aliases stored as comma-separated; use LIKE with % boundaries
|
||||
alias_row = self._conn.execute(
|
||||
"""
|
||||
SELECT entity_id FROM entities
|
||||
WHERE ',' || aliases || ',' LIKE '%,' || ? || ',%'
|
||||
""",
|
||||
(name,),
|
||||
).fetchone()
|
||||
if alias_row is not None:
|
||||
return int(alias_row["entity_id"])
|
||||
|
||||
# Create new entity
|
||||
cur = self._conn.execute(
|
||||
"INSERT INTO entities (name) VALUES (?)", (name,)
|
||||
)
|
||||
self._conn.commit()
|
||||
return int(cur.lastrowid) # type: ignore[return-value]
|
||||
|
||||
def _link_fact_entity(self, fact_id: int, entity_id: int) -> None:
|
||||
"""Insert into fact_entities, silently ignore if the link already exists."""
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO fact_entities (fact_id, entity_id)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(fact_id, entity_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def _compute_hrr_vector(self, fact_id: int, content: str) -> None:
|
||||
"""Compute and store HRR vector for a fact. No-op if numpy unavailable."""
|
||||
with self._lock:
|
||||
if not self._hrr_available:
|
||||
return
|
||||
|
||||
# Get entities linked to this fact
|
||||
rows = self._conn.execute(
|
||||
"""
|
||||
SELECT e.name FROM entities e
|
||||
JOIN fact_entities fe ON fe.entity_id = e.entity_id
|
||||
WHERE fe.fact_id = ?
|
||||
""",
|
||||
(fact_id,),
|
||||
).fetchall()
|
||||
entities = [row["name"] for row in rows]
|
||||
|
||||
vector = hrr.encode_fact(content, entities, self.hrr_dim)
|
||||
self._conn.execute(
|
||||
"UPDATE facts SET hrr_vector = ? WHERE fact_id = ?",
|
||||
(hrr.phases_to_bytes(vector), fact_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def _rebuild_bank(self, category: str) -> None:
|
||||
"""Full rebuild of a category's memory bank from all its fact vectors."""
|
||||
with self._lock:
|
||||
if not self._hrr_available:
|
||||
return
|
||||
|
||||
bank_name = f"cat:{category}"
|
||||
rows = self._conn.execute(
|
||||
"SELECT hrr_vector FROM facts WHERE category = ? AND hrr_vector IS NOT NULL",
|
||||
(category,),
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
self._conn.execute("DELETE FROM memory_banks WHERE bank_name = ?", (bank_name,))
|
||||
self._conn.commit()
|
||||
return
|
||||
|
||||
vectors = [hrr.bytes_to_phases(row["hrr_vector"]) for row in rows]
|
||||
bank_vector = hrr.bundle(*vectors)
|
||||
fact_count = len(vectors)
|
||||
|
||||
# Check SNR
|
||||
hrr.snr_estimate(self.hrr_dim, fact_count)
|
||||
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_banks (bank_name, vector, dim, fact_count, updated_at)
|
||||
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(bank_name) DO UPDATE SET
|
||||
vector = excluded.vector,
|
||||
dim = excluded.dim,
|
||||
fact_count = excluded.fact_count,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(bank_name, hrr.phases_to_bytes(bank_vector), self.hrr_dim, fact_count),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def rebuild_all_vectors(self, dim: int | None = None) -> int:
|
||||
"""Recompute all HRR vectors + banks from text. For recovery/migration.
|
||||
|
||||
Returns the number of facts processed.
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._hrr_available:
|
||||
return 0
|
||||
|
||||
if dim is not None:
|
||||
self.hrr_dim = dim
|
||||
|
||||
rows = self._conn.execute(
|
||||
"SELECT fact_id, content, category FROM facts"
|
||||
).fetchall()
|
||||
|
||||
categories: set[str] = set()
|
||||
for row in rows:
|
||||
self._compute_hrr_vector(row["fact_id"], row["content"])
|
||||
categories.add(row["category"])
|
||||
|
||||
for category in categories:
|
||||
self._rebuild_bank(category)
|
||||
|
||||
return len(rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utilities
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _row_to_dict(self, row: sqlite3.Row) -> dict:
|
||||
"""Convert a sqlite3.Row to a plain dict."""
|
||||
return dict(row)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
self._conn.close()
|
||||
|
||||
def __enter__(self) -> "MemoryStore":
|
||||
return self
|
||||
|
||||
def __exit__(self, *_: object) -> None:
|
||||
self.close()
|
||||
315
plugins/hindsight-memory/__init__.py
Normal file
315
plugins/hindsight-memory/__init__.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Hindsight memory plugin — MemoryProvider interface.
|
||||
|
||||
Long-term memory with knowledge graph, entity resolution, and multi-strategy
|
||||
retrieval. Supports cloud (API key) and local (embedded PostgreSQL) modes.
|
||||
|
||||
Original PR #1811 by benfrank241, adapted to MemoryProvider ABC.
|
||||
|
||||
Config via environment variables:
|
||||
HINDSIGHT_API_KEY — API key for Hindsight Cloud
|
||||
HINDSIGHT_BANK_ID — memory bank identifier (default: hermes)
|
||||
HINDSIGHT_BUDGET — recall budget: low/mid/high (default: mid)
|
||||
HINDSIGHT_API_URL — API endpoint
|
||||
HINDSIGHT_MODE — cloud or local (default: cloud)
|
||||
|
||||
Or via ~/.hindsight/config.json (written by the original setup wizard).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_API_URL = "https://api.hindsight.vectorize.io"
|
||||
_VALID_BUDGETS = {"low", "mid", "high"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread helper (from original PR — avoids aiohttp event loop conflicts)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _run_in_thread(fn, timeout: float = 30.0):
|
||||
result_q: queue.Queue = queue.Queue(maxsize=1)
|
||||
|
||||
def _run():
|
||||
import asyncio
|
||||
asyncio.set_event_loop(None)
|
||||
try:
|
||||
result_q.put(("ok", fn()))
|
||||
except Exception as exc:
|
||||
result_q.put(("err", exc))
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True, name="hindsight-call")
|
||||
t.start()
|
||||
kind, value = result_q.get(timeout=timeout)
|
||||
if kind == "err":
|
||||
raise value
|
||||
return value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RETAIN_SCHEMA = {
|
||||
"name": "hindsight_retain",
|
||||
"description": (
|
||||
"Store information to long-term memory. Hindsight automatically "
|
||||
"extracts structured facts, resolves entities, and indexes for retrieval."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "The information to store."},
|
||||
"context": {"type": "string", "description": "Short label (e.g. 'user preference', 'project decision')."},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
}
|
||||
|
||||
RECALL_SCHEMA = {
|
||||
"name": "hindsight_recall",
|
||||
"description": (
|
||||
"Search long-term memory. Returns memories ranked by relevance using "
|
||||
"semantic search, keyword matching, entity graph traversal, and reranking."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "What to search for."},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
REFLECT_SCHEMA = {
|
||||
"name": "hindsight_reflect",
|
||||
"description": (
|
||||
"Synthesize a reasoned answer from long-term memories. Unlike recall, "
|
||||
"this reasons across all stored memories to produce a coherent response."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "The question to reflect on."},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Load config from ~/.hindsight/config.json, falling back to env vars."""
|
||||
from pathlib import Path
|
||||
config_path = Path.home() / ".hindsight" / "config.json"
|
||||
|
||||
if config_path.exists():
|
||||
try:
|
||||
return json.loads(config_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"mode": os.environ.get("HINDSIGHT_MODE", "cloud"),
|
||||
"apiKey": os.environ.get("HINDSIGHT_API_KEY", ""),
|
||||
"banks": {
|
||||
"hermes": {
|
||||
"bankId": os.environ.get("HINDSIGHT_BANK_ID", "hermes"),
|
||||
"budget": os.environ.get("HINDSIGHT_BUDGET", "mid"),
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class HindsightMemoryProvider(MemoryProvider):
|
||||
"""Hindsight long-term memory with knowledge graph and multi-strategy retrieval."""
|
||||
|
||||
def __init__(self):
|
||||
self._config = None
|
||||
self._api_key = None
|
||||
self._bank_id = "hermes"
|
||||
self._budget = "mid"
|
||||
self._mode = "cloud"
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "hindsight"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
try:
|
||||
cfg = _load_config()
|
||||
mode = cfg.get("mode", "cloud")
|
||||
if mode == "local":
|
||||
embed = cfg.get("embed", {})
|
||||
return bool(embed.get("llmApiKey") or os.environ.get("HINDSIGHT_LLM_API_KEY"))
|
||||
api_key = cfg.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", "")
|
||||
return bool(api_key)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_config_schema(self):
|
||||
return [
|
||||
{"key": "mode", "description": "Cloud API or local embedded mode", "default": "cloud", "choices": ["cloud", "local"]},
|
||||
{"key": "api_key", "description": "Hindsight Cloud API key", "secret": True, "env_var": "HINDSIGHT_API_KEY", "url": "https://app.hindsight.vectorize.io"},
|
||||
{"key": "bank_id", "description": "Memory bank identifier", "default": "hermes"},
|
||||
{"key": "budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]},
|
||||
{"key": "llm_provider", "description": "LLM provider for local mode", "default": "anthropic", "choices": ["anthropic", "openai", "groq", "ollama"]},
|
||||
{"key": "llm_api_key", "description": "LLM API key for local mode", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY"},
|
||||
{"key": "llm_model", "description": "LLM model for local mode", "default": "claude-haiku-4-5-20251001"},
|
||||
]
|
||||
|
||||
def _make_client(self):
|
||||
"""Create a fresh Hindsight client (thread-safe)."""
|
||||
if self._mode == "local":
|
||||
from hindsight import HindsightEmbedded
|
||||
embed = self._config.get("embed", {})
|
||||
return HindsightEmbedded(
|
||||
profile=embed.get("profile", "hermes"),
|
||||
llm_provider=embed.get("llmProvider", ""),
|
||||
llm_api_key=embed.get("llmApiKey", ""),
|
||||
llm_model=embed.get("llmModel", ""),
|
||||
)
|
||||
from hindsight_client import Hindsight
|
||||
return Hindsight(api_key=self._api_key, timeout=30.0)
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._config = _load_config()
|
||||
self._mode = self._config.get("mode", "cloud")
|
||||
self._api_key = self._config.get("apiKey") or os.environ.get("HINDSIGHT_API_KEY", "")
|
||||
|
||||
banks = self._config.get("banks", {}).get("hermes", {})
|
||||
self._bank_id = banks.get("bankId", "hermes")
|
||||
budget = banks.get("budget", "mid")
|
||||
self._budget = budget if budget in _VALID_BUDGETS else "mid"
|
||||
|
||||
# Ensure bank exists
|
||||
try:
|
||||
client = _run_in_thread(self._make_client)
|
||||
_run_in_thread(lambda: client.create_bank(bank_id=self._bank_id, name=self._bank_id))
|
||||
except Exception:
|
||||
pass # Already exists
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
return (
|
||||
f"# Hindsight Memory\n"
|
||||
f"Active. Bank: {self._bank_id}, budget: {self._budget}.\n"
|
||||
f"Use hindsight_recall to search, hindsight_reflect for synthesis, "
|
||||
f"hindsight_retain to store facts."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
result = self._prefetch_result
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
return ""
|
||||
return f"## Hindsight Memory\n{result}"
|
||||
|
||||
def queue_prefetch(self, query: str) -> None:
|
||||
def _run():
|
||||
try:
|
||||
client = self._make_client()
|
||||
resp = client.recall(bank_id=self._bank_id, query=query, budget=self._budget)
|
||||
if resp.results:
|
||||
text = "\n".join(r.text for r in resp.results if r.text)
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = text
|
||||
except Exception as e:
|
||||
logger.debug("Hindsight prefetch failed: %s", e)
|
||||
|
||||
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="hindsight-prefetch")
|
||||
self._prefetch_thread.start()
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str) -> None:
|
||||
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||
try:
|
||||
_run_in_thread(
|
||||
lambda: self._make_client().retain(
|
||||
bank_id=self._bank_id, content=combined, context="conversation"
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Hindsight sync failed: %s", e)
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [RETAIN_SCHEMA, RECALL_SCHEMA, REFLECT_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
if tool_name == "hindsight_retain":
|
||||
content = args.get("content", "")
|
||||
if not content:
|
||||
return json.dumps({"error": "Missing required parameter: content"})
|
||||
context = args.get("context")
|
||||
try:
|
||||
_run_in_thread(
|
||||
lambda: self._make_client().retain(
|
||||
bank_id=self._bank_id, content=content, context=context
|
||||
)
|
||||
)
|
||||
return json.dumps({"result": "Memory stored successfully."})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to store memory: {e}"})
|
||||
|
||||
elif tool_name == "hindsight_recall":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "Missing required parameter: query"})
|
||||
try:
|
||||
resp = _run_in_thread(
|
||||
lambda: self._make_client().recall(
|
||||
bank_id=self._bank_id, query=query, budget=self._budget
|
||||
)
|
||||
)
|
||||
if not resp.results:
|
||||
return json.dumps({"result": "No relevant memories found."})
|
||||
lines = [f"{i}. {r.text}" for i, r in enumerate(resp.results, 1)]
|
||||
return json.dumps({"result": "\n".join(lines)})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to search memory: {e}"})
|
||||
|
||||
elif tool_name == "hindsight_reflect":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "Missing required parameter: query"})
|
||||
try:
|
||||
resp = _run_in_thread(
|
||||
lambda: self._make_client().reflect(
|
||||
bank_id=self._bank_id, query=query, budget=self._budget
|
||||
)
|
||||
)
|
||||
return json.dumps({"result": resp.text or "No relevant memories found."})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to reflect: {e}"})
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=5.0)
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
"""Register Hindsight as a memory provider plugin."""
|
||||
ctx.register_memory_provider(HindsightMemoryProvider())
|
||||
8
plugins/hindsight-memory/plugin.yaml
Normal file
8
plugins/hindsight-memory/plugin.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
name: hindsight-memory
|
||||
version: 1.0.0
|
||||
description: >
|
||||
Long-term memory via Hindsight — knowledge graph with entity resolution,
|
||||
multi-strategy retrieval (semantic + BM25 + graph + temporal), and
|
||||
cross-encoder reranking. Cloud or local mode.
|
||||
requires_env:
|
||||
- HINDSIGHT_API_KEY
|
||||
294
plugins/mem0-memory/__init__.py
Normal file
294
plugins/mem0-memory/__init__.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Mem0 memory plugin — MemoryProvider interface.
|
||||
|
||||
Server-side LLM fact extraction, semantic search with reranking, and
|
||||
automatic deduplication via the Mem0 Platform API.
|
||||
|
||||
Original PR #2933 by kartik-mem0, adapted to MemoryProvider ABC.
|
||||
|
||||
Config via environment variables:
|
||||
MEM0_API_KEY — Mem0 Platform API key (required)
|
||||
MEM0_USER_ID — User identifier (default: hermes-user)
|
||||
MEM0_AGENT_ID — Agent identifier (default: hermes)
|
||||
|
||||
Or via $HERMES_HOME/mem0.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Load config from $HERMES_HOME/mem0.json or env vars."""
|
||||
hermes_home = os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes"))
|
||||
config_path = Path(hermes_home) / "mem0.json"
|
||||
|
||||
if config_path.exists():
|
||||
try:
|
||||
return json.loads(config_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"api_key": os.environ.get("MEM0_API_KEY", ""),
|
||||
"user_id": os.environ.get("MEM0_USER_ID", "hermes-user"),
|
||||
"agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"),
|
||||
"rerank": True,
|
||||
"keyword_search": False,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROFILE_SCHEMA = {
|
||||
"name": "mem0_profile",
|
||||
"description": (
|
||||
"Retrieve all stored memories about the user — preferences, facts, "
|
||||
"project context. Fast, no reranking. Use at conversation start."
|
||||
),
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
SEARCH_SCHEMA = {
|
||||
"name": "mem0_search",
|
||||
"description": (
|
||||
"Search memories by meaning. Returns relevant facts ranked by similarity. "
|
||||
"Set rerank=true for higher accuracy (+150ms)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "What to search for."},
|
||||
"rerank": {"type": "boolean", "description": "Enable reranking for precision (default: false)."},
|
||||
"top_k": {"type": "integer", "description": "Max results (default: 10, max: 50)."},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
CONTEXT_SCHEMA = {
|
||||
"name": "mem0_context",
|
||||
"description": (
|
||||
"Deep retrieval with forced reranking. Use when you need the most "
|
||||
"relevant memories for a specific topic."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "What to search for."},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
CONCLUDE_SCHEMA = {
|
||||
"name": "mem0_conclude",
|
||||
"description": (
|
||||
"Store a durable fact about the user. Stored verbatim (no LLM extraction). "
|
||||
"Use for explicit preferences, corrections, or decisions."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"conclusion": {"type": "string", "description": "The fact to store."},
|
||||
},
|
||||
"required": ["conclusion"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Mem0MemoryProvider(MemoryProvider):
|
||||
"""Mem0 Platform memory with server-side extraction and semantic search."""
|
||||
|
||||
def __init__(self):
|
||||
self._config = None
|
||||
self._client = None
|
||||
self._api_key = ""
|
||||
self._user_id = "hermes-user"
|
||||
self._agent_id = "hermes"
|
||||
self._rerank = True
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "mem0"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
cfg = _load_config()
|
||||
return bool(cfg.get("api_key"))
|
||||
|
||||
def get_config_schema(self):
|
||||
return [
|
||||
{"key": "api_key", "description": "Mem0 Platform API key", "secret": True, "required": True, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"},
|
||||
{"key": "user_id", "description": "User identifier", "default": "hermes-user"},
|
||||
{"key": "agent_id", "description": "Agent identifier", "default": "hermes"},
|
||||
{"key": "rerank", "description": "Enable reranking for recall", "default": "true", "choices": ["true", "false"]},
|
||||
]
|
||||
|
||||
def _get_client(self):
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
try:
|
||||
from mem0 import MemoryClient
|
||||
self._client = MemoryClient(api_key=self._api_key)
|
||||
return self._client
|
||||
except ImportError:
|
||||
raise RuntimeError("mem0 package not installed. Run: pip install mem0ai")
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._config = _load_config()
|
||||
self._api_key = self._config.get("api_key", "")
|
||||
self._user_id = self._config.get("user_id", "hermes-user")
|
||||
self._agent_id = self._config.get("agent_id", "hermes")
|
||||
self._rerank = self._config.get("rerank", True)
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
return (
|
||||
"# Mem0 Memory\n"
|
||||
f"Active. User: {self._user_id}.\n"
|
||||
"Use mem0_search to find memories, mem0_conclude to store facts, "
|
||||
"mem0_profile for a full overview."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
result = self._prefetch_result
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
return ""
|
||||
return f"## Mem0 Memory\n{result}"
|
||||
|
||||
def queue_prefetch(self, query: str) -> None:
|
||||
def _run():
|
||||
try:
|
||||
client = self._get_client()
|
||||
results = client.search(
|
||||
query=query,
|
||||
user_id=self._user_id,
|
||||
rerank=self._rerank,
|
||||
top_k=5,
|
||||
)
|
||||
if results:
|
||||
lines = [r.get("memory", "") for r in results if r.get("memory")]
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = "\n".join(f"- {l}" for l in lines)
|
||||
except Exception as e:
|
||||
logger.debug("Mem0 prefetch failed: %s", e)
|
||||
|
||||
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="mem0-prefetch")
|
||||
self._prefetch_thread.start()
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str) -> None:
|
||||
"""Send the turn to Mem0 for server-side fact extraction."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
messages = [
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": assistant_content},
|
||||
]
|
||||
client.add(messages, user_id=self._user_id, agent_id=self._agent_id)
|
||||
except Exception as e:
|
||||
logger.warning("Mem0 sync failed: %s", e)
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, CONCLUDE_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
try:
|
||||
client = self._get_client()
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
if tool_name == "mem0_profile":
|
||||
try:
|
||||
memories = client.get_all(user_id=self._user_id)
|
||||
if not memories:
|
||||
return json.dumps({"result": "No memories stored yet."})
|
||||
lines = [m.get("memory", "") for m in memories if m.get("memory")]
|
||||
return json.dumps({"result": "\n".join(lines), "count": len(lines)})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to fetch profile: {e}"})
|
||||
|
||||
elif tool_name == "mem0_search":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "Missing required parameter: query"})
|
||||
rerank = args.get("rerank", False)
|
||||
top_k = min(int(args.get("top_k", 10)), 50)
|
||||
try:
|
||||
results = client.search(
|
||||
query=query, user_id=self._user_id,
|
||||
rerank=rerank, top_k=top_k,
|
||||
)
|
||||
if not results:
|
||||
return json.dumps({"result": "No relevant memories found."})
|
||||
items = [{"memory": r.get("memory", ""), "score": r.get("score", 0)} for r in results]
|
||||
return json.dumps({"results": items, "count": len(items)})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Search failed: {e}"})
|
||||
|
||||
elif tool_name == "mem0_context":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "Missing required parameter: query"})
|
||||
try:
|
||||
results = client.search(
|
||||
query=query, user_id=self._user_id,
|
||||
rerank=True, top_k=5,
|
||||
)
|
||||
if not results:
|
||||
return json.dumps({"result": "No relevant memories found."})
|
||||
items = [{"memory": r.get("memory", ""), "score": r.get("score", 0)} for r in results]
|
||||
return json.dumps({"results": items, "count": len(items)})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Context retrieval failed: {e}"})
|
||||
|
||||
elif tool_name == "mem0_conclude":
|
||||
conclusion = args.get("conclusion", "")
|
||||
if not conclusion:
|
||||
return json.dumps({"error": "Missing required parameter: conclusion"})
|
||||
try:
|
||||
client.add(
|
||||
[{"role": "user", "content": conclusion}],
|
||||
user_id=self._user_id,
|
||||
agent_id=self._agent_id,
|
||||
infer=False,
|
||||
)
|
||||
return json.dumps({"result": "Fact stored."})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to store: {e}"})
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=5.0)
|
||||
self._client = None
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
"""Register Mem0 as a memory provider plugin."""
|
||||
ctx.register_memory_provider(Mem0MemoryProvider())
|
||||
7
plugins/mem0-memory/plugin.yaml
Normal file
7
plugins/mem0-memory/plugin.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
name: mem0-memory
|
||||
version: 1.0.0
|
||||
description: >
|
||||
Long-term memory via Mem0 Platform — server-side LLM fact extraction,
|
||||
semantic search with reranking, and automatic deduplication.
|
||||
requires_env:
|
||||
- MEM0_API_KEY
|
||||
205
plugins/openviking-memory/__init__.py
Normal file
205
plugins/openviking-memory/__init__.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""OpenViking memory plugin — MemoryProvider interface.
|
||||
|
||||
Read-only semantic search over a self-hosted OpenViking knowledge server.
|
||||
Supports search (fast/deep/auto), URI-based content reading, and
|
||||
filesystem-style browsing.
|
||||
|
||||
Original PR #3369 by Mibayy, adapted to MemoryProvider ABC.
|
||||
|
||||
Config via environment variables:
|
||||
OPENVIKING_ENDPOINT — Server URL (default: http://127.0.0.1:1933)
|
||||
OPENVIKING_API_KEY — Optional API key
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SEARCH_SCHEMA = {
|
||||
"name": "viking_search",
|
||||
"description": (
|
||||
"Semantic search over OpenViking knowledge base. "
|
||||
"Returns ranked results with URIs for deeper reading."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query."},
|
||||
"mode": {
|
||||
"type": "string", "enum": ["auto", "fast", "deep"],
|
||||
"description": "Search depth (default: auto).",
|
||||
},
|
||||
"scope": {"type": "string", "description": "URI prefix to scope search."},
|
||||
"limit": {"type": "integer", "description": "Max results (default: 10)."},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
READ_SCHEMA = {
|
||||
"name": "viking_read",
|
||||
"description": (
|
||||
"Read content at a viking:// URI. Supports three detail levels: "
|
||||
"abstract (summary), overview (key points), read (full content)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"uri": {"type": "string", "description": "viking:// URI to read."},
|
||||
"level": {
|
||||
"type": "string", "enum": ["abstract", "overview", "read"],
|
||||
"description": "Detail level (default: overview).",
|
||||
},
|
||||
},
|
||||
"required": ["uri"],
|
||||
},
|
||||
}
|
||||
|
||||
BROWSE_SCHEMA = {
|
||||
"name": "viking_browse",
|
||||
"description": (
|
||||
"Browse the OpenViking knowledge store like a filesystem. "
|
||||
"Supports tree (hierarchy), list (directory), and stat (metadata)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string", "enum": ["tree", "list", "stat"],
|
||||
"description": "Browse action.",
|
||||
},
|
||||
"path": {"type": "string", "description": "Path to browse (default: root)."},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OpenVikingMemoryProvider(MemoryProvider):
|
||||
"""Read-only memory via OpenViking self-hosted knowledge server."""
|
||||
|
||||
def __init__(self):
|
||||
self._endpoint = ""
|
||||
self._api_key = ""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "openviking"
|
||||
|
||||
def get_config_schema(self):
|
||||
return [
|
||||
{"key": "endpoint", "description": "OpenViking server URL", "required": True, "default": "http://127.0.0.1:1933"},
|
||||
{"key": "api_key", "description": "OpenViking API key (if server requires auth)", "secret": True, "env_var": "OPENVIKING_API_KEY"},
|
||||
]
|
||||
|
||||
def is_available(self) -> bool:
|
||||
endpoint = os.environ.get("OPENVIKING_ENDPOINT", "")
|
||||
if not endpoint:
|
||||
return False
|
||||
# Quick health check
|
||||
try:
|
||||
import httpx
|
||||
resp = httpx.get(f"{endpoint}/health", timeout=3.0)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._endpoint = os.environ.get("OPENVIKING_ENDPOINT", "http://127.0.0.1:1933")
|
||||
self._api_key = os.environ.get("OPENVIKING_API_KEY", "")
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
h["X-API-Key"] = self._api_key
|
||||
return h
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
return (
|
||||
"# OpenViking Knowledge Base\n"
|
||||
f"Active. Endpoint: {self._endpoint}\n"
|
||||
"Use viking_search to find information, viking_read for details, "
|
||||
"viking_browse to explore the knowledge tree."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
"""OpenViking is tool-driven, no automatic prefetch."""
|
||||
return ""
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [SEARCH_SCHEMA, READ_SCHEMA, BROWSE_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
return json.dumps({"error": "httpx not installed"})
|
||||
|
||||
try:
|
||||
if tool_name == "viking_search":
|
||||
return self._search(httpx, args)
|
||||
elif tool_name == "viking_read":
|
||||
return self._read(httpx, args)
|
||||
elif tool_name == "viking_browse":
|
||||
return self._browse(httpx, args)
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
def _search(self, httpx, args: dict) -> str:
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
payload = {"query": query, "mode": args.get("mode", "auto")}
|
||||
if args.get("scope"):
|
||||
payload["scope"] = args["scope"]
|
||||
if args.get("limit"):
|
||||
payload["limit"] = args["limit"]
|
||||
resp = httpx.post(
|
||||
f"{self._endpoint}/v1/search",
|
||||
json=payload, headers=self._headers(), timeout=30.0,
|
||||
)
|
||||
return resp.text
|
||||
|
||||
def _read(self, httpx, args: dict) -> str:
|
||||
uri = args.get("uri", "")
|
||||
if not uri:
|
||||
return json.dumps({"error": "uri is required"})
|
||||
level = args.get("level", "overview")
|
||||
resp = httpx.post(
|
||||
f"{self._endpoint}/v1/read",
|
||||
json={"uri": uri, "level": level},
|
||||
headers=self._headers(), timeout=30.0,
|
||||
)
|
||||
return resp.text
|
||||
|
||||
def _browse(self, httpx, args: dict) -> str:
|
||||
action = args.get("action", "tree")
|
||||
path = args.get("path", "/")
|
||||
resp = httpx.post(
|
||||
f"{self._endpoint}/v1/browse",
|
||||
json={"action": action, "path": path},
|
||||
headers=self._headers(), timeout=30.0,
|
||||
)
|
||||
return resp.text
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
"""Register OpenViking as a memory provider plugin."""
|
||||
ctx.register_memory_provider(OpenVikingMemoryProvider())
|
||||
7
plugins/openviking-memory/plugin.yaml
Normal file
7
plugins/openviking-memory/plugin.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
name: openviking-memory
|
||||
version: 1.0.0
|
||||
description: >
|
||||
Read-only memory via OpenViking — semantic search, URI-based content
|
||||
reading, and filesystem browsing over a self-hosted knowledge server.
|
||||
requires_env:
|
||||
- OPENVIKING_ENDPOINT
|
||||
280
plugins/retaindb-memory/__init__.py
Normal file
280
plugins/retaindb-memory/__init__.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""RetainDB memory plugin — MemoryProvider interface.
|
||||
|
||||
Cross-session memory via RetainDB cloud API. Durable write-behind queue,
|
||||
semantic search with deduplication, and user profile retrieval.
|
||||
|
||||
Original PR #2732 by Alinxus, adapted to MemoryProvider ABC.
|
||||
|
||||
Config via environment variables:
|
||||
RETAINDB_API_KEY — API key (required)
|
||||
RETAINDB_BASE_URL — API endpoint (default: https://api.retaindb.com)
|
||||
RETAINDB_PROJECT — Project identifier (default: hermes)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_BASE_URL = "https://api.retaindb.com"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROFILE_SCHEMA = {
|
||||
"name": "retaindb_profile",
|
||||
"description": "Get the user's stable profile — preferences, facts, and patterns.",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
SEARCH_SCHEMA = {
|
||||
"name": "retaindb_search",
|
||||
"description": (
|
||||
"Semantic search across stored memories. Returns ranked results "
|
||||
"with relevance scores."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "What to search for."},
|
||||
"top_k": {"type": "integer", "description": "Max results (default: 8, max: 20)."},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
CONTEXT_SCHEMA = {
|
||||
"name": "retaindb_context",
|
||||
"description": "Synthesized 'what matters now' context block for the current task.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Current task or question."},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
REMEMBER_SCHEMA = {
|
||||
"name": "retaindb_remember",
|
||||
"description": "Persist an explicit fact or preference to long-term memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "The fact to remember."},
|
||||
"memory_type": {
|
||||
"type": "string",
|
||||
"enum": ["preference", "fact", "decision", "context"],
|
||||
"description": "Category (default: fact).",
|
||||
},
|
||||
"importance": {
|
||||
"type": "number",
|
||||
"description": "Importance 0-1 (default: 0.5).",
|
||||
},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
}
|
||||
|
||||
FORGET_SCHEMA = {
|
||||
"name": "retaindb_forget",
|
||||
"description": "Delete a specific memory by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"memory_id": {"type": "string", "description": "Memory ID to delete."},
|
||||
},
|
||||
"required": ["memory_id"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RetainDBMemoryProvider(MemoryProvider):
|
||||
"""RetainDB cloud memory with write-behind queue and semantic search."""
|
||||
|
||||
def __init__(self):
|
||||
self._api_key = ""
|
||||
self._base_url = _DEFAULT_BASE_URL
|
||||
self._project = "hermes"
|
||||
self._user_id = ""
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "retaindb"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return bool(os.environ.get("RETAINDB_API_KEY"))
|
||||
|
||||
def get_config_schema(self):
|
||||
return [
|
||||
{"key": "api_key", "description": "RetainDB API key", "secret": True, "required": True, "env_var": "RETAINDB_API_KEY", "url": "https://retaindb.com"},
|
||||
{"key": "base_url", "description": "API endpoint", "default": "https://api.retaindb.com"},
|
||||
{"key": "project", "description": "Project identifier", "default": "hermes"},
|
||||
]
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _api(self, method: str, path: str, **kwargs):
|
||||
"""Make an API call to RetainDB."""
|
||||
import requests
|
||||
url = f"{self._base_url}{path}"
|
||||
resp = requests.request(method, url, headers=self._headers(), timeout=30, **kwargs)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._api_key = os.environ.get("RETAINDB_API_KEY", "")
|
||||
self._base_url = os.environ.get("RETAINDB_BASE_URL", _DEFAULT_BASE_URL)
|
||||
self._project = os.environ.get("RETAINDB_PROJECT", "hermes")
|
||||
self._user_id = kwargs.get("user_id", "default")
|
||||
self._session_id = session_id
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
return (
|
||||
"# RetainDB Memory\n"
|
||||
f"Active. Project: {self._project}.\n"
|
||||
"Use retaindb_search to find memories, retaindb_remember to store facts, "
|
||||
"retaindb_profile for a user overview, retaindb_context for task-relevant context."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
result = self._prefetch_result
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
return ""
|
||||
return f"## RetainDB Memory\n{result}"
|
||||
|
||||
def queue_prefetch(self, query: str) -> None:
|
||||
def _run():
|
||||
try:
|
||||
data = self._api("POST", "/v1/recall", json={
|
||||
"project": self._project,
|
||||
"query": query,
|
||||
"user_id": self._user_id,
|
||||
"top_k": 5,
|
||||
})
|
||||
results = data.get("results", [])
|
||||
if results:
|
||||
lines = [r.get("content", "") for r in results if r.get("content")]
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = "\n".join(f"- {l}" for l in lines)
|
||||
except Exception as e:
|
||||
logger.debug("RetainDB prefetch failed: %s", e)
|
||||
|
||||
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="retaindb-prefetch")
|
||||
self._prefetch_thread.start()
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str) -> None:
|
||||
try:
|
||||
self._api("POST", "/v1/ingest", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"session_id": self._session_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": assistant_content},
|
||||
],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("RetainDB sync failed: %s", e)
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, REMEMBER_SCHEMA, FORGET_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
try:
|
||||
if tool_name == "retaindb_profile":
|
||||
data = self._api("GET", f"/v1/profile/{self._project}/{self._user_id}")
|
||||
return json.dumps(data)
|
||||
|
||||
elif tool_name == "retaindb_search":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
data = self._api("POST", "/v1/search", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"query": query,
|
||||
"top_k": min(int(args.get("top_k", 8)), 20),
|
||||
})
|
||||
return json.dumps(data)
|
||||
|
||||
elif tool_name == "retaindb_context":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
data = self._api("POST", "/v1/recall", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"query": query,
|
||||
"top_k": 5,
|
||||
})
|
||||
return json.dumps(data)
|
||||
|
||||
elif tool_name == "retaindb_remember":
|
||||
content = args.get("content", "")
|
||||
if not content:
|
||||
return json.dumps({"error": "content is required"})
|
||||
data = self._api("POST", "/v1/remember", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"content": content,
|
||||
"memory_type": args.get("memory_type", "fact"),
|
||||
"importance": float(args.get("importance", 0.5)),
|
||||
})
|
||||
return json.dumps(data)
|
||||
|
||||
elif tool_name == "retaindb_forget":
|
||||
memory_id = args.get("memory_id", "")
|
||||
if not memory_id:
|
||||
return json.dumps({"error": "memory_id is required"})
|
||||
data = self._api("DELETE", f"/v1/memory/{memory_id}")
|
||||
return json.dumps(data)
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
||||
if action == "add":
|
||||
try:
|
||||
self._api("POST", "/v1/remember", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"content": content,
|
||||
"memory_type": "preference" if target == "user" else "fact",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("RetainDB memory bridge failed: %s", e)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=5.0)
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
"""Register RetainDB as a memory provider plugin."""
|
||||
ctx.register_memory_provider(RetainDBMemoryProvider())
|
||||
7
plugins/retaindb-memory/plugin.yaml
Normal file
7
plugins/retaindb-memory/plugin.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
name: retaindb-memory
|
||||
version: 1.0.0
|
||||
description: >
|
||||
Cross-session memory via RetainDB — durable write-behind queue, semantic
|
||||
search with deduplication, user identity resolution, and profile retrieval.
|
||||
requires_env:
|
||||
- RETAINDB_API_KEY
|
||||
117
run_agent.py
117
run_agent.py
@@ -1043,7 +1043,20 @@ class AIAgent:
|
||||
self._memory_store.load_from_disk()
|
||||
except Exception:
|
||||
pass # Memory is optional -- don't break agent init
|
||||
|
||||
|
||||
# Memory provider manager — orchestrates built-in + plugin providers.
|
||||
# Existing Honcho code stays as-is (will migrate in a follow-up PR).
|
||||
# The manager provides the extension point for plugin memory backends.
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
self._memory_manager = MemoryManager()
|
||||
self._memory_manager.add_provider(BuiltinMemoryProvider(
|
||||
memory_store=self._memory_store,
|
||||
memory_enabled=self._memory_enabled,
|
||||
user_profile_enabled=self._user_profile_enabled,
|
||||
))
|
||||
# Plugin memory providers are added after Honcho init (below).
|
||||
|
||||
# Honcho AI-native memory (cross-session user modeling)
|
||||
# Reads $HERMES_HOME/honcho.json (instance) or ~/.honcho/config.json (global).
|
||||
self._honcho = None # HonchoSessionManager | None
|
||||
@@ -1114,6 +1127,53 @@ class AIAgent:
|
||||
self._user_profile_enabled = False
|
||||
logger.debug("peer %s memory_mode=honcho: local USER.md writes disabled", _hcfg.peer_name or "user")
|
||||
|
||||
# Register plugin memory providers with the manager.
|
||||
# At most ONE external provider is active at a time (configured via
|
||||
# memory.provider in config.yaml). Auto-detected from plugins that
|
||||
# called ctx.register_memory_provider() during discover_and_load().
|
||||
_configured_provider = mem_config.get("provider", "") if not skip_memory else ""
|
||||
if _configured_provider and not skip_memory:
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_memory_providers
|
||||
_found = False
|
||||
for plugin_provider in get_plugin_memory_providers():
|
||||
_pname = getattr(plugin_provider, "name", "unknown")
|
||||
if _pname != _configured_provider:
|
||||
logger.debug(
|
||||
"Memory provider '%s' skipped (config selects '%s')",
|
||||
_pname, _configured_provider,
|
||||
)
|
||||
continue
|
||||
try:
|
||||
if plugin_provider.is_available():
|
||||
self._memory_manager.add_provider(plugin_provider)
|
||||
plugin_provider.initialize(
|
||||
session_id=self.session_id or "",
|
||||
platform=self.platform,
|
||||
model=self.model,
|
||||
)
|
||||
_found = True
|
||||
if not self.quiet_mode:
|
||||
print(f" Memory provider: {_pname}")
|
||||
else:
|
||||
logger.warning(
|
||||
"Memory provider '%s' configured but not available "
|
||||
"(missing credentials or dependencies)", _pname,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Plugin memory provider '%s' init failed: %s", _pname, e,
|
||||
)
|
||||
if not _found and not self.quiet_mode:
|
||||
_available = [getattr(p, "name", "?") for p in get_plugin_memory_providers()]
|
||||
if _available:
|
||||
logger.warning(
|
||||
"memory.provider='%s' not found among registered providers: %s",
|
||||
_configured_provider, _available,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Plugin memory provider loading skipped: %s", e)
|
||||
|
||||
# Skills config: nudge interval for skill creation reminders
|
||||
self._skill_nudge_interval = 10
|
||||
try:
|
||||
@@ -2659,6 +2719,19 @@ class AIAgent:
|
||||
if user_block:
|
||||
prompt_parts.append(user_block)
|
||||
|
||||
# Plugin memory providers contribute system prompt blocks.
|
||||
# (Builtin provider's block is already handled above via _memory_store.)
|
||||
if hasattr(self, "_memory_manager"):
|
||||
for provider in self._memory_manager.providers:
|
||||
if provider.name == "builtin":
|
||||
continue # already handled above
|
||||
try:
|
||||
block = provider.system_prompt_block()
|
||||
if block and block.strip():
|
||||
prompt_parts.append(block)
|
||||
except Exception as e:
|
||||
logger.debug("Memory provider '%s' prompt block failed: %s", provider.name, e)
|
||||
|
||||
has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skill_view', 'skill_manage'])
|
||||
if has_skills_tools:
|
||||
avail_toolsets = {
|
||||
@@ -6099,6 +6172,23 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
logger.debug("Honcho prefetch failed (non-fatal): %s", e)
|
||||
|
||||
# Plugin memory provider prefetch (non-builtin providers only).
|
||||
self._plugin_memory_context = ""
|
||||
self._plugin_memory_turn_context = ""
|
||||
if hasattr(self, "_memory_manager"):
|
||||
for provider in self._memory_manager.providers:
|
||||
if provider.name == "builtin":
|
||||
continue
|
||||
try:
|
||||
ctx = provider.prefetch(original_user_message)
|
||||
if ctx and ctx.strip():
|
||||
if not conversation_history:
|
||||
self._plugin_memory_context += ("\n\n" + ctx if self._plugin_memory_context else ctx)
|
||||
else:
|
||||
self._plugin_memory_turn_context += ("\n\n" + ctx if self._plugin_memory_turn_context else ctx)
|
||||
except Exception as e:
|
||||
logger.debug("Memory provider '%s' prefetch failed: %s", provider.name, e)
|
||||
|
||||
# Add user message
|
||||
user_msg = {"role": "user", "content": user_message}
|
||||
messages.append(user_msg)
|
||||
@@ -6142,6 +6232,11 @@ class AIAgent:
|
||||
self._cached_system_prompt = (
|
||||
self._cached_system_prompt + "\n\n" + self._honcho_context
|
||||
).strip()
|
||||
# Plugin memory provider context (first turn → bake into system prompt)
|
||||
if self._plugin_memory_context:
|
||||
self._cached_system_prompt = (
|
||||
self._cached_system_prompt + "\n\n" + self._plugin_memory_context
|
||||
).strip()
|
||||
|
||||
# Plugin hook: on_session_start
|
||||
# Fired once when a brand-new session is created (not on
|
||||
@@ -6312,6 +6407,15 @@ class AIAgent:
|
||||
api_msg.get("content", ""), self._honcho_turn_context
|
||||
)
|
||||
|
||||
# Plugin memory provider turn context injection
|
||||
if idx == current_turn_user_idx and msg.get("role") == "user" and self._plugin_memory_turn_context:
|
||||
existing = api_msg.get("content", "")
|
||||
api_msg["content"] = (
|
||||
existing
|
||||
+ "\n\n[Plugin memory context — relevant memories for this turn.]\n\n"
|
||||
+ self._plugin_memory_turn_context
|
||||
)
|
||||
|
||||
# For ALL assistant messages, pass reasoning back to the API
|
||||
# This ensures multi-turn reasoning context is preserved
|
||||
if msg.get("role") == "assistant":
|
||||
@@ -7949,6 +8053,17 @@ class AIAgent:
|
||||
self._honcho_sync(original_user_message, final_response)
|
||||
self._queue_honcho_prefetch(original_user_message)
|
||||
|
||||
# Sync to plugin memory providers and queue prefetch for next turn
|
||||
if final_response and not interrupted and hasattr(self, "_memory_manager"):
|
||||
for provider in self._memory_manager.providers:
|
||||
if provider.name == "builtin":
|
||||
continue
|
||||
try:
|
||||
provider.sync_turn(original_user_message, final_response)
|
||||
provider.queue_prefetch(original_user_message)
|
||||
except Exception as e:
|
||||
logger.debug("Memory provider '%s' post-turn failed: %s", provider.name, e)
|
||||
|
||||
# Plugin hook: post_llm_call
|
||||
# Fired once per turn after the tool-calling loop completes.
|
||||
# Plugins can use this to persist conversation data (e.g. sync
|
||||
|
||||
341
tests/agent/test_memory_plugin_e2e.py
Normal file
341
tests/agent/test_memory_plugin_e2e.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""End-to-end test: a SQLite-backed memory plugin exercising the full interface.
|
||||
|
||||
This proves a real plugin can register as a MemoryProvider and get wired
|
||||
into the agent loop via MemoryManager. Uses SQLite + FTS5 (stdlib, no
|
||||
external deps, no API keys).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQLite FTS5 memory provider — a real, minimal plugin implementation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SQLiteMemoryProvider(MemoryProvider):
|
||||
"""Minimal SQLite + FTS5 memory provider for testing.
|
||||
|
||||
Demonstrates the full MemoryProvider interface with a real backend.
|
||||
No external dependencies — just stdlib sqlite3.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = ":memory:"):
|
||||
self._db_path = db_path
|
||||
self._conn = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sqlite_memory"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True # SQLite is always available
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._conn = sqlite3.connect(self._db_path)
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS memories
|
||||
USING fts5(content, context, session_id)
|
||||
""")
|
||||
self._session_id = session_id
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
if not self._conn:
|
||||
return ""
|
||||
count = self._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
if count == 0:
|
||||
return ""
|
||||
return (
|
||||
f"# SQLite Memory Plugin\n"
|
||||
f"Active. {count} memories stored.\n"
|
||||
f"Use sqlite_recall to search, sqlite_retain to store."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str) -> str:
|
||||
if not self._conn or not query:
|
||||
return ""
|
||||
# FTS5 search
|
||||
try:
|
||||
rows = self._conn.execute(
|
||||
"SELECT content FROM memories WHERE memories MATCH ? LIMIT 5",
|
||||
(query,)
|
||||
).fetchall()
|
||||
if not rows:
|
||||
return ""
|
||||
results = [row[0] for row in rows]
|
||||
return "## SQLite Memory\n" + "\n".join(f"- {r}" for r in results)
|
||||
except sqlite3.OperationalError:
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str) -> None:
|
||||
if not self._conn:
|
||||
return
|
||||
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(combined, "conversation", self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return [
|
||||
{
|
||||
"name": "sqlite_retain",
|
||||
"description": "Store a fact to SQLite memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "What to remember"},
|
||||
"context": {"type": "string", "description": "Category/context"},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "sqlite_recall",
|
||||
"description": "Search SQLite memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
if tool_name == "sqlite_retain":
|
||||
content = args.get("content", "")
|
||||
context = args.get("context", "explicit")
|
||||
if not content:
|
||||
return json.dumps({"error": "content is required"})
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(content, context, self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return json.dumps({"result": "Stored."})
|
||||
|
||||
elif tool_name == "sqlite_recall":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
try:
|
||||
rows = self._conn.execute(
|
||||
"SELECT content, context FROM memories WHERE memories MATCH ? LIMIT 10",
|
||||
(query,)
|
||||
).fetchall()
|
||||
results = [{"content": r[0], "context": r[1]} for r in rows]
|
||||
return json.dumps({"results": results})
|
||||
except sqlite3.OperationalError:
|
||||
return json.dumps({"results": []})
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
def on_memory_write(self, action, target, content):
|
||||
"""Mirror built-in memory writes to SQLite."""
|
||||
if action == "add" and self._conn:
|
||||
self._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
(content, f"builtin_{target}", self._session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def shutdown(self):
|
||||
if self._conn:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSQLiteMemoryPlugin:
|
||||
"""Full lifecycle test with the SQLite provider."""
|
||||
|
||||
def test_full_lifecycle(self):
|
||||
"""Exercise init → store → recall → sync → prefetch → shutdown."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
sqlite_mem = SQLiteMemoryProvider()
|
||||
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(sqlite_mem)
|
||||
|
||||
# Initialize
|
||||
mgr.initialize_all(session_id="test-session-1", platform="cli")
|
||||
assert sqlite_mem._conn is not None
|
||||
|
||||
# System prompt — empty at first
|
||||
prompt = mgr.build_system_prompt()
|
||||
assert "SQLite Memory Plugin" not in prompt
|
||||
|
||||
# Store via tool call
|
||||
result = json.loads(mgr.handle_tool_call(
|
||||
"sqlite_retain", {"content": "User prefers dark mode", "context": "preference"}
|
||||
))
|
||||
assert result["result"] == "Stored."
|
||||
|
||||
# System prompt now shows count
|
||||
prompt = mgr.build_system_prompt()
|
||||
assert "1 memories stored" in prompt
|
||||
|
||||
# Recall via tool call
|
||||
result = json.loads(mgr.handle_tool_call(
|
||||
"sqlite_recall", {"query": "dark mode"}
|
||||
))
|
||||
assert len(result["results"]) == 1
|
||||
assert "dark mode" in result["results"][0]["content"]
|
||||
|
||||
# Sync a turn (auto-stores conversation)
|
||||
mgr.sync_all("What's my theme?", "You prefer dark mode.")
|
||||
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
assert count == 2 # 1 explicit + 1 synced
|
||||
|
||||
# Prefetch for next turn
|
||||
prefetched = mgr.prefetch_all("dark mode")
|
||||
assert "dark mode" in prefetched
|
||||
|
||||
# Memory bridge — mirroring builtin writes
|
||||
mgr.on_memory_write("add", "user", "Timezone: US Pacific")
|
||||
count = sqlite_mem._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
assert count == 3
|
||||
|
||||
# Shutdown
|
||||
mgr.shutdown_all()
|
||||
assert sqlite_mem._conn is None
|
||||
|
||||
def test_tool_routing_with_builtin(self):
|
||||
"""Verify builtin + plugin tools coexist without conflict."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
sqlite_mem = SQLiteMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(sqlite_mem)
|
||||
mgr.initialize_all(session_id="test-2")
|
||||
|
||||
# Builtin has no tools
|
||||
assert len(builtin.get_tool_schemas()) == 0
|
||||
# SQLite has 2 tools
|
||||
schemas = mgr.get_all_tool_schemas()
|
||||
names = {s["name"] for s in schemas}
|
||||
assert names == {"sqlite_retain", "sqlite_recall"}
|
||||
|
||||
# Routing works
|
||||
assert mgr.has_tool("sqlite_retain")
|
||||
assert mgr.has_tool("sqlite_recall")
|
||||
assert not mgr.has_tool("memory") # builtin doesn't register this
|
||||
|
||||
def test_multiple_plugins_coexist(self):
|
||||
"""Two plugin providers can run simultaneously."""
|
||||
mgr = MemoryManager()
|
||||
p1 = SQLiteMemoryProvider()
|
||||
p2 = SQLiteMemoryProvider()
|
||||
# Hack name for p2
|
||||
p2._name_override = "sqlite_memory_2"
|
||||
original_name = p2.__class__.name
|
||||
type(p2).name = property(lambda self: getattr(self, '_name_override', 'sqlite_memory'))
|
||||
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
mgr.initialize_all(session_id="test-3")
|
||||
|
||||
# Store in p1
|
||||
p1._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
("fact from p1", "test", "test-3"),
|
||||
)
|
||||
p1._conn.commit()
|
||||
|
||||
# Store in p2
|
||||
p2._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
("fact from p2", "test", "test-3"),
|
||||
)
|
||||
p2._conn.commit()
|
||||
|
||||
# Prefetch merges both
|
||||
result = mgr.prefetch_all("fact")
|
||||
assert "fact from p1" in result
|
||||
assert "fact from p2" in result
|
||||
|
||||
# Sync goes to both
|
||||
mgr.sync_all("user msg", "assistant msg")
|
||||
count1 = p1._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
count2 = p2._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
assert count1 == 2 # 1 explicit + 1 synced
|
||||
assert count2 == 2
|
||||
|
||||
# Restore class
|
||||
type(p2).name = original_name
|
||||
mgr.shutdown_all()
|
||||
|
||||
def test_provider_failure_isolation(self):
|
||||
"""Failing provider doesn't break others."""
|
||||
mgr = MemoryManager()
|
||||
good = SQLiteMemoryProvider()
|
||||
bad = SQLiteMemoryProvider()
|
||||
|
||||
mgr.add_provider(good)
|
||||
mgr.add_provider(bad)
|
||||
mgr.initialize_all(session_id="test-4")
|
||||
|
||||
# Break bad provider's connection
|
||||
bad._conn.close()
|
||||
bad._conn = None
|
||||
|
||||
# Good provider still works
|
||||
good._conn.execute(
|
||||
"INSERT INTO memories (content, context, session_id) VALUES (?, ?, ?)",
|
||||
("still works", "test", "test-4"),
|
||||
)
|
||||
good._conn.commit()
|
||||
|
||||
# Sync — bad fails silently, good succeeds
|
||||
mgr.sync_all("user", "assistant")
|
||||
count = good._conn.execute("SELECT COUNT(*) FROM memories").fetchone()[0]
|
||||
assert count == 2
|
||||
|
||||
mgr.shutdown_all()
|
||||
|
||||
def test_plugin_registration_flow(self):
|
||||
"""Simulate the full plugin registration → agent init path."""
|
||||
from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest
|
||||
|
||||
# Simulate plugin discovery
|
||||
manager = PluginManager()
|
||||
manifest = PluginManifest(name="sqlite-memory", source="test")
|
||||
ctx = PluginContext(manifest, manager)
|
||||
|
||||
# Plugin registers its provider
|
||||
provider = SQLiteMemoryProvider()
|
||||
ctx.register_memory_provider(provider)
|
||||
|
||||
assert len(manager._memory_providers) == 1
|
||||
|
||||
# Simulate what AIAgent.__init__ does
|
||||
mem_mgr = MemoryManager()
|
||||
mem_mgr.add_provider(BuiltinMemoryProvider())
|
||||
for plugin_provider in manager._memory_providers:
|
||||
if plugin_provider.is_available():
|
||||
mem_mgr.add_provider(plugin_provider)
|
||||
plugin_provider.initialize(session_id="agent-session")
|
||||
|
||||
assert len(mem_mgr.providers) == 2
|
||||
assert mem_mgr.provider_names == ["builtin", "sqlite_memory"]
|
||||
assert provider._conn is not None # initialized = connection established
|
||||
|
||||
mem_mgr.shutdown_all()
|
||||
550
tests/agent/test_memory_provider.py
Normal file
550
tests/agent/test_memory_provider.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""Tests for the memory provider interface, manager, and builtin provider."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete test provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakeMemoryProvider(MemoryProvider):
|
||||
"""Minimal concrete provider for testing."""
|
||||
|
||||
def __init__(self, name="fake", available=True, tools=None):
|
||||
self._name = name
|
||||
self._available = available
|
||||
self._tools = tools or []
|
||||
self.initialized = False
|
||||
self.synced_turns = []
|
||||
self.prefetch_queries = []
|
||||
self.queued_prefetches = []
|
||||
self.turn_starts = []
|
||||
self.session_end_called = False
|
||||
self.pre_compress_called = False
|
||||
self.memory_writes = []
|
||||
self.shutdown_called = False
|
||||
self._prefetch_result = ""
|
||||
self._prompt_block = ""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._available
|
||||
|
||||
def initialize(self, session_id, **kwargs):
|
||||
self.initialized = True
|
||||
self._init_kwargs = {"session_id": session_id, **kwargs}
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
return self._prompt_block
|
||||
|
||||
def prefetch(self, query):
|
||||
self.prefetch_queries.append(query)
|
||||
return self._prefetch_result
|
||||
|
||||
def queue_prefetch(self, query):
|
||||
self.queued_prefetches.append(query)
|
||||
|
||||
def sync_turn(self, user_content, assistant_content):
|
||||
self.synced_turns.append((user_content, assistant_content))
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return self._tools
|
||||
|
||||
def handle_tool_call(self, tool_name, args, **kwargs):
|
||||
return json.dumps({"handled": tool_name, "args": args})
|
||||
|
||||
def shutdown(self):
|
||||
self.shutdown_called = True
|
||||
|
||||
def on_turn_start(self, turn_number, message):
|
||||
self.turn_starts.append((turn_number, message))
|
||||
|
||||
def on_session_end(self, messages):
|
||||
self.session_end_called = True
|
||||
|
||||
def on_pre_compress(self, messages):
|
||||
self.pre_compress_called = True
|
||||
|
||||
def on_memory_write(self, action, target, content):
|
||||
self.memory_writes.append((action, target, content))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider ABC tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMemoryProviderABC:
|
||||
def test_cannot_instantiate_abstract(self):
|
||||
"""ABC cannot be instantiated directly."""
|
||||
with pytest.raises(TypeError):
|
||||
MemoryProvider()
|
||||
|
||||
def test_concrete_provider_works(self):
|
||||
"""Concrete implementation can be instantiated."""
|
||||
p = FakeMemoryProvider()
|
||||
assert p.name == "fake"
|
||||
assert p.is_available()
|
||||
|
||||
def test_default_optional_hooks_are_noop(self):
|
||||
"""Optional hooks have default no-op implementations."""
|
||||
p = FakeMemoryProvider()
|
||||
# These should not raise
|
||||
p.on_turn_start(1, "hello")
|
||||
p.on_session_end([])
|
||||
p.on_pre_compress([])
|
||||
p.on_memory_write("add", "memory", "test")
|
||||
p.queue_prefetch("query")
|
||||
p.sync_turn("user", "assistant")
|
||||
p.shutdown()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryManager tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMemoryManager:
|
||||
def test_empty_manager(self):
|
||||
mgr = MemoryManager()
|
||||
assert mgr.providers == []
|
||||
assert mgr.provider_names == []
|
||||
assert mgr.get_all_tool_schemas() == []
|
||||
assert mgr.build_system_prompt() == ""
|
||||
assert mgr.prefetch_all("test") == ""
|
||||
|
||||
def test_add_provider(self):
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("test1")
|
||||
mgr.add_provider(p)
|
||||
assert len(mgr.providers) == 1
|
||||
assert mgr.provider_names == ["test1"]
|
||||
|
||||
def test_get_provider_by_name(self):
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("test1")
|
||||
mgr.add_provider(p)
|
||||
assert mgr.get_provider("test1") is p
|
||||
assert mgr.get_provider("nonexistent") is None
|
||||
|
||||
def test_multiple_providers(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
assert mgr.provider_names == ["p1", "p2"]
|
||||
|
||||
def test_system_prompt_merges_blocks(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p1._prompt_block = "Block from p1"
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
p2._prompt_block = "Block from p2"
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
result = mgr.build_system_prompt()
|
||||
assert "Block from p1" in result
|
||||
assert "Block from p2" in result
|
||||
|
||||
def test_system_prompt_skips_empty(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p1._prompt_block = "Has content"
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
p2._prompt_block = ""
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
result = mgr.build_system_prompt()
|
||||
assert result == "Has content"
|
||||
|
||||
def test_prefetch_merges_results(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p1._prefetch_result = "Memory from p1"
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
p2._prefetch_result = "Memory from p2"
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
result = mgr.prefetch_all("what do you know?")
|
||||
assert "Memory from p1" in result
|
||||
assert "Memory from p2" in result
|
||||
assert p1.prefetch_queries == ["what do you know?"]
|
||||
assert p2.prefetch_queries == ["what do you know?"]
|
||||
|
||||
def test_prefetch_skips_empty(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p1._prefetch_result = "Has memories"
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
p2._prefetch_result = ""
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
result = mgr.prefetch_all("query")
|
||||
assert result == "Has memories"
|
||||
|
||||
def test_queue_prefetch_all(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
mgr.queue_prefetch_all("next turn")
|
||||
assert p1.queued_prefetches == ["next turn"]
|
||||
assert p2.queued_prefetches == ["next turn"]
|
||||
|
||||
def test_sync_all(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
mgr.sync_all("user msg", "assistant msg")
|
||||
assert p1.synced_turns == [("user msg", "assistant msg")]
|
||||
assert p2.synced_turns == [("user msg", "assistant msg")]
|
||||
|
||||
def test_sync_failure_doesnt_block_others(self):
|
||||
"""If one provider's sync fails, others still run."""
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p1.sync_turn = MagicMock(side_effect=RuntimeError("boom"))
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
mgr.sync_all("user", "assistant")
|
||||
# p1 failed but p2 still synced
|
||||
assert p2.synced_turns == [("user", "assistant")]
|
||||
|
||||
# -- Tool routing -------------------------------------------------------
|
||||
|
||||
def test_tool_schemas_collected(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1", tools=[
|
||||
{"name": "recall_p1", "description": "P1 recall", "parameters": {}}
|
||||
])
|
||||
p2 = FakeMemoryProvider("p2", tools=[
|
||||
{"name": "recall_p2", "description": "P2 recall", "parameters": {}}
|
||||
])
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
schemas = mgr.get_all_tool_schemas()
|
||||
names = {s["name"] for s in schemas}
|
||||
assert names == {"recall_p1", "recall_p2"}
|
||||
|
||||
def test_tool_name_conflict_first_wins(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1", tools=[
|
||||
{"name": "shared_tool", "description": "From P1", "parameters": {}}
|
||||
])
|
||||
p2 = FakeMemoryProvider("p2", tools=[
|
||||
{"name": "shared_tool", "description": "From P2", "parameters": {}}
|
||||
])
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
assert mgr.has_tool("shared_tool")
|
||||
result = json.loads(mgr.handle_tool_call("shared_tool", {"q": "test"}))
|
||||
assert result["handled"] == "shared_tool"
|
||||
# Should be handled by p1 (first registered)
|
||||
|
||||
def test_handle_unknown_tool(self):
|
||||
mgr = MemoryManager()
|
||||
result = json.loads(mgr.handle_tool_call("nonexistent", {}))
|
||||
assert "error" in result
|
||||
|
||||
def test_tool_routing(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1", tools=[
|
||||
{"name": "p1_tool", "description": "P1", "parameters": {}}
|
||||
])
|
||||
p2 = FakeMemoryProvider("p2", tools=[
|
||||
{"name": "p2_tool", "description": "P2", "parameters": {}}
|
||||
])
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
r1 = json.loads(mgr.handle_tool_call("p1_tool", {"a": 1}))
|
||||
assert r1["handled"] == "p1_tool"
|
||||
r2 = json.loads(mgr.handle_tool_call("p2_tool", {"b": 2}))
|
||||
assert r2["handled"] == "p2_tool"
|
||||
|
||||
# -- Lifecycle hooks -----------------------------------------------------
|
||||
|
||||
def test_on_turn_start(self):
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("p")
|
||||
mgr.add_provider(p)
|
||||
mgr.on_turn_start(3, "hello")
|
||||
assert p.turn_starts == [(3, "hello")]
|
||||
|
||||
def test_on_session_end(self):
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("p")
|
||||
mgr.add_provider(p)
|
||||
mgr.on_session_end([{"role": "user", "content": "hi"}])
|
||||
assert p.session_end_called
|
||||
|
||||
def test_on_pre_compress(self):
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("p")
|
||||
mgr.add_provider(p)
|
||||
mgr.on_pre_compress([{"role": "user", "content": "old"}])
|
||||
assert p.pre_compress_called
|
||||
|
||||
def test_on_memory_write_skips_builtin(self):
|
||||
"""on_memory_write should skip the builtin provider."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
external = FakeMemoryProvider("external")
|
||||
mgr.add_provider(builtin)
|
||||
mgr.add_provider(external)
|
||||
|
||||
mgr.on_memory_write("add", "memory", "test fact")
|
||||
assert external.memory_writes == [("add", "memory", "test fact")]
|
||||
|
||||
def test_shutdown_all_reverse_order(self):
|
||||
mgr = MemoryManager()
|
||||
order = []
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p1.shutdown = lambda: order.append("p1")
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
p2.shutdown = lambda: order.append("p2")
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
mgr.shutdown_all()
|
||||
assert order == ["p2", "p1"] # reverse order
|
||||
|
||||
def test_initialize_all(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
mgr.initialize_all(session_id="test-123", platform="cli")
|
||||
assert p1.initialized
|
||||
assert p2.initialized
|
||||
assert p1._init_kwargs["session_id"] == "test-123"
|
||||
assert p1._init_kwargs["platform"] == "cli"
|
||||
|
||||
# -- Error resilience ---------------------------------------------------
|
||||
|
||||
def test_prefetch_failure_doesnt_block(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p1.prefetch = MagicMock(side_effect=RuntimeError("network error"))
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
p2._prefetch_result = "p2 memory"
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
result = mgr.prefetch_all("query")
|
||||
assert "p2 memory" in result
|
||||
|
||||
def test_system_prompt_failure_doesnt_block(self):
|
||||
mgr = MemoryManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p1.system_prompt_block = MagicMock(side_effect=RuntimeError("broken"))
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
p2._prompt_block = "works fine"
|
||||
mgr.add_provider(p1)
|
||||
mgr.add_provider(p2)
|
||||
|
||||
result = mgr.build_system_prompt()
|
||||
assert result == "works fine"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BuiltinMemoryProvider tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuiltinMemoryProvider:
|
||||
def test_name(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.name == "builtin"
|
||||
|
||||
def test_always_available(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.is_available()
|
||||
|
||||
def test_no_tools(self):
|
||||
"""Builtin provider exposes no tools (memory tool is agent-level)."""
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.get_tool_schemas() == []
|
||||
|
||||
def test_system_prompt_with_store(self):
|
||||
store = MagicMock()
|
||||
store.format_for_system_prompt.side_effect = lambda t: f"BLOCK_{t}" if t == "memory" else f"BLOCK_{t}"
|
||||
|
||||
p = BuiltinMemoryProvider(
|
||||
memory_store=store,
|
||||
memory_enabled=True,
|
||||
user_profile_enabled=True,
|
||||
)
|
||||
block = p.system_prompt_block()
|
||||
assert "BLOCK_memory" in block
|
||||
assert "BLOCK_user" in block
|
||||
|
||||
def test_system_prompt_memory_disabled(self):
|
||||
store = MagicMock()
|
||||
store.format_for_system_prompt.return_value = "content"
|
||||
|
||||
p = BuiltinMemoryProvider(
|
||||
memory_store=store,
|
||||
memory_enabled=False,
|
||||
user_profile_enabled=False,
|
||||
)
|
||||
assert p.system_prompt_block() == ""
|
||||
|
||||
def test_system_prompt_no_store(self):
|
||||
p = BuiltinMemoryProvider(memory_store=None, memory_enabled=True)
|
||||
assert p.system_prompt_block() == ""
|
||||
|
||||
def test_prefetch_returns_empty(self):
|
||||
p = BuiltinMemoryProvider()
|
||||
assert p.prefetch("anything") == ""
|
||||
|
||||
def test_store_property(self):
|
||||
store = MagicMock()
|
||||
p = BuiltinMemoryProvider(memory_store=store)
|
||||
assert p.store is store
|
||||
|
||||
def test_initialize_loads_from_disk(self):
|
||||
store = MagicMock()
|
||||
p = BuiltinMemoryProvider(memory_store=store)
|
||||
p.initialize(session_id="test")
|
||||
store.load_from_disk.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin registration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleProviderGating:
|
||||
"""Only the configured provider should activate."""
|
||||
|
||||
def test_no_provider_configured_means_builtin_only(self):
|
||||
"""When memory.provider is empty, no plugin providers activate."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
# Simulate what run_agent.py does when provider=""
|
||||
configured = ""
|
||||
available_plugins = [
|
||||
FakeMemoryProvider("holographic"),
|
||||
FakeMemoryProvider("mem0"),
|
||||
]
|
||||
# With empty config, no plugins should be added
|
||||
if configured:
|
||||
for p in available_plugins:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
def test_configured_provider_activates(self):
|
||||
"""Only the named provider should be added."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "holographic"
|
||||
p1 = FakeMemoryProvider("holographic")
|
||||
p2 = FakeMemoryProvider("mem0")
|
||||
p3 = FakeMemoryProvider("hindsight")
|
||||
|
||||
for p in [p1, p2, p3]:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin", "holographic"]
|
||||
assert p1.initialized is False # not initialized by the gating logic itself
|
||||
|
||||
def test_unavailable_provider_skipped(self):
|
||||
"""If the configured provider is unavailable, it should be skipped."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "holographic"
|
||||
p1 = FakeMemoryProvider("holographic", available=False)
|
||||
|
||||
for p in [p1]:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
def test_nonexistent_provider_results_in_builtin_only(self):
|
||||
"""If the configured name doesn't match any plugin, only builtin remains."""
|
||||
mgr = MemoryManager()
|
||||
builtin = BuiltinMemoryProvider()
|
||||
mgr.add_provider(builtin)
|
||||
|
||||
configured = "nonexistent"
|
||||
plugins = [FakeMemoryProvider("holographic"), FakeMemoryProvider("mem0")]
|
||||
|
||||
for p in plugins:
|
||||
if p.name == configured and p.is_available():
|
||||
mgr.add_provider(p)
|
||||
|
||||
assert mgr.provider_names == ["builtin"]
|
||||
|
||||
|
||||
class TestPluginMemoryProviderRegistration:
|
||||
def test_register_memory_provider(self):
|
||||
"""PluginContext.register_memory_provider adds to manager list."""
|
||||
from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest
|
||||
|
||||
manager = PluginManager()
|
||||
manifest = PluginManifest(
|
||||
name="test-plugin",
|
||||
version="1.0.0",
|
||||
description="Test",
|
||||
source="test",
|
||||
)
|
||||
ctx = PluginContext(manifest, manager)
|
||||
|
||||
fake_provider = FakeMemoryProvider("test-mem")
|
||||
ctx.register_memory_provider(fake_provider)
|
||||
|
||||
assert len(manager._memory_providers) == 1
|
||||
assert manager._memory_providers[0] is fake_provider
|
||||
|
||||
def test_get_plugin_memory_providers(self):
|
||||
"""get_plugin_memory_providers returns registered providers."""
|
||||
from hermes_cli.plugins import PluginManager, get_plugin_memory_providers
|
||||
|
||||
with patch("hermes_cli.plugins.get_plugin_manager") as mock_get:
|
||||
mgr = PluginManager()
|
||||
p1 = FakeMemoryProvider("p1")
|
||||
p2 = FakeMemoryProvider("p2")
|
||||
mgr._memory_providers = [p1, p2]
|
||||
mock_get.return_value = mgr
|
||||
|
||||
result = get_plugin_memory_providers()
|
||||
assert len(result) == 2
|
||||
assert result[0] is p1
|
||||
assert result[1] is p2
|
||||
0
tests/plugins/__init__.py
Normal file
0
tests/plugins/__init__.py
Normal file
248
tests/plugins/test_holographic.py
Normal file
248
tests/plugins/test_holographic.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Tests for holographic.py — pure HRR math operations.
|
||||
|
||||
All tests are synthetic: no filesystem, no database, no external state.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# Plugin path: prefer home dir install, fall back to in-repo copy
|
||||
_plugin_dir = Path.home() / ".hermes" / "plugins" / "hermes-memory-store"
|
||||
if not _plugin_dir.exists():
|
||||
_plugin_dir = Path(__file__).resolve().parent.parent.parent / "plugins" / "hermes-memory-store"
|
||||
sys.path.insert(0, str(_plugin_dir))
|
||||
|
||||
from holographic import (
|
||||
_HAS_NUMPY,
|
||||
bind,
|
||||
bundle,
|
||||
bytes_to_phases,
|
||||
encode_atom,
|
||||
encode_fact,
|
||||
encode_text,
|
||||
phases_to_bytes,
|
||||
similarity,
|
||||
snr_estimate,
|
||||
unbind,
|
||||
)
|
||||
|
||||
|
||||
DIM = 256 # Smaller dim for fast tests; math properties hold at any dim.
|
||||
|
||||
|
||||
class TestEncodeAtom:
|
||||
def test_deterministic(self):
|
||||
"""Same input always produces the identical vector."""
|
||||
v1 = encode_atom("hello", DIM)
|
||||
v2 = encode_atom("hello", DIM)
|
||||
np.testing.assert_array_equal(v1, v2)
|
||||
|
||||
def test_shape_and_dtype(self):
|
||||
v = encode_atom("test", DIM)
|
||||
assert v.shape == (DIM,)
|
||||
assert v.dtype == np.float64
|
||||
|
||||
def test_phase_range(self):
|
||||
"""All phases must be in [0, 2π)."""
|
||||
v = encode_atom("range_check", DIM)
|
||||
assert np.all(v >= 0.0)
|
||||
assert np.all(v < 2.0 * np.pi)
|
||||
|
||||
def test_near_orthogonal(self):
|
||||
"""Random unrelated words should have near-zero similarity."""
|
||||
words = ["apple", "quantum", "bicycle", "telescope", "jazz"]
|
||||
vectors = [encode_atom(w, DIM) for w in words]
|
||||
for i in range(len(vectors)):
|
||||
for j in range(i + 1, len(vectors)):
|
||||
sim = similarity(vectors[i], vectors[j])
|
||||
assert abs(sim) < 0.15, f"'{words[i]}' vs '{words[j]}': sim={sim:.4f}"
|
||||
|
||||
|
||||
class TestBindUnbind:
|
||||
def test_roundtrip(self):
|
||||
"""unbind(bind(a, b), b) should recover a exactly."""
|
||||
a = encode_atom("concept_a", DIM)
|
||||
b = encode_atom("concept_b", DIM)
|
||||
bound = bind(a, b)
|
||||
recovered = unbind(bound, b)
|
||||
np.testing.assert_allclose(recovered, a, atol=1e-10)
|
||||
|
||||
def test_commutative(self):
|
||||
"""bind(a, b) == bind(b, a) — phase addition is commutative."""
|
||||
a = encode_atom("alpha", DIM)
|
||||
b = encode_atom("beta", DIM)
|
||||
np.testing.assert_allclose(bind(a, b), bind(b, a), atol=1e-10)
|
||||
|
||||
def test_bound_dissimilar_to_inputs(self):
|
||||
"""The bound vector should be quasi-orthogonal to both inputs."""
|
||||
a = encode_atom("dog", DIM)
|
||||
b = encode_atom("cat", DIM)
|
||||
bound = bind(a, b)
|
||||
assert abs(similarity(bound, a)) < 0.15
|
||||
assert abs(similarity(bound, b)) < 0.15
|
||||
|
||||
|
||||
class TestBundle:
|
||||
def test_preserves_similarity(self):
|
||||
"""Bundled vector should be similar to each of its components."""
|
||||
vecs = [encode_atom(f"item_{i}", DIM) for i in range(3)]
|
||||
bundled = bundle(*vecs)
|
||||
for v in vecs:
|
||||
sim = similarity(bundled, v)
|
||||
assert sim > 0.2, f"Bundle lost signal: sim={sim:.4f}"
|
||||
|
||||
def test_capacity_degrades(self):
|
||||
"""Similarity to each component should decrease as more items are added."""
|
||||
target = encode_atom("target", DIM)
|
||||
sims = []
|
||||
for n in [2, 5, 10, 20]:
|
||||
others = [encode_atom(f"noise_{i}", DIM) for i in range(n - 1)]
|
||||
bundled = bundle(target, *others)
|
||||
sims.append(similarity(bundled, target))
|
||||
# Similarity should generally decrease (allow minor non-monotonicity)
|
||||
assert sims[0] > sims[-1], f"No degradation: {sims}"
|
||||
|
||||
|
||||
class TestSimilarity:
|
||||
def test_identity(self):
|
||||
"""similarity(a, a) should be exactly 1.0."""
|
||||
a = encode_atom("self", DIM)
|
||||
assert similarity(a, a) == pytest.approx(1.0)
|
||||
|
||||
def test_orthogonal_near_zero(self):
|
||||
"""Random vectors should have similarity near 0."""
|
||||
sims = []
|
||||
for i in range(10):
|
||||
a = encode_atom(f"rand_a_{i}", DIM)
|
||||
b = encode_atom(f"rand_b_{i}", DIM)
|
||||
sims.append(similarity(a, b))
|
||||
mean_sim = np.mean(sims)
|
||||
assert abs(mean_sim) < 0.1, f"Mean similarity too high: {mean_sim:.4f}"
|
||||
|
||||
|
||||
class TestEncodeText:
|
||||
def test_order_invariant(self):
|
||||
"""Bag-of-words should be order-invariant."""
|
||||
v1 = encode_text("the quick brown fox", DIM)
|
||||
v2 = encode_text("fox brown quick the", DIM)
|
||||
sim = similarity(v1, v2)
|
||||
assert sim == pytest.approx(1.0, abs=1e-10)
|
||||
|
||||
def test_similar_texts_high_similarity(self):
|
||||
"""Texts sharing words should have high similarity."""
|
||||
v1 = encode_text("the cat sat on the mat", DIM)
|
||||
v2 = encode_text("the cat on the mat", DIM)
|
||||
sim = similarity(v1, v2)
|
||||
assert sim > 0.5, f"Similar texts low sim: {sim:.4f}"
|
||||
|
||||
def test_empty_text(self):
|
||||
"""Empty text should return a valid vector (the __hrr_empty__ atom)."""
|
||||
v = encode_text("", DIM)
|
||||
assert v.shape == (DIM,)
|
||||
|
||||
|
||||
class TestEncodeFact:
|
||||
def test_entity_extraction(self):
|
||||
"""Unbinding entity from fact should recover content signal."""
|
||||
content = "prefers rust for systems programming"
|
||||
entities = ["peppi"]
|
||||
|
||||
fact_vec = encode_fact(content, entities, DIM)
|
||||
content_vec = encode_text(content, DIM)
|
||||
|
||||
# Unbind: fact - bind(entity, ROLE_ENTITY) should be similar to bind(content, ROLE_CONTENT)
|
||||
role_entity = encode_atom("__hrr_role_entity__", DIM)
|
||||
role_content = encode_atom("__hrr_role_content__", DIM)
|
||||
entity_vec = encode_atom("peppi", DIM)
|
||||
|
||||
# Extract what's associated with peppi's entity role
|
||||
probe = unbind(fact_vec, bind(entity_vec, role_entity))
|
||||
|
||||
# The extracted signal should have nonzero similarity to the content-role binding
|
||||
content_bound = bind(content_vec, role_content)
|
||||
sim = similarity(probe, content_bound)
|
||||
# At DIM=256, 2-component bundle: SNR≈11, but phase cosine similarity compresses
|
||||
# the signal. Noise baseline is ~0.035 std; signal should be above 0.03.
|
||||
assert sim > 0.03, f"Entity extraction failed: sim={sim:.4f}"
|
||||
|
||||
def test_multiple_entities(self):
|
||||
"""Facts with multiple entities should encode all of them."""
|
||||
fact_vec = encode_fact("loves pizza", ["alice", "bob"], DIM)
|
||||
assert fact_vec.shape == (DIM,)
|
||||
# Both entities should be recoverable (above noise floor)
|
||||
role_entity = encode_atom("__hrr_role_entity__", DIM)
|
||||
for name in ["alice", "bob"]:
|
||||
entity_vec = encode_atom(name, DIM)
|
||||
probe = unbind(fact_vec, bind(entity_vec, role_entity))
|
||||
# Just verify it's a valid vector (deeper tests would check signal)
|
||||
assert probe.shape == (DIM,)
|
||||
|
||||
|
||||
class TestSerialization:
|
||||
def test_roundtrip(self):
|
||||
"""bytes_to_phases(phases_to_bytes(v)) should recover v exactly."""
|
||||
v = encode_atom("serialize_me", DIM)
|
||||
data = phases_to_bytes(v)
|
||||
recovered = bytes_to_phases(data)
|
||||
np.testing.assert_array_equal(v, recovered)
|
||||
|
||||
def test_byte_size(self):
|
||||
"""float64 * dim = 8 * dim bytes."""
|
||||
v = encode_atom("size_check", DIM)
|
||||
data = phases_to_bytes(v)
|
||||
assert len(data) == DIM * 8
|
||||
|
||||
|
||||
class TestSNREstimate:
|
||||
def test_formula(self):
|
||||
"""SNR should match sqrt(dim / n_items)."""
|
||||
import math
|
||||
assert snr_estimate(1024, 4) == pytest.approx(math.sqrt(1024 / 4))
|
||||
assert snr_estimate(1024, 256) == pytest.approx(math.sqrt(1024 / 256))
|
||||
|
||||
def test_empty(self):
|
||||
"""Zero items → infinite SNR."""
|
||||
assert snr_estimate(1024, 0) == float("inf")
|
||||
|
||||
def test_warning_logged(self, caplog):
|
||||
"""SNR < 2.0 should emit a warning."""
|
||||
import logging
|
||||
with caplog.at_level(logging.WARNING):
|
||||
snr_estimate(4, 4) # SNR = 1.0
|
||||
assert "near capacity" in caplog.text.lower()
|
||||
|
||||
|
||||
class TestNumpyGuard:
|
||||
def test_raises_without_numpy(self):
|
||||
"""All public functions should raise RuntimeError when numpy is absent."""
|
||||
import holographic
|
||||
|
||||
original = holographic._HAS_NUMPY
|
||||
try:
|
||||
holographic._HAS_NUMPY = False
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
encode_atom("test", DIM)
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
bind(np.zeros(DIM), np.zeros(DIM))
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
unbind(np.zeros(DIM), np.zeros(DIM))
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
bundle(np.zeros(DIM))
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
similarity(np.zeros(DIM), np.zeros(DIM))
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
encode_text("test", DIM)
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
encode_fact("test", ["e"], DIM)
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
phases_to_bytes(np.zeros(DIM))
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
bytes_to_phases(b"\x00" * DIM * 8)
|
||||
with pytest.raises(RuntimeError, match="numpy is required"):
|
||||
snr_estimate(DIM, 1)
|
||||
finally:
|
||||
holographic._HAS_NUMPY = original
|
||||
336
tests/plugins/test_holographic_provider.py
Normal file
336
tests/plugins/test_holographic_provider.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""Tests for the holographic memory MemoryProvider adapter.
|
||||
|
||||
Tests the HolographicMemoryProvider interface — registration, tool handling,
|
||||
prefetch, session end hooks, and memory bridging.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Add plugin dir to path so imports work
|
||||
_plugin_dir = Path(__file__).resolve().parent.parent.parent / "plugins" / "hermes-memory-store"
|
||||
sys.path.insert(0, str(_plugin_dir))
|
||||
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.builtin_memory_provider import BuiltinMemoryProvider
|
||||
|
||||
|
||||
def _make_provider(tmp_path, config=None):
|
||||
"""Create a HolographicMemoryProvider with a temp DB."""
|
||||
# Import inside function to avoid module-level issues
|
||||
sys.path.insert(0, str(_plugin_dir))
|
||||
from plugins import HolographicMemoryProvider # noqa: F811
|
||||
# Use the full import path
|
||||
from importlib import import_module
|
||||
init_mod = import_module("plugins.hermes-memory-store")
|
||||
|
||||
cfg = config or {}
|
||||
cfg.setdefault("db_path", str(tmp_path / "test.db"))
|
||||
provider = init_mod.HolographicMemoryProvider(config=cfg)
|
||||
provider.initialize(session_id="test-session")
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(tmp_path):
|
||||
"""Create an initialized holographic provider."""
|
||||
sys.path.insert(0, str(_plugin_dir.parent))
|
||||
# Direct import
|
||||
spec_path = _plugin_dir / "__init__.py"
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"hermes_memory_store_test",
|
||||
spec_path,
|
||||
submodule_search_locations=[str(_plugin_dir)],
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
sys.modules["hermes_memory_store_test"] = mod
|
||||
# Pre-populate submodule references
|
||||
store_spec = importlib.util.spec_from_file_location(
|
||||
"hermes_memory_store_test.store",
|
||||
_plugin_dir / "store.py",
|
||||
)
|
||||
store_mod = importlib.util.module_from_spec(store_spec)
|
||||
sys.modules["hermes_memory_store_test.store"] = store_mod
|
||||
store_spec.loader.exec_module(store_mod)
|
||||
|
||||
retrieval_spec = importlib.util.spec_from_file_location(
|
||||
"hermes_memory_store_test.retrieval",
|
||||
_plugin_dir / "retrieval.py",
|
||||
)
|
||||
retrieval_mod = importlib.util.module_from_spec(retrieval_spec)
|
||||
sys.modules["hermes_memory_store_test.retrieval"] = retrieval_mod
|
||||
retrieval_spec.loader.exec_module(retrieval_mod)
|
||||
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
cfg = {"db_path": str(tmp_path / "test.db")}
|
||||
p = mod.HolographicMemoryProvider(config=cfg)
|
||||
p.initialize(session_id="test-session")
|
||||
yield p
|
||||
p.shutdown()
|
||||
|
||||
# Cleanup
|
||||
for key in list(sys.modules):
|
||||
if key.startswith("hermes_memory_store_test"):
|
||||
del sys.modules[key]
|
||||
|
||||
|
||||
class TestProviderRegistration:
|
||||
def test_register_calls_register_memory_provider(self, tmp_path):
|
||||
"""register(ctx) should call ctx.register_memory_provider()."""
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"hermes_memory_store_reg",
|
||||
_plugin_dir / "__init__.py",
|
||||
submodule_search_locations=[str(_plugin_dir)],
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
sys.modules["hermes_memory_store_reg"] = mod
|
||||
|
||||
store_spec = importlib.util.spec_from_file_location(
|
||||
"hermes_memory_store_reg.store", _plugin_dir / "store.py")
|
||||
store_mod = importlib.util.module_from_spec(store_spec)
|
||||
sys.modules["hermes_memory_store_reg.store"] = store_mod
|
||||
store_spec.loader.exec_module(store_mod)
|
||||
|
||||
retrieval_spec = importlib.util.spec_from_file_location(
|
||||
"hermes_memory_store_reg.retrieval", _plugin_dir / "retrieval.py")
|
||||
retrieval_mod = importlib.util.module_from_spec(retrieval_spec)
|
||||
sys.modules["hermes_memory_store_reg.retrieval"] = retrieval_mod
|
||||
retrieval_spec.loader.exec_module(retrieval_mod)
|
||||
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
ctx = MagicMock()
|
||||
mod.register(ctx)
|
||||
ctx.register_memory_provider.assert_called_once()
|
||||
registered = ctx.register_memory_provider.call_args[0][0]
|
||||
assert registered.name == "holographic"
|
||||
|
||||
for key in list(sys.modules):
|
||||
if key.startswith("hermes_memory_store_reg"):
|
||||
del sys.modules[key]
|
||||
|
||||
|
||||
class TestToolHandling:
|
||||
def test_add_and_search(self, provider):
|
||||
"""Add a fact via tool call, then search for it."""
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "User prefers vim over emacs"}
|
||||
))
|
||||
assert "fact_id" in result
|
||||
fact_id = result["fact_id"]
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "search", "query": "vim emacs"}
|
||||
))
|
||||
assert result["count"] >= 1
|
||||
contents = [r["content"] for r in result["results"]]
|
||||
assert any("vim" in c for c in contents)
|
||||
|
||||
def test_add_and_probe(self, provider):
|
||||
"""Add facts about an entity, then probe it."""
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Peppi uses Rust for systems work"}
|
||||
)
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Peppi prefers Neovim"}
|
||||
)
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "probe", "entity": "peppi"}
|
||||
))
|
||||
assert result["count"] >= 1
|
||||
|
||||
def test_related(self, provider):
|
||||
"""Test related entity lookup."""
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Peppi uses Rust for systems work"}
|
||||
)
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Rust ensures memory safety"}
|
||||
)
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "related", "entity": "rust"}
|
||||
))
|
||||
assert "results" in result
|
||||
assert "count" in result
|
||||
|
||||
def test_reason(self, provider):
|
||||
"""Test compositional reasoning across entities."""
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Peppi uses Rust for backend work"}
|
||||
)
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "The backend handles API requests"}
|
||||
)
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "reason", "entities": ["peppi", "backend"]}
|
||||
))
|
||||
assert "results" in result
|
||||
|
||||
def test_feedback(self, provider):
|
||||
"""Test trust scoring via feedback."""
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Test feedback fact"}
|
||||
))
|
||||
fact_id = result["fact_id"]
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_feedback", {"action": "helpful", "fact_id": fact_id}
|
||||
))
|
||||
assert "error" not in result
|
||||
|
||||
def test_update_and_remove(self, provider):
|
||||
"""Test CRUD operations."""
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Will be updated"}
|
||||
))
|
||||
fact_id = result["fact_id"]
|
||||
|
||||
# Update
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "update", "fact_id": fact_id, "content": "Updated content"}
|
||||
))
|
||||
assert result["updated"]
|
||||
|
||||
# Remove
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "remove", "fact_id": fact_id}
|
||||
))
|
||||
assert result["removed"]
|
||||
|
||||
def test_all_handlers_return_json(self, provider):
|
||||
"""Every tool call must return valid JSON."""
|
||||
# Add a fact first
|
||||
r = provider.handle_tool_call("fact_store", {"action": "add", "content": "JSON test"})
|
||||
parsed = json.loads(r)
|
||||
fact_id = parsed["fact_id"]
|
||||
|
||||
# Test every action
|
||||
actions = [
|
||||
("fact_store", {"action": "search", "query": "JSON"}),
|
||||
("fact_store", {"action": "list"}),
|
||||
("fact_store", {"action": "probe", "entity": "test"}),
|
||||
("fact_store", {"action": "related", "entity": "test"}),
|
||||
("fact_store", {"action": "reason", "entities": ["test"]}),
|
||||
("fact_store", {"action": "contradict"}),
|
||||
("fact_feedback", {"action": "helpful", "fact_id": fact_id}),
|
||||
]
|
||||
for tool_name, args in actions:
|
||||
result = provider.handle_tool_call(tool_name, args)
|
||||
json.loads(result) # Should not raise
|
||||
|
||||
|
||||
class TestPrefetch:
|
||||
def test_prefetch_returns_matching_facts(self, provider):
|
||||
"""Prefetch should return facts matching the query."""
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "The deploy pipeline uses Docker"}
|
||||
)
|
||||
result = provider.prefetch("deploy pipeline")
|
||||
assert "Docker" in result or "deploy" in result
|
||||
|
||||
def test_prefetch_empty_when_no_facts(self, provider):
|
||||
assert provider.prefetch("anything") == ""
|
||||
|
||||
|
||||
class TestSystemPromptBlock:
|
||||
def test_empty_when_no_facts(self, provider):
|
||||
assert provider.system_prompt_block() == ""
|
||||
|
||||
def test_shows_count_with_facts(self, provider):
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Fact one"}
|
||||
)
|
||||
provider.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Fact two"}
|
||||
)
|
||||
block = provider.system_prompt_block()
|
||||
assert "2 facts" in block
|
||||
assert "Holographic" in block
|
||||
|
||||
|
||||
class TestSessionEndHook:
|
||||
def test_extracts_preferences(self, provider):
|
||||
"""on_session_end should extract preference patterns."""
|
||||
provider._config["auto_extract"] = True
|
||||
messages = [
|
||||
{"role": "user", "content": "I prefer dark mode for all my editors"},
|
||||
{"role": "assistant", "content": "Noted, I'll remember that."},
|
||||
]
|
||||
provider.on_session_end(messages)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "list"}
|
||||
))
|
||||
assert result["count"] >= 1
|
||||
|
||||
def test_skips_when_disabled(self, provider):
|
||||
"""on_session_end should do nothing when auto_extract is False."""
|
||||
provider._config["auto_extract"] = False
|
||||
messages = [
|
||||
{"role": "user", "content": "I prefer dark mode"},
|
||||
]
|
||||
provider.on_session_end(messages)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "list"}
|
||||
))
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_skips_assistant_messages(self, provider):
|
||||
"""Only user messages should be scanned."""
|
||||
provider._config["auto_extract"] = True
|
||||
messages = [
|
||||
{"role": "assistant", "content": "I prefer to help you with that"},
|
||||
]
|
||||
provider.on_session_end(messages)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "list"}
|
||||
))
|
||||
assert result["count"] == 0
|
||||
|
||||
|
||||
class TestMemoryBridge:
|
||||
def test_mirrors_builtin_writes(self, provider):
|
||||
"""on_memory_write should store facts from the builtin memory tool."""
|
||||
provider.on_memory_write("add", "user", "Timezone: US Pacific")
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"fact_store", {"action": "search", "query": "timezone pacific"}
|
||||
))
|
||||
assert result["count"] >= 1
|
||||
|
||||
|
||||
class TestManagerIntegration:
|
||||
def test_coexists_with_builtin(self, provider):
|
||||
"""Holographic provider works alongside builtin in MemoryManager."""
|
||||
mgr = MemoryManager()
|
||||
mgr.add_provider(BuiltinMemoryProvider())
|
||||
mgr.add_provider(provider)
|
||||
|
||||
assert mgr.provider_names == ["builtin", "holographic"]
|
||||
|
||||
# Tools from holographic are available
|
||||
schemas = mgr.get_all_tool_schemas()
|
||||
names = {s["name"] for s in schemas}
|
||||
assert "fact_store" in names
|
||||
assert "fact_feedback" in names
|
||||
|
||||
# Tool routing works
|
||||
result = json.loads(mgr.handle_tool_call(
|
||||
"fact_store", {"action": "add", "content": "Manager integration test"}
|
||||
))
|
||||
assert result["status"] == "added"
|
||||
|
||||
# Memory bridge fires
|
||||
mgr.on_memory_write("add", "memory", "Test fact from builtin")
|
||||
result = json.loads(mgr.handle_tool_call(
|
||||
"fact_store", {"action": "search", "query": "test fact builtin"}
|
||||
))
|
||||
assert result["count"] >= 1
|
||||
Reference in New Issue
Block a user