mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 07:21:37 +08:00
Compare commits
3 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec6daae0e1 | ||
|
|
b3e406635f | ||
|
|
31764e175c |
@@ -1,29 +1,45 @@
|
||||
"""RetainDB memory plugin — MemoryProvider interface.
|
||||
|
||||
Cross-session memory via RetainDB cloud API. Durable write-behind queue,
|
||||
semantic search with deduplication, and user profile retrieval.
|
||||
Cross-session memory via RetainDB cloud API.
|
||||
|
||||
Original PR #2732 by Alinxus, adapted to MemoryProvider ABC.
|
||||
Features:
|
||||
- Correct API routes for all operations
|
||||
- Durable SQLite write-behind queue (crash-safe, async ingest)
|
||||
- Semantic search + user profile retrieval
|
||||
- Context query with deduplication overlay
|
||||
- Dialectic synthesis (LLM-powered user understanding, prefetched each turn)
|
||||
- Agent self-model (persona + instructions from SOUL.md, prefetched each turn)
|
||||
- Shared file store tools (upload, list, read, ingest, delete)
|
||||
- Explicit memory tools (profile, search, context, remember, forget)
|
||||
|
||||
Config via environment variables:
|
||||
RETAINDB_API_KEY — API key (required)
|
||||
RETAINDB_BASE_URL — API endpoint (default: https://api.retaindb.com)
|
||||
RETAINDB_PROJECT — Project identifier (default: hermes)
|
||||
Config (env vars or hermes config.yaml under retaindb:):
|
||||
RETAINDB_API_KEY — API key (required)
|
||||
RETAINDB_BASE_URL — API endpoint (default: https://api.retaindb.com)
|
||||
RETAINDB_PROJECT — Project identifier (optional — defaults to "default")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import quote
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_BASE_URL = "https://api.retaindb.com"
|
||||
_ASYNC_SHUTDOWN = object()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -32,16 +48,13 @@ _DEFAULT_BASE_URL = "https://api.retaindb.com"
|
||||
|
||||
PROFILE_SCHEMA = {
|
||||
"name": "retaindb_profile",
|
||||
"description": "Get the user's stable profile — preferences, facts, and patterns.",
|
||||
"description": "Get the user's stable profile — preferences, facts, and patterns recalled from long-term memory.",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
SEARCH_SCHEMA = {
|
||||
"name": "retaindb_search",
|
||||
"description": (
|
||||
"Semantic search across stored memories. Returns ranked results "
|
||||
"with relevance scores."
|
||||
),
|
||||
"description": "Semantic search across stored memories. Returns ranked results with relevance scores.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -54,7 +67,7 @@ SEARCH_SCHEMA = {
|
||||
|
||||
CONTEXT_SCHEMA = {
|
||||
"name": "retaindb_context",
|
||||
"description": "Synthesized 'what matters now' context block for the current task.",
|
||||
"description": "Synthesized context block — what matters most for the current task, pulled from long-term memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -66,20 +79,17 @@ CONTEXT_SCHEMA = {
|
||||
|
||||
REMEMBER_SCHEMA = {
|
||||
"name": "retaindb_remember",
|
||||
"description": "Persist an explicit fact or preference to long-term memory.",
|
||||
"description": "Persist an explicit fact, preference, or decision to long-term memory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "The fact to remember."},
|
||||
"memory_type": {
|
||||
"type": "string",
|
||||
"enum": ["preference", "fact", "decision", "context"],
|
||||
"description": "Category (default: fact).",
|
||||
},
|
||||
"importance": {
|
||||
"type": "number",
|
||||
"description": "Importance 0-1 (default: 0.5).",
|
||||
"enum": ["factual", "preference", "goal", "instruction", "event", "opinion"],
|
||||
"description": "Category (default: factual).",
|
||||
},
|
||||
"importance": {"type": "number", "description": "Importance 0-1 (default: 0.7)."},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
@@ -97,23 +107,368 @@ FORGET_SCHEMA = {
|
||||
},
|
||||
}
|
||||
|
||||
FILE_UPLOAD_SCHEMA = {
|
||||
"name": "retaindb_upload_file",
|
||||
"description": "Upload a file to the shared RetainDB file store. Returns an rdb:// URI any agent can reference.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"local_path": {"type": "string", "description": "Local file path to upload."},
|
||||
"remote_path": {"type": "string", "description": "Destination path, e.g. /reports/q1.pdf"},
|
||||
"scope": {"type": "string", "enum": ["USER", "PROJECT", "ORG"], "description": "Access scope (default: PROJECT)."},
|
||||
"ingest": {"type": "boolean", "description": "Also extract memories from file after upload (default: false)."},
|
||||
},
|
||||
"required": ["local_path"],
|
||||
},
|
||||
}
|
||||
|
||||
FILE_LIST_SCHEMA = {
|
||||
"name": "retaindb_list_files",
|
||||
"description": "List files in the shared file store.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prefix": {"type": "string", "description": "Path prefix to filter by, e.g. /reports/"},
|
||||
"limit": {"type": "integer", "description": "Max results (default: 50)."},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
FILE_READ_SCHEMA = {
|
||||
"name": "retaindb_read_file",
|
||||
"description": "Read the text content of a stored file by its file ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {"type": "string", "description": "File ID returned from upload or list."},
|
||||
},
|
||||
"required": ["file_id"],
|
||||
},
|
||||
}
|
||||
|
||||
FILE_INGEST_SCHEMA = {
|
||||
"name": "retaindb_ingest_file",
|
||||
"description": "Chunk, embed, and extract memories from a stored file. Makes its contents searchable.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {"type": "string", "description": "File ID to ingest."},
|
||||
},
|
||||
"required": ["file_id"],
|
||||
},
|
||||
}
|
||||
|
||||
FILE_DELETE_SCHEMA = {
|
||||
"name": "retaindb_delete_file",
|
||||
"description": "Delete a stored file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {"type": "string", "description": "File ID to delete."},
|
||||
},
|
||||
"required": ["file_id"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
# HTTP client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _Client:
|
||||
def __init__(self, api_key: str, base_url: str, project: str):
|
||||
self.api_key = api_key
|
||||
self.base_url = re.sub(r"/+$", "", base_url)
|
||||
self.project = project
|
||||
|
||||
def _headers(self, path: str) -> dict:
|
||||
token = self.api_key.replace("Bearer ", "").strip()
|
||||
h = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"x-sdk-runtime": "hermes-plugin",
|
||||
}
|
||||
if path.startswith("/v1/memory") or path.startswith("/v1/context"):
|
||||
h["X-API-Key"] = token
|
||||
return h
|
||||
|
||||
def request(self, method: str, path: str, *, params=None, json_body=None, timeout: float = 8.0) -> Any:
|
||||
import requests
|
||||
url = f"{self.base_url}{path}"
|
||||
resp = requests.request(
|
||||
method.upper(), url,
|
||||
params=params,
|
||||
json=json_body if method.upper() not in {"GET", "DELETE"} else None,
|
||||
headers=self._headers(path),
|
||||
timeout=timeout,
|
||||
)
|
||||
try:
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
payload = resp.text
|
||||
if not resp.ok:
|
||||
msg = ""
|
||||
if isinstance(payload, dict):
|
||||
msg = str(payload.get("message") or payload.get("error") or "")
|
||||
raise RuntimeError(f"RetainDB {method} {path} failed ({resp.status_code}): {msg or payload}")
|
||||
return payload
|
||||
|
||||
# ── Memory ────────────────────────────────────────────────────────────────
|
||||
|
||||
def query_context(self, user_id: str, session_id: str, query: str, max_tokens: int = 1200) -> dict:
|
||||
return self.request("POST", "/v1/context/query", json_body={
|
||||
"project": self.project,
|
||||
"query": query,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"include_memories": True,
|
||||
"max_tokens": max_tokens,
|
||||
})
|
||||
|
||||
def search(self, user_id: str, session_id: str, query: str, top_k: int = 8) -> dict:
|
||||
return self.request("POST", "/v1/memory/search", json_body={
|
||||
"project": self.project,
|
||||
"query": query,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"top_k": top_k,
|
||||
"include_pending": True,
|
||||
})
|
||||
|
||||
def get_profile(self, user_id: str) -> dict:
|
||||
try:
|
||||
return self.request("GET", f"/v1/memory/profile/{quote(user_id, safe='')}", params={"project": self.project, "include_pending": "true"})
|
||||
except Exception:
|
||||
return self.request("GET", "/v1/memories", params={"project": self.project, "user_id": user_id, "limit": "200"})
|
||||
|
||||
def add_memory(self, user_id: str, session_id: str, content: str, memory_type: str = "factual", importance: float = 0.7) -> dict:
|
||||
try:
|
||||
return self.request("POST", "/v1/memory", json_body={
|
||||
"project": self.project, "content": content, "memory_type": memory_type,
|
||||
"user_id": user_id, "session_id": session_id, "importance": importance, "write_mode": "sync",
|
||||
}, timeout=5.0)
|
||||
except Exception:
|
||||
return self.request("POST", "/v1/memories", json_body={
|
||||
"project": self.project, "content": content, "memory_type": memory_type,
|
||||
"user_id": user_id, "session_id": session_id, "importance": importance,
|
||||
}, timeout=5.0)
|
||||
|
||||
def delete_memory(self, memory_id: str) -> dict:
|
||||
try:
|
||||
return self.request("DELETE", f"/v1/memory/{quote(memory_id, safe='')}", timeout=5.0)
|
||||
except Exception:
|
||||
return self.request("DELETE", f"/v1/memories/{quote(memory_id, safe='')}", timeout=5.0)
|
||||
|
||||
def ingest_session(self, user_id: str, session_id: str, messages: list, timeout: float = 15.0) -> dict:
|
||||
return self.request("POST", "/v1/memory/ingest/session", json_body={
|
||||
"project": self.project, "session_id": session_id, "user_id": user_id,
|
||||
"messages": messages, "write_mode": "sync",
|
||||
}, timeout=timeout)
|
||||
|
||||
def ask_user(self, user_id: str, query: str, reasoning_level: str = "low") -> dict:
|
||||
return self.request("POST", f"/v1/memory/profile/{quote(user_id, safe='')}/ask", json_body={
|
||||
"project": self.project, "query": query, "reasoning_level": reasoning_level,
|
||||
}, timeout=8.0)
|
||||
|
||||
def get_agent_model(self, agent_id: str) -> dict:
|
||||
return self.request("GET", f"/v1/memory/agent/{quote(agent_id, safe='')}/model", params={"project": self.project}, timeout=4.0)
|
||||
|
||||
def seed_agent_identity(self, agent_id: str, content: str, source: str = "soul_md") -> dict:
|
||||
return self.request("POST", f"/v1/memory/agent/{quote(agent_id, safe='')}/seed", json_body={
|
||||
"project": self.project, "content": content, "source": source,
|
||||
}, timeout=20.0)
|
||||
|
||||
# ── Files ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def upload_file(self, data: bytes, filename: str, remote_path: str, mime_type: str, scope: str, project_id: str | None) -> dict:
|
||||
import io
|
||||
import requests
|
||||
url = f"{self.base_url}/v1/files"
|
||||
token = self.api_key.replace("Bearer ", "").strip()
|
||||
headers = {"Authorization": f"Bearer {token}", "x-sdk-runtime": "hermes-plugin"}
|
||||
fields = {"path": remote_path, "scope": scope.upper()}
|
||||
if project_id:
|
||||
fields["project_id"] = project_id
|
||||
resp = requests.post(url, files={"file": (filename, io.BytesIO(data), mime_type)}, data=fields, headers=headers, timeout=30)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
def list_files(self, prefix: str | None = None, limit: int = 50) -> dict:
|
||||
params: dict = {"limit": limit}
|
||||
if prefix:
|
||||
params["prefix"] = prefix
|
||||
return self.request("GET", "/v1/files", params=params)
|
||||
|
||||
def get_file(self, file_id: str) -> dict:
|
||||
return self.request("GET", f"/v1/files/{quote(file_id, safe='')}")
|
||||
|
||||
def read_file_content(self, file_id: str) -> bytes:
|
||||
import requests
|
||||
token = self.api_key.replace("Bearer ", "").strip()
|
||||
url = f"{self.base_url}/v1/files/{quote(file_id, safe='')}/content"
|
||||
resp = requests.get(url, headers={"Authorization": f"Bearer {token}", "x-sdk-runtime": "hermes-plugin"}, timeout=30, allow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
def ingest_file(self, file_id: str, user_id: str | None = None, agent_id: str | None = None) -> dict:
|
||||
body: dict = {}
|
||||
if user_id:
|
||||
body["user_id"] = user_id
|
||||
if agent_id:
|
||||
body["agent_id"] = agent_id
|
||||
return self.request("POST", f"/v1/files/{quote(file_id, safe='')}/ingest", json_body=body, timeout=60.0)
|
||||
|
||||
def delete_file(self, file_id: str) -> dict:
|
||||
return self.request("DELETE", f"/v1/files/{quote(file_id, safe='')}", timeout=5.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Durable write-behind queue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _WriteQueue:
|
||||
"""SQLite-backed async write queue. Survives crashes — pending rows replay on startup."""
|
||||
|
||||
def __init__(self, client: _Client, db_path: Path):
|
||||
self._client = client
|
||||
self._db_path = db_path
|
||||
self._q: queue.Queue = queue.Queue()
|
||||
self._thread = threading.Thread(target=self._loop, name="retaindb-writer", daemon=True)
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Thread-local connection cache — one connection per thread, reused.
|
||||
self._local = threading.local()
|
||||
self._init_db()
|
||||
self._thread.start()
|
||||
# Replay any rows left from a previous crash
|
||||
for row_id, user_id, session_id, msgs_json in self._pending_rows():
|
||||
self._q.put((row_id, user_id, session_id, json.loads(msgs_json)))
|
||||
|
||||
def _get_conn(self) -> sqlite3.Connection:
|
||||
"""Return a cached connection for the current thread."""
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is None:
|
||||
conn = sqlite3.connect(str(self._db_path), timeout=30)
|
||||
conn.row_factory = sqlite3.Row
|
||||
self._local.conn = conn
|
||||
return conn
|
||||
|
||||
def _init_db(self) -> None:
|
||||
conn = self._get_conn()
|
||||
conn.execute("""CREATE TABLE IF NOT EXISTS pending (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT, session_id TEXT, messages_json TEXT,
|
||||
created_at TEXT, last_error TEXT
|
||||
)""")
|
||||
conn.commit()
|
||||
|
||||
def _pending_rows(self) -> list:
|
||||
conn = self._get_conn()
|
||||
return conn.execute("SELECT id, user_id, session_id, messages_json FROM pending ORDER BY id ASC LIMIT 200").fetchall()
|
||||
|
||||
def enqueue(self, user_id: str, session_id: str, messages: list) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn = self._get_conn()
|
||||
cur = conn.execute(
|
||||
"INSERT INTO pending (user_id, session_id, messages_json, created_at) VALUES (?,?,?,?)",
|
||||
(user_id, session_id, json.dumps(messages, ensure_ascii=False), now),
|
||||
)
|
||||
row_id = cur.lastrowid
|
||||
conn.commit()
|
||||
self._q.put((row_id, user_id, session_id, messages))
|
||||
|
||||
def _flush_row(self, row_id: int, user_id: str, session_id: str, messages: list) -> None:
|
||||
try:
|
||||
self._client.ingest_session(user_id, session_id, messages)
|
||||
conn = self._get_conn()
|
||||
conn.execute("DELETE FROM pending WHERE id = ?", (row_id,))
|
||||
conn.commit()
|
||||
except Exception as exc:
|
||||
logger.warning("RetainDB ingest failed (will retry): %s", exc)
|
||||
conn = self._get_conn()
|
||||
conn.execute("UPDATE pending SET last_error = ? WHERE id = ?", (str(exc), row_id))
|
||||
conn.commit()
|
||||
time.sleep(2)
|
||||
|
||||
def _loop(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
item = self._q.get(timeout=5)
|
||||
if item is _ASYNC_SHUTDOWN:
|
||||
break
|
||||
self._flush_row(*item)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.error("RetainDB writer error: %s", exc)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._q.put(_ASYNC_SHUTDOWN)
|
||||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay formatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_overlay(profile: dict, query_result: dict, local_entries: list[str] | None = None) -> str:
|
||||
def _compact(s: str) -> str:
|
||||
return re.sub(r"\s+", " ", str(s or "")).strip()[:320]
|
||||
|
||||
def _norm(s: str) -> str:
|
||||
return re.sub(r"[^a-z0-9 ]", "", _compact(s).lower())
|
||||
|
||||
seen: list[str] = [_norm(e) for e in (local_entries or []) if _norm(e)]
|
||||
profile_items: list[str] = []
|
||||
for m in list((profile or {}).get("memories") or [])[:5]:
|
||||
c = _compact((m or {}).get("content") or "")
|
||||
n = _norm(c)
|
||||
if c and n not in seen:
|
||||
seen.append(n)
|
||||
profile_items.append(c)
|
||||
|
||||
query_items: list[str] = []
|
||||
for r in list((query_result or {}).get("results") or [])[:5]:
|
||||
c = _compact((r or {}).get("content") or "")
|
||||
n = _norm(c)
|
||||
if c and n not in seen:
|
||||
seen.append(n)
|
||||
query_items.append(c)
|
||||
|
||||
if not profile_items and not query_items:
|
||||
return ""
|
||||
|
||||
lines = ["[RetainDB Context]", "Profile:"]
|
||||
lines += [f"- {i}" for i in profile_items] or ["- None"]
|
||||
lines.append("Relevant memories:")
|
||||
lines += [f"- {i}" for i in query_items] or ["- None"]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main plugin class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RetainDBMemoryProvider(MemoryProvider):
|
||||
"""RetainDB cloud memory with write-behind queue and semantic search."""
|
||||
"""RetainDB cloud memory — durable queue, semantic search, dialectic synthesis, shared files."""
|
||||
|
||||
def __init__(self):
|
||||
self._api_key = ""
|
||||
self._base_url = _DEFAULT_BASE_URL
|
||||
self._project = "hermes"
|
||||
self._user_id = ""
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
self._sync_thread = None
|
||||
self._client: _Client | None = None
|
||||
self._queue: _WriteQueue | None = None
|
||||
self._user_id = "default"
|
||||
self._session_id = ""
|
||||
self._agent_id = "hermes"
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Prefetch caches
|
||||
self._context_result = ""
|
||||
self._dialectic_result = ""
|
||||
self._agent_model: dict = {}
|
||||
|
||||
# Prefetch thread tracking — prevents accumulation on rapid calls
|
||||
self._prefetch_threads: list[threading.Thread] = []
|
||||
|
||||
# ── Core identity ──────────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -122,179 +477,287 @@ class RetainDBMemoryProvider(MemoryProvider):
|
||||
def is_available(self) -> bool:
|
||||
return bool(os.environ.get("RETAINDB_API_KEY"))
|
||||
|
||||
def get_config_schema(self):
|
||||
def get_config_schema(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"key": "api_key", "description": "RetainDB API key", "secret": True, "required": True, "env_var": "RETAINDB_API_KEY", "url": "https://retaindb.com"},
|
||||
{"key": "base_url", "description": "API endpoint", "default": "https://api.retaindb.com"},
|
||||
{"key": "project", "description": "Project identifier", "default": "hermes"},
|
||||
{"key": "base_url", "description": "API endpoint", "default": _DEFAULT_BASE_URL},
|
||||
{"key": "project", "description": "Project identifier (optional — uses 'default' project if not set)", "default": ""},
|
||||
]
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _api(self, method: str, path: str, **kwargs):
|
||||
"""Make an API call to RetainDB."""
|
||||
import requests
|
||||
url = f"{self._base_url}{path}"
|
||||
resp = requests.request(method, url, headers=self._headers(), timeout=30, **kwargs)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
# ── Lifecycle ──────────────────────────────────────────────────────────
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._api_key = os.environ.get("RETAINDB_API_KEY", "")
|
||||
self._base_url = os.environ.get("RETAINDB_BASE_URL", _DEFAULT_BASE_URL)
|
||||
self._user_id = kwargs.get("user_id", "default")
|
||||
self._session_id = session_id
|
||||
api_key = os.environ.get("RETAINDB_API_KEY", "")
|
||||
base_url = re.sub(r"/+$", "", os.environ.get("RETAINDB_BASE_URL", _DEFAULT_BASE_URL))
|
||||
|
||||
# Derive profile-scoped project name so different profiles don't
|
||||
# share server-side memory. Explicit RETAINDB_PROJECT always wins.
|
||||
explicit_project = os.environ.get("RETAINDB_PROJECT")
|
||||
if explicit_project:
|
||||
self._project = explicit_project
|
||||
# Project resolution: RETAINDB_PROJECT > hermes-<profile> > "default"
|
||||
# If unset, the API auto-creates and uses the "default" project — no config required.
|
||||
explicit = os.environ.get("RETAINDB_PROJECT")
|
||||
if explicit:
|
||||
project = explicit
|
||||
else:
|
||||
hermes_home = kwargs.get("hermes_home", "")
|
||||
hermes_home = str(kwargs.get("hermes_home", ""))
|
||||
profile_name = os.path.basename(hermes_home) if hermes_home else ""
|
||||
# Default profile (~/.hermes) → "hermes"; named profiles → "hermes-<name>"
|
||||
if profile_name and profile_name != ".hermes":
|
||||
self._project = f"hermes-{profile_name}"
|
||||
else:
|
||||
self._project = "hermes"
|
||||
project = f"hermes-{profile_name}" if (profile_name and profile_name not in {"", ".hermes"}) else "default"
|
||||
|
||||
self._client = _Client(api_key, base_url, project)
|
||||
self._session_id = session_id
|
||||
self._user_id = kwargs.get("user_id", "default") or "default"
|
||||
self._agent_id = kwargs.get("agent_id", "hermes") or "hermes"
|
||||
|
||||
hermes_home_path = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
db_path = hermes_home_path / "retaindb_queue.db"
|
||||
self._queue = _WriteQueue(self._client, db_path)
|
||||
|
||||
# Seed agent identity from SOUL.md in background
|
||||
soul_path = hermes_home_path / "SOUL.md"
|
||||
if soul_path.exists():
|
||||
soul_content = soul_path.read_text(encoding="utf-8", errors="replace").strip()
|
||||
if soul_content:
|
||||
threading.Thread(
|
||||
target=self._seed_soul,
|
||||
args=(soul_content,),
|
||||
name="retaindb-soul-seed",
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
def _seed_soul(self, content: str) -> None:
|
||||
try:
|
||||
self._client.seed_agent_identity(self._agent_id, content, source="soul_md")
|
||||
except Exception as exc:
|
||||
logger.debug("RetainDB soul seed failed: %s", exc)
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
project = self._client.project if self._client else "retaindb"
|
||||
return (
|
||||
"# RetainDB Memory\n"
|
||||
f"Active. Project: {self._project}.\n"
|
||||
f"Active. Project: {project}.\n"
|
||||
"Use retaindb_search to find memories, retaindb_remember to store facts, "
|
||||
"retaindb_profile for a user overview, retaindb_context for task-relevant context."
|
||||
"retaindb_profile for a user overview, retaindb_context for current-task context."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
result = self._prefetch_result
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
return ""
|
||||
return f"## RetainDB Memory\n{result}"
|
||||
# ── Background prefetch (fires at turn-end, consumed next turn-start) ──
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
def _run():
|
||||
try:
|
||||
data = self._api("POST", "/v1/recall", json={
|
||||
"project": self._project,
|
||||
"query": query,
|
||||
"user_id": self._user_id,
|
||||
"top_k": 5,
|
||||
})
|
||||
results = data.get("results", [])
|
||||
if results:
|
||||
lines = [r.get("content", "") for r in results if r.get("content")]
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = "\n".join(f"- {l}" for l in lines)
|
||||
except Exception as e:
|
||||
logger.debug("RetainDB prefetch failed: %s", e)
|
||||
"""Fire context + dialectic + agent model prefetches in background."""
|
||||
if not self._client:
|
||||
return
|
||||
# Wait for any still-running prefetch threads before spawning new ones.
|
||||
# Prevents thread accumulation if turns fire faster than prefetches complete.
|
||||
for t in self._prefetch_threads:
|
||||
t.join(timeout=2.0)
|
||||
threads = [
|
||||
threading.Thread(target=self._prefetch_context, args=(query,), name="retaindb-ctx", daemon=True),
|
||||
threading.Thread(target=self._prefetch_dialectic, args=(query,), name="retaindb-dialectic", daemon=True),
|
||||
threading.Thread(target=self._prefetch_agent_model, name="retaindb-agent-model", daemon=True),
|
||||
]
|
||||
self._prefetch_threads = threads
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="retaindb-prefetch")
|
||||
self._prefetch_thread.start()
|
||||
def _prefetch_context(self, query: str) -> None:
|
||||
try:
|
||||
query_result = self._client.query_context(self._user_id, self._session_id, query)
|
||||
profile = self._client.get_profile(self._user_id)
|
||||
overlay = _build_overlay(profile, query_result)
|
||||
with self._lock:
|
||||
self._context_result = overlay
|
||||
except Exception as exc:
|
||||
logger.debug("RetainDB context prefetch failed: %s", exc)
|
||||
|
||||
def _prefetch_dialectic(self, query: str) -> None:
|
||||
try:
|
||||
result = self._client.ask_user(self._user_id, query, reasoning_level=self._reasoning_level(query))
|
||||
answer = str(result.get("answer") or "")
|
||||
if answer:
|
||||
with self._lock:
|
||||
self._dialectic_result = answer
|
||||
except Exception as exc:
|
||||
logger.debug("RetainDB dialectic prefetch failed: %s", exc)
|
||||
|
||||
def _prefetch_agent_model(self) -> None:
|
||||
try:
|
||||
model = self._client.get_agent_model(self._agent_id)
|
||||
if model.get("memory_count", 0) > 0:
|
||||
with self._lock:
|
||||
self._agent_model = model
|
||||
except Exception as exc:
|
||||
logger.debug("RetainDB agent model prefetch failed: %s", exc)
|
||||
|
||||
@staticmethod
|
||||
def _reasoning_level(query: str) -> str:
|
||||
n = len(query)
|
||||
if n < 120:
|
||||
return "low"
|
||||
if n < 400:
|
||||
return "medium"
|
||||
return "high"
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
"""Consume prefetched results and return them as a context block."""
|
||||
with self._lock:
|
||||
context = self._context_result
|
||||
dialectic = self._dialectic_result
|
||||
agent_model = self._agent_model
|
||||
self._context_result = ""
|
||||
self._dialectic_result = ""
|
||||
self._agent_model = {}
|
||||
|
||||
parts: list[str] = []
|
||||
if context:
|
||||
parts.append(context)
|
||||
if dialectic:
|
||||
parts.append(f"[RetainDB User Synthesis]\n{dialectic}")
|
||||
if agent_model and agent_model.get("memory_count", 0) > 0:
|
||||
model_lines: list[str] = []
|
||||
if agent_model.get("persona"):
|
||||
model_lines.append(f"Persona: {agent_model['persona']}")
|
||||
if agent_model.get("persistent_instructions"):
|
||||
model_lines.append("Instructions:\n" + "\n".join(f"- {i}" for i in agent_model["persistent_instructions"]))
|
||||
if agent_model.get("working_style"):
|
||||
model_lines.append(f"Working style: {agent_model['working_style']}")
|
||||
if model_lines:
|
||||
parts.append("[RetainDB Agent Self-Model]\n" + "\n".join(model_lines))
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
# ── Turn sync ──────────────────────────────────────────────────────────
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Ingest conversation turn in background (non-blocking)."""
|
||||
def _sync():
|
||||
try:
|
||||
self._api("POST", "/v1/ingest", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"session_id": self._session_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": assistant_content},
|
||||
],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning("RetainDB sync failed: %s", e)
|
||||
"""Queue turn for async ingest. Returns immediately."""
|
||||
if not self._queue or not user_content:
|
||||
return
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
self._queue.enqueue(
|
||||
self._user_id,
|
||||
session_id or self._session_id,
|
||||
[
|
||||
{"role": "user", "content": user_content, "timestamp": now},
|
||||
{"role": "assistant", "content": assistant_content, "timestamp": now},
|
||||
],
|
||||
)
|
||||
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=5.0)
|
||||
self._sync_thread = threading.Thread(target=_sync, daemon=True, name="retaindb-sync")
|
||||
self._sync_thread.start()
|
||||
# ── Tools ──────────────────────────────────────────────────────────────
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, REMEMBER_SCHEMA, FORGET_SCHEMA]
|
||||
return [
|
||||
PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA,
|
||||
REMEMBER_SCHEMA, FORGET_SCHEMA,
|
||||
FILE_UPLOAD_SCHEMA, FILE_LIST_SCHEMA, FILE_READ_SCHEMA,
|
||||
FILE_INGEST_SCHEMA, FILE_DELETE_SCHEMA,
|
||||
]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
if not self._client:
|
||||
return json.dumps({"error": "RetainDB not initialized"})
|
||||
try:
|
||||
if tool_name == "retaindb_profile":
|
||||
data = self._api("GET", f"/v1/profile/{self._project}/{self._user_id}")
|
||||
return json.dumps(data)
|
||||
return json.dumps(self._dispatch(tool_name, args))
|
||||
except Exception as exc:
|
||||
return json.dumps({"error": str(exc)})
|
||||
|
||||
elif tool_name == "retaindb_search":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
data = self._api("POST", "/v1/search", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"query": query,
|
||||
"top_k": min(int(args.get("top_k", 8)), 20),
|
||||
})
|
||||
return json.dumps(data)
|
||||
def _dispatch(self, tool_name: str, args: dict) -> Any:
|
||||
c = self._client
|
||||
|
||||
elif tool_name == "retaindb_context":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "query is required"})
|
||||
data = self._api("POST", "/v1/recall", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"query": query,
|
||||
"top_k": 5,
|
||||
})
|
||||
return json.dumps(data)
|
||||
if tool_name == "retaindb_profile":
|
||||
return c.get_profile(self._user_id)
|
||||
|
||||
elif tool_name == "retaindb_remember":
|
||||
content = args.get("content", "")
|
||||
if not content:
|
||||
return json.dumps({"error": "content is required"})
|
||||
data = self._api("POST", "/v1/remember", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"content": content,
|
||||
"memory_type": args.get("memory_type", "fact"),
|
||||
"importance": float(args.get("importance", 0.5)),
|
||||
})
|
||||
return json.dumps(data)
|
||||
if tool_name == "retaindb_search":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return {"error": "query is required"}
|
||||
return c.search(self._user_id, self._session_id, query, top_k=min(int(args.get("top_k", 8)), 20))
|
||||
|
||||
elif tool_name == "retaindb_forget":
|
||||
memory_id = args.get("memory_id", "")
|
||||
if not memory_id:
|
||||
return json.dumps({"error": "memory_id is required"})
|
||||
data = self._api("DELETE", f"/v1/memory/{memory_id}")
|
||||
return json.dumps(data)
|
||||
if tool_name == "retaindb_context":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return {"error": "query is required"}
|
||||
query_result = c.query_context(self._user_id, self._session_id, query)
|
||||
profile = c.get_profile(self._user_id)
|
||||
overlay = _build_overlay(profile, query_result)
|
||||
return {"context": overlay, "raw": query_result}
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
if tool_name == "retaindb_remember":
|
||||
content = args.get("content", "")
|
||||
if not content:
|
||||
return {"error": "content is required"}
|
||||
return c.add_memory(
|
||||
self._user_id, self._session_id, content,
|
||||
memory_type=args.get("memory_type", "factual"),
|
||||
importance=float(args.get("importance", 0.7)),
|
||||
)
|
||||
|
||||
if tool_name == "retaindb_forget":
|
||||
memory_id = args.get("memory_id", "")
|
||||
if not memory_id:
|
||||
return {"error": "memory_id is required"}
|
||||
return c.delete_memory(memory_id)
|
||||
|
||||
# ── File tools ──────────────────────────────────────────────────────
|
||||
|
||||
if tool_name == "retaindb_upload_file":
|
||||
local_path = args.get("local_path", "")
|
||||
if not local_path:
|
||||
return {"error": "local_path is required"}
|
||||
path_obj = Path(local_path)
|
||||
if not path_obj.exists():
|
||||
return {"error": f"File not found: {local_path}"}
|
||||
data = path_obj.read_bytes()
|
||||
import mimetypes
|
||||
mime = mimetypes.guess_type(path_obj.name)[0] or "application/octet-stream"
|
||||
remote_path = args.get("remote_path") or f"/{path_obj.name}"
|
||||
result = c.upload_file(data, path_obj.name, remote_path, mime, args.get("scope", "PROJECT"), None)
|
||||
if args.get("ingest") and result.get("file", {}).get("id"):
|
||||
ingest = c.ingest_file(result["file"]["id"], user_id=self._user_id, agent_id=self._agent_id)
|
||||
result["ingest"] = ingest
|
||||
return result
|
||||
|
||||
if tool_name == "retaindb_list_files":
|
||||
return c.list_files(prefix=args.get("prefix"), limit=int(args.get("limit", 50)))
|
||||
|
||||
if tool_name == "retaindb_read_file":
|
||||
file_id = args.get("file_id", "")
|
||||
if not file_id:
|
||||
return {"error": "file_id is required"}
|
||||
meta = c.get_file(file_id)
|
||||
file_info = meta.get("file") or {}
|
||||
mime = (file_info.get("mime_type") or "").lower()
|
||||
raw = c.read_file_content(file_id)
|
||||
if not (mime.startswith("text/") or any(file_info.get("name", "").endswith(e) for e in (".txt", ".md", ".json", ".csv", ".yaml", ".yml", ".xml", ".html"))):
|
||||
return {"file_id": file_id, "rdb_uri": file_info.get("rdb_uri"), "name": file_info.get("name"), "content": None, "note": "Binary file — use retaindb_ingest_file to extract text into memory."}
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
return {"file_id": file_id, "rdb_uri": file_info.get("rdb_uri"), "name": file_info.get("name"), "content": text[:32000], "truncated": len(text) > 32000}
|
||||
|
||||
if tool_name == "retaindb_ingest_file":
|
||||
file_id = args.get("file_id", "")
|
||||
if not file_id:
|
||||
return {"error": "file_id is required"}
|
||||
return c.ingest_file(file_id, user_id=self._user_id, agent_id=self._agent_id)
|
||||
|
||||
if tool_name == "retaindb_delete_file":
|
||||
file_id = args.get("file_id", "")
|
||||
if not file_id:
|
||||
return {"error": "file_id is required"}
|
||||
return c.delete_file(file_id)
|
||||
|
||||
return {"error": f"Unknown tool: {tool_name}"}
|
||||
|
||||
# ── Optional hooks ─────────────────────────────────────────────────────
|
||||
|
||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
||||
if action == "add":
|
||||
try:
|
||||
self._api("POST", "/v1/remember", json={
|
||||
"project": self._project,
|
||||
"user_id": self._user_id,
|
||||
"content": content,
|
||||
"memory_type": "preference" if target == "user" else "fact",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("RetainDB memory bridge failed: %s", e)
|
||||
"""Mirror built-in memory writes to RetainDB."""
|
||||
if action != "add" or not content or not self._client:
|
||||
return
|
||||
try:
|
||||
memory_type = "preference" if target == "user" else "factual"
|
||||
self._client.add_memory(self._user_id, self._session_id, content, memory_type=memory_type)
|
||||
except Exception as exc:
|
||||
logger.debug("RetainDB memory mirror failed: %s", exc)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
for t in (self._prefetch_thread, self._sync_thread):
|
||||
if t and t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
for t in self._prefetch_threads:
|
||||
t.join(timeout=3.0)
|
||||
if self._queue:
|
||||
self._queue.shutdown()
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
|
||||
776
tests/plugins/test_retaindb_plugin.py
Normal file
776
tests/plugins/test_retaindb_plugin.py
Normal file
@@ -0,0 +1,776 @@
|
||||
"""Tests for the RetainDB memory plugin.
|
||||
|
||||
Covers: _Client HTTP client, _WriteQueue SQLite queue, _build_overlay formatter,
|
||||
RetainDBMemoryProvider lifecycle/tools/prefetch, thread management, connection pooling.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Imports — guarded since plugins/memory lives outside the standard test path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_env(tmp_path, monkeypatch):
|
||||
"""Ensure HERMES_HOME and RETAINDB vars are isolated."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("RETAINDB_API_KEY", raising=False)
|
||||
monkeypatch.delenv("RETAINDB_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("RETAINDB_PROJECT", raising=False)
|
||||
|
||||
|
||||
# We need the repo root on sys.path so the plugin can import agent.memory_provider
|
||||
import sys
|
||||
_repo_root = str(Path(__file__).resolve().parents[2])
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
|
||||
from plugins.memory.retaindb import (
|
||||
_Client,
|
||||
_WriteQueue,
|
||||
_build_overlay,
|
||||
RetainDBMemoryProvider,
|
||||
_ASYNC_SHUTDOWN,
|
||||
_DEFAULT_BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _Client tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestClient:
|
||||
"""Test the HTTP client with mocked requests."""
|
||||
|
||||
def _make_client(self, api_key="rdb-test-key", base_url="https://api.retaindb.com", project="test"):
|
||||
return _Client(api_key, base_url, project)
|
||||
|
||||
def test_base_url_trailing_slash_stripped(self):
|
||||
c = self._make_client(base_url="https://api.retaindb.com///")
|
||||
assert c.base_url == "https://api.retaindb.com"
|
||||
|
||||
def test_headers_include_auth(self):
|
||||
c = self._make_client()
|
||||
h = c._headers("/v1/files")
|
||||
assert h["Authorization"] == "Bearer rdb-test-key"
|
||||
assert "X-API-Key" not in h
|
||||
|
||||
def test_headers_include_api_key_for_memory_path(self):
|
||||
c = self._make_client()
|
||||
h = c._headers("/v1/memory/search")
|
||||
assert h["X-API-Key"] == "rdb-test-key"
|
||||
|
||||
def test_headers_include_api_key_for_context_path(self):
|
||||
c = self._make_client()
|
||||
h = c._headers("/v1/context/query")
|
||||
assert h["X-API-Key"] == "rdb-test-key"
|
||||
|
||||
def test_headers_strip_bearer_prefix(self):
|
||||
c = self._make_client(api_key="Bearer rdb-test-key")
|
||||
h = c._headers("/v1/memory/search")
|
||||
assert h["Authorization"] == "Bearer rdb-test-key"
|
||||
assert h["X-API-Key"] == "rdb-test-key"
|
||||
|
||||
def test_query_context_builds_correct_payload(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"results": []}
|
||||
c.query_context("user1", "sess1", "test query", max_tokens=500)
|
||||
mock_req.assert_called_once_with("POST", "/v1/context/query", json_body={
|
||||
"project": "test",
|
||||
"query": "test query",
|
||||
"user_id": "user1",
|
||||
"session_id": "sess1",
|
||||
"include_memories": True,
|
||||
"max_tokens": 500,
|
||||
})
|
||||
|
||||
def test_search_builds_correct_payload(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"results": []}
|
||||
c.search("user1", "sess1", "find this", top_k=5)
|
||||
mock_req.assert_called_once_with("POST", "/v1/memory/search", json_body={
|
||||
"project": "test",
|
||||
"query": "find this",
|
||||
"user_id": "user1",
|
||||
"session_id": "sess1",
|
||||
"top_k": 5,
|
||||
"include_pending": True,
|
||||
})
|
||||
|
||||
def test_add_memory_tries_fallback(self):
|
||||
c = self._make_client()
|
||||
call_count = 0
|
||||
def fake_request(method, path, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("404")
|
||||
return {"id": "mem-1"}
|
||||
|
||||
with patch.object(c, "request", side_effect=fake_request):
|
||||
result = c.add_memory("u1", "s1", "test fact")
|
||||
assert result == {"id": "mem-1"}
|
||||
assert call_count == 2
|
||||
|
||||
def test_delete_memory_tries_fallback(self):
|
||||
c = self._make_client()
|
||||
call_count = 0
|
||||
def fake_request(method, path, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("404")
|
||||
return {"deleted": True}
|
||||
|
||||
with patch.object(c, "request", side_effect=fake_request):
|
||||
result = c.delete_memory("mem-123")
|
||||
assert result == {"deleted": True}
|
||||
assert call_count == 2
|
||||
|
||||
def test_ingest_session_payload(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"status": "ok"}
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
c.ingest_session("u1", "s1", msgs, timeout=10.0)
|
||||
mock_req.assert_called_once_with("POST", "/v1/memory/ingest/session", json_body={
|
||||
"project": "test",
|
||||
"session_id": "s1",
|
||||
"user_id": "u1",
|
||||
"messages": msgs,
|
||||
"write_mode": "sync",
|
||||
}, timeout=10.0)
|
||||
|
||||
def test_ask_user_payload(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"answer": "test answer"}
|
||||
c.ask_user("u1", "who am i?", reasoning_level="medium")
|
||||
mock_req.assert_called_once()
|
||||
call_kwargs = mock_req.call_args
|
||||
assert call_kwargs[1]["json_body"]["reasoning_level"] == "medium"
|
||||
|
||||
def test_get_agent_model_path(self):
|
||||
c = self._make_client()
|
||||
with patch.object(c, "request") as mock_req:
|
||||
mock_req.return_value = {"memory_count": 3}
|
||||
c.get_agent_model("hermes")
|
||||
mock_req.assert_called_once_with(
|
||||
"GET", "/v1/memory/agent/hermes/model",
|
||||
params={"project": "test"}, timeout=4.0
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _WriteQueue tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestWriteQueue:
|
||||
"""Test the SQLite-backed write queue with real SQLite."""
|
||||
|
||||
def _make_queue(self, tmp_path, client=None):
|
||||
if client is None:
|
||||
client = MagicMock()
|
||||
client.ingest_session = MagicMock(return_value={"status": "ok"})
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
return _WriteQueue(client, db_path), client, db_path
|
||||
|
||||
def test_enqueue_creates_row(self, tmp_path):
|
||||
q, client, db_path = self._make_queue(tmp_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
# Give the writer thread a moment to process
|
||||
time.sleep(1)
|
||||
q.shutdown()
|
||||
# If ingest succeeded, the row should be deleted
|
||||
client.ingest_session.assert_called_once()
|
||||
|
||||
def test_enqueue_persists_to_sqlite(self, tmp_path):
|
||||
client = MagicMock()
|
||||
# Make ingest hang so the row stays in SQLite
|
||||
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5))
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
q = _WriteQueue(client, db_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}])
|
||||
# Check SQLite directly — row should exist since flush is slow
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
rows = conn.execute("SELECT user_id, session_id FROM pending").fetchall()
|
||||
conn.close()
|
||||
assert len(rows) >= 1
|
||||
assert rows[0][0] == "user1"
|
||||
q.shutdown()
|
||||
|
||||
def test_flush_deletes_row_on_success(self, tmp_path):
|
||||
q, client, db_path = self._make_queue(tmp_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
time.sleep(1)
|
||||
q.shutdown()
|
||||
# Row should be gone
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0]
|
||||
conn.close()
|
||||
assert rows == 0
|
||||
|
||||
def test_flush_records_error_on_failure(self, tmp_path):
|
||||
client = MagicMock()
|
||||
client.ingest_session = MagicMock(side_effect=RuntimeError("API down"))
|
||||
db_path = tmp_path / "test_queue.db"
|
||||
q = _WriteQueue(client, db_path)
|
||||
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
|
||||
time.sleep(3) # Allow retry + sleep(2) in _flush_row
|
||||
q.shutdown()
|
||||
# Row should still exist with error recorded
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
row = conn.execute("SELECT last_error FROM pending").fetchone()
|
||||
conn.close()
|
||||
assert row is not None
|
||||
assert "API down" in row[0]
|
||||
|
||||
def test_thread_local_connection_reuse(self, tmp_path):
|
||||
q, _, _ = self._make_queue(tmp_path)
|
||||
# Same thread should get same connection
|
||||
conn1 = q._get_conn()
|
||||
conn2 = q._get_conn()
|
||||
assert conn1 is conn2
|
||||
q.shutdown()
|
||||
|
||||
def test_crash_recovery_replays_pending(self, tmp_path):
|
||||
"""Simulate crash: create rows, then new queue should replay them."""
|
||||
db_path = tmp_path / "recovery_test.db"
|
||||
# First: create a queue and insert rows, but don't let them flush
|
||||
client1 = MagicMock()
|
||||
client1.ingest_session = MagicMock(side_effect=RuntimeError("fail"))
|
||||
q1 = _WriteQueue(client1, db_path)
|
||||
q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}])
|
||||
time.sleep(3)
|
||||
q1.shutdown()
|
||||
|
||||
# Now create a new queue — it should replay the pending rows
|
||||
client2 = MagicMock()
|
||||
client2.ingest_session = MagicMock(return_value={"status": "ok"})
|
||||
q2 = _WriteQueue(client2, db_path)
|
||||
time.sleep(2)
|
||||
q2.shutdown()
|
||||
|
||||
# The replayed row should have been ingested via client2
|
||||
client2.ingest_session.assert_called_once()
|
||||
call_args = client2.ingest_session.call_args
|
||||
assert call_args[0][0] == "user1" # user_id
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _build_overlay tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestBuildOverlay:
|
||||
"""Test the overlay formatter (pure function)."""
|
||||
|
||||
def test_empty_inputs_returns_empty(self):
|
||||
assert _build_overlay({}, {}) == ""
|
||||
|
||||
def test_empty_memories_returns_empty(self):
|
||||
assert _build_overlay({"memories": []}, {"results": []}) == ""
|
||||
|
||||
def test_profile_items_included(self):
|
||||
profile = {"memories": [{"content": "User likes Python"}]}
|
||||
result = _build_overlay(profile, {})
|
||||
assert "User likes Python" in result
|
||||
assert "[RetainDB Context]" in result
|
||||
|
||||
def test_query_results_included(self):
|
||||
query_result = {"results": [{"content": "Previous discussion about Rust"}]}
|
||||
result = _build_overlay({}, query_result)
|
||||
assert "Previous discussion about Rust" in result
|
||||
|
||||
def test_deduplication_removes_duplicates(self):
|
||||
profile = {"memories": [{"content": "User likes Python"}]}
|
||||
query_result = {"results": [{"content": "User likes Python"}]}
|
||||
result = _build_overlay(profile, query_result)
|
||||
assert result.count("User likes Python") == 1
|
||||
|
||||
def test_local_entries_filter(self):
|
||||
profile = {"memories": [{"content": "Already known fact"}]}
|
||||
result = _build_overlay(profile, {}, local_entries=["Already known fact"])
|
||||
# The profile item matches a local entry, should be filtered
|
||||
assert result == ""
|
||||
|
||||
def test_max_five_items_per_section(self):
|
||||
profile = {"memories": [{"content": f"Fact {i}"} for i in range(10)]}
|
||||
result = _build_overlay(profile, {})
|
||||
# Should only include first 5
|
||||
assert "Fact 0" in result
|
||||
assert "Fact 4" in result
|
||||
assert "Fact 5" not in result
|
||||
|
||||
def test_none_content_handled(self):
|
||||
profile = {"memories": [{"content": None}, {"content": "Real fact"}]}
|
||||
result = _build_overlay(profile, {})
|
||||
assert "Real fact" in result
|
||||
|
||||
def test_truncation_at_320_chars(self):
|
||||
long_content = "x" * 500
|
||||
profile = {"memories": [{"content": long_content}]}
|
||||
result = _build_overlay(profile, {})
|
||||
# Each item is compacted to 320 chars max
|
||||
for line in result.split("\n"):
|
||||
if line.startswith("- "):
|
||||
assert len(line) <= 322 # "- " + 320
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# RetainDBMemoryProvider tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestRetainDBMemoryProvider:
|
||||
"""Test the main plugin class."""
|
||||
|
||||
def _make_provider(self, tmp_path, monkeypatch, api_key="rdb-test-key"):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", api_key)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
(tmp_path / ".hermes").mkdir(exist_ok=True)
|
||||
provider = RetainDBMemoryProvider()
|
||||
return provider
|
||||
|
||||
def test_name(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
assert p.name == "retaindb"
|
||||
|
||||
def test_is_available_without_key(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
assert p.is_available() is False
|
||||
|
||||
def test_is_available_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test")
|
||||
p = RetainDBMemoryProvider()
|
||||
assert p.is_available() is True
|
||||
|
||||
def test_config_schema(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
schema = p.get_config_schema()
|
||||
assert len(schema) == 3
|
||||
keys = [s["key"] for s in schema]
|
||||
assert "api_key" in keys
|
||||
assert "base_url" in keys
|
||||
assert "project" in keys
|
||||
|
||||
def test_initialize_creates_client_and_queue(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
assert p._client is not None
|
||||
assert p._queue is not None
|
||||
assert p._session_id == "test-session"
|
||||
p.shutdown()
|
||||
|
||||
def test_initialize_default_project(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
assert p._client.project == "default"
|
||||
p.shutdown()
|
||||
|
||||
def test_initialize_explicit_project(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_PROJECT", "my-project")
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
assert p._client.project == "my-project"
|
||||
p.shutdown()
|
||||
|
||||
def test_initialize_profile_project(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
profile_home = str(tmp_path / "profiles" / "coder")
|
||||
p.initialize("test-session", hermes_home=profile_home)
|
||||
assert p._client.project == "hermes-coder"
|
||||
p.shutdown()
|
||||
|
||||
def test_initialize_seeds_soul_md(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
soul_path = tmp_path / ".hermes" / "SOUL.md"
|
||||
soul_path.write_text("I am a helpful agent.")
|
||||
with patch.object(RetainDBMemoryProvider, "_seed_soul") as mock_seed:
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
# Give thread time to start
|
||||
time.sleep(0.5)
|
||||
mock_seed.assert_called_once_with("I am a helpful agent.")
|
||||
p.shutdown()
|
||||
|
||||
def test_system_prompt_block(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
block = p.system_prompt_block()
|
||||
assert "RetainDB Memory" in block
|
||||
assert "Active" in block
|
||||
p.shutdown()
|
||||
|
||||
def test_tool_schemas_count(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
schemas = p.get_tool_schemas()
|
||||
assert len(schemas) == 10 # 5 memory + 5 file tools
|
||||
names = [s["name"] for s in schemas]
|
||||
assert "retaindb_profile" in names
|
||||
assert "retaindb_search" in names
|
||||
assert "retaindb_context" in names
|
||||
assert "retaindb_remember" in names
|
||||
assert "retaindb_forget" in names
|
||||
assert "retaindb_upload_file" in names
|
||||
assert "retaindb_list_files" in names
|
||||
assert "retaindb_read_file" in names
|
||||
assert "retaindb_ingest_file" in names
|
||||
assert "retaindb_delete_file" in names
|
||||
|
||||
def test_handle_tool_call_not_initialized(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
|
||||
assert "error" in result
|
||||
assert "not initialized" in result["error"]
|
||||
|
||||
def test_handle_tool_call_unknown_tool(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_nonexistent", {}))
|
||||
assert result == {"error": "Unknown tool: retaindb_nonexistent"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_profile(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "get_profile", return_value={"memories": []}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
|
||||
assert "memories" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_search_requires_query(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_search", {}))
|
||||
assert result == {"error": "query is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_search(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "search", return_value={"results": [{"content": "found"}]}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_search", {"query": "test"}))
|
||||
assert "results" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_search_top_k_capped(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "search") as mock_search:
|
||||
mock_search.return_value = {"results": []}
|
||||
p.handle_tool_call("retaindb_search", {"query": "test", "top_k": 100})
|
||||
# top_k should be capped at 20
|
||||
assert mock_search.call_args[1]["top_k"] == 20
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_remember(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_remember", {"content": "test fact"}))
|
||||
assert result["id"] == "mem-1"
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_remember_requires_content(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_remember", {}))
|
||||
assert result == {"error": "content is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_forget(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "delete_memory", return_value={"deleted": True}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_forget", {"memory_id": "mem-1"}))
|
||||
assert result["deleted"] is True
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_forget_requires_id(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_forget", {}))
|
||||
assert result == {"error": "memory_id is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_context(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "query_context", return_value={"results": [{"content": "relevant"}]}), \
|
||||
patch.object(p._client, "get_profile", return_value={"memories": []}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_context", {"query": "current task"}))
|
||||
assert "context" in result
|
||||
assert "raw" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_list(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "list_files", return_value={"files": []}):
|
||||
result = json.loads(p.handle_tool_call("retaindb_list_files", {}))
|
||||
assert "files" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_upload_missing_path(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_upload_file", {}))
|
||||
assert "error" in result
|
||||
|
||||
def test_dispatch_file_upload_not_found(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_upload_file", {"local_path": "/nonexistent/file.txt"}))
|
||||
assert "File not found" in result["error"]
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_read_requires_id(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_read_file", {}))
|
||||
assert result == {"error": "file_id is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_ingest_requires_id(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_ingest_file", {}))
|
||||
assert result == {"error": "file_id is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_dispatch_file_delete_requires_id(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
result = json.loads(p.handle_tool_call("retaindb_delete_file", {}))
|
||||
assert result == {"error": "file_id is required"}
|
||||
p.shutdown()
|
||||
|
||||
def test_handle_tool_call_wraps_exception(self, tmp_path, monkeypatch):
|
||||
p = self._make_provider(tmp_path, monkeypatch)
|
||||
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
|
||||
with patch.object(p._client, "get_profile", side_effect=RuntimeError("API exploded")):
|
||||
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
|
||||
assert "API exploded" in result["error"]
|
||||
p.shutdown()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Prefetch and thread management tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestPrefetch:
|
||||
"""Test background prefetch and thread accumulation prevention."""
|
||||
|
||||
def _make_initialized_provider(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
return p
|
||||
|
||||
def test_queue_prefetch_skips_without_client(self):
|
||||
p = RetainDBMemoryProvider()
|
||||
p.queue_prefetch("test") # Should not raise
|
||||
|
||||
def test_prefetch_returns_empty_when_nothing_cached(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
result = p.prefetch("test")
|
||||
assert result == ""
|
||||
p.shutdown()
|
||||
|
||||
def test_prefetch_consumes_context_result(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
# Manually set the cached result
|
||||
with p._lock:
|
||||
p._context_result = "[RetainDB Context]\nProfile:\n- User likes tests"
|
||||
result = p.prefetch("test")
|
||||
assert "User likes tests" in result
|
||||
# Should be consumed
|
||||
assert p.prefetch("test") == ""
|
||||
p.shutdown()
|
||||
|
||||
def test_prefetch_consumes_dialectic_result(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
with p._lock:
|
||||
p._dialectic_result = "User is a software engineer who prefers Python."
|
||||
result = p.prefetch("test")
|
||||
assert "[RetainDB User Synthesis]" in result
|
||||
assert "software engineer" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_prefetch_consumes_agent_model(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
with p._lock:
|
||||
p._agent_model = {
|
||||
"memory_count": 5,
|
||||
"persona": "Helpful coding assistant",
|
||||
"persistent_instructions": ["Be concise", "Use Python"],
|
||||
"working_style": "Direct and efficient",
|
||||
}
|
||||
result = p.prefetch("test")
|
||||
assert "[RetainDB Agent Self-Model]" in result
|
||||
assert "Helpful coding assistant" in result
|
||||
assert "Be concise" in result
|
||||
assert "Direct and efficient" in result
|
||||
p.shutdown()
|
||||
|
||||
def test_prefetch_skips_empty_agent_model(self, tmp_path, monkeypatch):
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
with p._lock:
|
||||
p._agent_model = {"memory_count": 0}
|
||||
result = p.prefetch("test")
|
||||
assert "Agent Self-Model" not in result
|
||||
p.shutdown()
|
||||
|
||||
def test_thread_accumulation_guard(self, tmp_path, monkeypatch):
|
||||
"""Verify old prefetch threads are joined before new ones spawn."""
|
||||
p = self._make_initialized_provider(tmp_path, monkeypatch)
|
||||
# Mock the prefetch methods to be slow
|
||||
with patch.object(p, "_prefetch_context", side_effect=lambda q: time.sleep(0.5)), \
|
||||
patch.object(p, "_prefetch_dialectic", side_effect=lambda q: time.sleep(0.5)), \
|
||||
patch.object(p, "_prefetch_agent_model", side_effect=lambda: time.sleep(0.5)):
|
||||
p.queue_prefetch("query 1")
|
||||
first_threads = list(p._prefetch_threads)
|
||||
assert len(first_threads) == 3
|
||||
|
||||
# Call again — should join first batch before spawning new
|
||||
p.queue_prefetch("query 2")
|
||||
second_threads = list(p._prefetch_threads)
|
||||
assert len(second_threads) == 3
|
||||
# Should be different thread objects
|
||||
for t in second_threads:
|
||||
assert t not in first_threads
|
||||
p.shutdown()
|
||||
|
||||
def test_reasoning_level_short(self):
|
||||
assert RetainDBMemoryProvider._reasoning_level("hi") == "low"
|
||||
|
||||
def test_reasoning_level_medium(self):
|
||||
assert RetainDBMemoryProvider._reasoning_level("x" * 200) == "medium"
|
||||
|
||||
def test_reasoning_level_long(self):
|
||||
assert RetainDBMemoryProvider._reasoning_level("x" * 500) == "high"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# sync_turn tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestSyncTurn:
|
||||
"""Test turn synchronization via the write queue."""
|
||||
|
||||
def test_sync_turn_enqueues(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._queue, "enqueue") as mock_enqueue:
|
||||
p.sync_turn("user msg", "assistant msg")
|
||||
mock_enqueue.assert_called_once()
|
||||
args = mock_enqueue.call_args[0]
|
||||
assert args[0] == "default" # user_id
|
||||
assert args[1] == "test-session" # session_id
|
||||
msgs = args[2]
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["role"] == "user"
|
||||
assert msgs[1]["role"] == "assistant"
|
||||
p.shutdown()
|
||||
|
||||
def test_sync_turn_skips_empty_user_content(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._queue, "enqueue") as mock_enqueue:
|
||||
p.sync_turn("", "assistant msg")
|
||||
mock_enqueue.assert_not_called()
|
||||
p.shutdown()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# on_memory_write hook tests
|
||||
# ===========================================================================
|
||||
|
||||
class TestOnMemoryWrite:
|
||||
"""Test the built-in memory mirror hook."""
|
||||
|
||||
def test_mirrors_add_action(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add:
|
||||
p.on_memory_write("add", "user", "User prefers dark mode")
|
||||
mock_add.assert_called_once()
|
||||
assert mock_add.call_args[1]["memory_type"] == "preference"
|
||||
p.shutdown()
|
||||
|
||||
def test_skips_non_add_action(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._client, "add_memory") as mock_add:
|
||||
p.on_memory_write("remove", "user", "something")
|
||||
mock_add.assert_not_called()
|
||||
p.shutdown()
|
||||
|
||||
def test_skips_empty_content(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._client, "add_memory") as mock_add:
|
||||
p.on_memory_write("add", "user", "")
|
||||
mock_add.assert_not_called()
|
||||
p.shutdown()
|
||||
|
||||
def test_memory_target_maps_to_type(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir(exist_ok=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
p = RetainDBMemoryProvider()
|
||||
p.initialize("test-session", hermes_home=str(hermes_home))
|
||||
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add:
|
||||
p.on_memory_write("add", "memory", "Some env fact")
|
||||
assert mock_add.call_args[1]["memory_type"] == "factual"
|
||||
p.shutdown()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# register() test
|
||||
# ===========================================================================
|
||||
|
||||
class TestRegister:
|
||||
def test_register_calls_register_memory_provider(self):
|
||||
from plugins.memory.retaindb import register
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
ctx.register_memory_provider.assert_called_once()
|
||||
arg = ctx.register_memory_provider.call_args[0][0]
|
||||
assert isinstance(arg, RetainDBMemoryProvider)
|
||||
Reference in New Issue
Block a user