Compare commits

...

3 Commits

Author SHA1 Message Date
Teknium
ec6daae0e1 fix: follow-up improvements for salvaged PR #5456
- SQLite write queue: thread-local connection pooling instead of
  creating+closing a new connection per operation
- Prefetch threads: join previous batch before spawning new ones to
  prevent thread accumulation on rapid queue_prefetch() calls
- Shutdown: join prefetch threads before stopping write queue
- Add 73 tests covering _Client HTTP payloads, _WriteQueue crash
  recovery & connection reuse, _build_overlay deduplication,
  RetainDBMemoryProvider lifecycle/tools/prefetch/hooks, thread
  accumulation guard, and reasoning_level heuristic
2026-04-06 00:49:27 -07:00
Alinxus
b3e406635f fix(retaindb): make project optional, default to 'default' project 2026-04-06 00:44:09 -07:00
Alinxus
31764e175c fix(retaindb): fix API routes, add write queue, dialectic, agent model, file tools
The previous implementation hit endpoints that do not exist on the RetainDB
API (/v1/recall, /v1/ingest, /v1/remember, /v1/search, /v1/profile/:p/:u).
Every operation was silently failing with 404. This rewrites the plugin against
the real API surface and adds several new capabilities.

API route fixes:
- Context query: POST /v1/context/query (was /v1/recall)
- Session ingest: POST /v1/memory/ingest/session (was /v1/ingest)
- Memory write: POST /v1/memory with legacy fallback to /v1/memories (was /v1/remember)
- Memory search: POST /v1/memory/search (was /v1/search)
- User profile: GET /v1/memory/profile/:userId (was /v1/profile/:project/:userId)
- Memory delete: DELETE /v1/memory/:id with fallback (was /v1/memory/:id, wrong base)

Durable write-behind queue:
- SQLite spool at ~/.hermes/retaindb_queue.db
- Turn ingest is fully async — zero blocking on the hot path
- Pending rows replay automatically on restart after a crash
- Per-row error marking with retry backoff

Background prefetch (fires at turn-end, ready for next turn-start):
- Context: profile + semantic query, deduped overlay block
- Dialectic synthesis: LLM-powered synthesis of what is known about the
  user for the current query, with dynamic reasoning level based on
  message length (low / medium / high)
- Agent self-model: persona, persistent instructions, working style
  derived from AGENT-scoped memories
- All three run in parallel daemon threads, consumed atomically at
  turn-start within the prefetch timeout budget

Agent identity seeding:
- SOUL.md content ingested as AGENT-scoped memories on startup
- Enables persistent cross-session agent self-knowledge

Shared file store tools (new):
- retaindb_upload_file: upload local file, optional auto-ingest
- retaindb_list_files: directory listing with prefix filter
- retaindb_read_file: fetch and decode text content
- retaindb_ingest_file: chunk + embed + extract memories from stored file
- retaindb_delete_file: soft delete

Built-in memory mirror:
- on_memory_write() now hits the correct write endpoint
2026-04-06 00:44:09 -07:00
2 changed files with 1406 additions and 167 deletions

View File

@@ -1,29 +1,45 @@
"""RetainDB memory plugin — MemoryProvider interface.
Cross-session memory via RetainDB cloud API. Durable write-behind queue,
semantic search with deduplication, and user profile retrieval.
Cross-session memory via RetainDB cloud API.
Original PR #2732 by Alinxus, adapted to MemoryProvider ABC.
Features:
- Correct API routes for all operations
- Durable SQLite write-behind queue (crash-safe, async ingest)
- Semantic search + user profile retrieval
- Context query with deduplication overlay
- Dialectic synthesis (LLM-powered user understanding, prefetched each turn)
- Agent self-model (persona + instructions from SOUL.md, prefetched each turn)
- Shared file store tools (upload, list, read, ingest, delete)
- Explicit memory tools (profile, search, context, remember, forget)
Config via environment variables:
RETAINDB_API_KEY — API key (required)
RETAINDB_BASE_URL — API endpoint (default: https://api.retaindb.com)
RETAINDB_PROJECT — Project identifier (default: hermes)
Config (env vars or hermes config.yaml under retaindb:):
RETAINDB_API_KEY — API key (required)
RETAINDB_BASE_URL — API endpoint (default: https://api.retaindb.com)
RETAINDB_PROJECT — Project identifier (optional — defaults to "default")
"""
from __future__ import annotations
import hashlib
import json
import logging
import os
import queue
import re
import sqlite3
import threading
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List
from urllib.parse import quote
from agent.memory_provider import MemoryProvider
logger = logging.getLogger(__name__)
_DEFAULT_BASE_URL = "https://api.retaindb.com"
_ASYNC_SHUTDOWN = object()
# ---------------------------------------------------------------------------
@@ -32,16 +48,13 @@ _DEFAULT_BASE_URL = "https://api.retaindb.com"
PROFILE_SCHEMA = {
"name": "retaindb_profile",
"description": "Get the user's stable profile — preferences, facts, and patterns.",
"description": "Get the user's stable profile — preferences, facts, and patterns recalled from long-term memory.",
"parameters": {"type": "object", "properties": {}, "required": []},
}
SEARCH_SCHEMA = {
"name": "retaindb_search",
"description": (
"Semantic search across stored memories. Returns ranked results "
"with relevance scores."
),
"description": "Semantic search across stored memories. Returns ranked results with relevance scores.",
"parameters": {
"type": "object",
"properties": {
@@ -54,7 +67,7 @@ SEARCH_SCHEMA = {
CONTEXT_SCHEMA = {
"name": "retaindb_context",
"description": "Synthesized 'what matters now' context block for the current task.",
"description": "Synthesized context block — what matters most for the current task, pulled from long-term memory.",
"parameters": {
"type": "object",
"properties": {
@@ -66,20 +79,17 @@ CONTEXT_SCHEMA = {
REMEMBER_SCHEMA = {
"name": "retaindb_remember",
"description": "Persist an explicit fact or preference to long-term memory.",
"description": "Persist an explicit fact, preference, or decision to long-term memory.",
"parameters": {
"type": "object",
"properties": {
"content": {"type": "string", "description": "The fact to remember."},
"memory_type": {
"type": "string",
"enum": ["preference", "fact", "decision", "context"],
"description": "Category (default: fact).",
},
"importance": {
"type": "number",
"description": "Importance 0-1 (default: 0.5).",
"enum": ["factual", "preference", "goal", "instruction", "event", "opinion"],
"description": "Category (default: factual).",
},
"importance": {"type": "number", "description": "Importance 0-1 (default: 0.7)."},
},
"required": ["content"],
},
@@ -97,23 +107,368 @@ FORGET_SCHEMA = {
},
}
FILE_UPLOAD_SCHEMA = {
"name": "retaindb_upload_file",
"description": "Upload a file to the shared RetainDB file store. Returns an rdb:// URI any agent can reference.",
"parameters": {
"type": "object",
"properties": {
"local_path": {"type": "string", "description": "Local file path to upload."},
"remote_path": {"type": "string", "description": "Destination path, e.g. /reports/q1.pdf"},
"scope": {"type": "string", "enum": ["USER", "PROJECT", "ORG"], "description": "Access scope (default: PROJECT)."},
"ingest": {"type": "boolean", "description": "Also extract memories from file after upload (default: false)."},
},
"required": ["local_path"],
},
}
FILE_LIST_SCHEMA = {
"name": "retaindb_list_files",
"description": "List files in the shared file store.",
"parameters": {
"type": "object",
"properties": {
"prefix": {"type": "string", "description": "Path prefix to filter by, e.g. /reports/"},
"limit": {"type": "integer", "description": "Max results (default: 50)."},
},
"required": [],
},
}
FILE_READ_SCHEMA = {
"name": "retaindb_read_file",
"description": "Read the text content of a stored file by its file ID.",
"parameters": {
"type": "object",
"properties": {
"file_id": {"type": "string", "description": "File ID returned from upload or list."},
},
"required": ["file_id"],
},
}
FILE_INGEST_SCHEMA = {
"name": "retaindb_ingest_file",
"description": "Chunk, embed, and extract memories from a stored file. Makes its contents searchable.",
"parameters": {
"type": "object",
"properties": {
"file_id": {"type": "string", "description": "File ID to ingest."},
},
"required": ["file_id"],
},
}
FILE_DELETE_SCHEMA = {
"name": "retaindb_delete_file",
"description": "Delete a stored file.",
"parameters": {
"type": "object",
"properties": {
"file_id": {"type": "string", "description": "File ID to delete."},
},
"required": ["file_id"],
},
}
# ---------------------------------------------------------------------------
# MemoryProvider implementation
# HTTP client
# ---------------------------------------------------------------------------
class _Client:
def __init__(self, api_key: str, base_url: str, project: str):
self.api_key = api_key
self.base_url = re.sub(r"/+$", "", base_url)
self.project = project
def _headers(self, path: str) -> dict:
token = self.api_key.replace("Bearer ", "").strip()
h = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
"x-sdk-runtime": "hermes-plugin",
}
if path.startswith("/v1/memory") or path.startswith("/v1/context"):
h["X-API-Key"] = token
return h
def request(self, method: str, path: str, *, params=None, json_body=None, timeout: float = 8.0) -> Any:
import requests
url = f"{self.base_url}{path}"
resp = requests.request(
method.upper(), url,
params=params,
json=json_body if method.upper() not in {"GET", "DELETE"} else None,
headers=self._headers(path),
timeout=timeout,
)
try:
payload = resp.json()
except Exception:
payload = resp.text
if not resp.ok:
msg = ""
if isinstance(payload, dict):
msg = str(payload.get("message") or payload.get("error") or "")
raise RuntimeError(f"RetainDB {method} {path} failed ({resp.status_code}): {msg or payload}")
return payload
# ── Memory ────────────────────────────────────────────────────────────────
def query_context(self, user_id: str, session_id: str, query: str, max_tokens: int = 1200) -> dict:
return self.request("POST", "/v1/context/query", json_body={
"project": self.project,
"query": query,
"user_id": user_id,
"session_id": session_id,
"include_memories": True,
"max_tokens": max_tokens,
})
def search(self, user_id: str, session_id: str, query: str, top_k: int = 8) -> dict:
return self.request("POST", "/v1/memory/search", json_body={
"project": self.project,
"query": query,
"user_id": user_id,
"session_id": session_id,
"top_k": top_k,
"include_pending": True,
})
def get_profile(self, user_id: str) -> dict:
try:
return self.request("GET", f"/v1/memory/profile/{quote(user_id, safe='')}", params={"project": self.project, "include_pending": "true"})
except Exception:
return self.request("GET", "/v1/memories", params={"project": self.project, "user_id": user_id, "limit": "200"})
def add_memory(self, user_id: str, session_id: str, content: str, memory_type: str = "factual", importance: float = 0.7) -> dict:
try:
return self.request("POST", "/v1/memory", json_body={
"project": self.project, "content": content, "memory_type": memory_type,
"user_id": user_id, "session_id": session_id, "importance": importance, "write_mode": "sync",
}, timeout=5.0)
except Exception:
return self.request("POST", "/v1/memories", json_body={
"project": self.project, "content": content, "memory_type": memory_type,
"user_id": user_id, "session_id": session_id, "importance": importance,
}, timeout=5.0)
def delete_memory(self, memory_id: str) -> dict:
try:
return self.request("DELETE", f"/v1/memory/{quote(memory_id, safe='')}", timeout=5.0)
except Exception:
return self.request("DELETE", f"/v1/memories/{quote(memory_id, safe='')}", timeout=5.0)
def ingest_session(self, user_id: str, session_id: str, messages: list, timeout: float = 15.0) -> dict:
return self.request("POST", "/v1/memory/ingest/session", json_body={
"project": self.project, "session_id": session_id, "user_id": user_id,
"messages": messages, "write_mode": "sync",
}, timeout=timeout)
def ask_user(self, user_id: str, query: str, reasoning_level: str = "low") -> dict:
return self.request("POST", f"/v1/memory/profile/{quote(user_id, safe='')}/ask", json_body={
"project": self.project, "query": query, "reasoning_level": reasoning_level,
}, timeout=8.0)
def get_agent_model(self, agent_id: str) -> dict:
return self.request("GET", f"/v1/memory/agent/{quote(agent_id, safe='')}/model", params={"project": self.project}, timeout=4.0)
def seed_agent_identity(self, agent_id: str, content: str, source: str = "soul_md") -> dict:
return self.request("POST", f"/v1/memory/agent/{quote(agent_id, safe='')}/seed", json_body={
"project": self.project, "content": content, "source": source,
}, timeout=20.0)
# ── Files ─────────────────────────────────────────────────────────────────
def upload_file(self, data: bytes, filename: str, remote_path: str, mime_type: str, scope: str, project_id: str | None) -> dict:
import io
import requests
url = f"{self.base_url}/v1/files"
token = self.api_key.replace("Bearer ", "").strip()
headers = {"Authorization": f"Bearer {token}", "x-sdk-runtime": "hermes-plugin"}
fields = {"path": remote_path, "scope": scope.upper()}
if project_id:
fields["project_id"] = project_id
resp = requests.post(url, files={"file": (filename, io.BytesIO(data), mime_type)}, data=fields, headers=headers, timeout=30)
resp.raise_for_status()
return resp.json()
def list_files(self, prefix: str | None = None, limit: int = 50) -> dict:
params: dict = {"limit": limit}
if prefix:
params["prefix"] = prefix
return self.request("GET", "/v1/files", params=params)
def get_file(self, file_id: str) -> dict:
return self.request("GET", f"/v1/files/{quote(file_id, safe='')}")
def read_file_content(self, file_id: str) -> bytes:
import requests
token = self.api_key.replace("Bearer ", "").strip()
url = f"{self.base_url}/v1/files/{quote(file_id, safe='')}/content"
resp = requests.get(url, headers={"Authorization": f"Bearer {token}", "x-sdk-runtime": "hermes-plugin"}, timeout=30, allow_redirects=True)
resp.raise_for_status()
return resp.content
def ingest_file(self, file_id: str, user_id: str | None = None, agent_id: str | None = None) -> dict:
body: dict = {}
if user_id:
body["user_id"] = user_id
if agent_id:
body["agent_id"] = agent_id
return self.request("POST", f"/v1/files/{quote(file_id, safe='')}/ingest", json_body=body, timeout=60.0)
def delete_file(self, file_id: str) -> dict:
return self.request("DELETE", f"/v1/files/{quote(file_id, safe='')}", timeout=5.0)
# ---------------------------------------------------------------------------
# Durable write-behind queue
# ---------------------------------------------------------------------------
class _WriteQueue:
"""SQLite-backed async write queue. Survives crashes — pending rows replay on startup."""
def __init__(self, client: _Client, db_path: Path):
self._client = client
self._db_path = db_path
self._q: queue.Queue = queue.Queue()
self._thread = threading.Thread(target=self._loop, name="retaindb-writer", daemon=True)
self._db_path.parent.mkdir(parents=True, exist_ok=True)
# Thread-local connection cache — one connection per thread, reused.
self._local = threading.local()
self._init_db()
self._thread.start()
# Replay any rows left from a previous crash
for row_id, user_id, session_id, msgs_json in self._pending_rows():
self._q.put((row_id, user_id, session_id, json.loads(msgs_json)))
def _get_conn(self) -> sqlite3.Connection:
"""Return a cached connection for the current thread."""
conn = getattr(self._local, "conn", None)
if conn is None:
conn = sqlite3.connect(str(self._db_path), timeout=30)
conn.row_factory = sqlite3.Row
self._local.conn = conn
return conn
def _init_db(self) -> None:
conn = self._get_conn()
conn.execute("""CREATE TABLE IF NOT EXISTS pending (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT, session_id TEXT, messages_json TEXT,
created_at TEXT, last_error TEXT
)""")
conn.commit()
def _pending_rows(self) -> list:
conn = self._get_conn()
return conn.execute("SELECT id, user_id, session_id, messages_json FROM pending ORDER BY id ASC LIMIT 200").fetchall()
def enqueue(self, user_id: str, session_id: str, messages: list) -> None:
now = datetime.now(timezone.utc).isoformat()
conn = self._get_conn()
cur = conn.execute(
"INSERT INTO pending (user_id, session_id, messages_json, created_at) VALUES (?,?,?,?)",
(user_id, session_id, json.dumps(messages, ensure_ascii=False), now),
)
row_id = cur.lastrowid
conn.commit()
self._q.put((row_id, user_id, session_id, messages))
def _flush_row(self, row_id: int, user_id: str, session_id: str, messages: list) -> None:
try:
self._client.ingest_session(user_id, session_id, messages)
conn = self._get_conn()
conn.execute("DELETE FROM pending WHERE id = ?", (row_id,))
conn.commit()
except Exception as exc:
logger.warning("RetainDB ingest failed (will retry): %s", exc)
conn = self._get_conn()
conn.execute("UPDATE pending SET last_error = ? WHERE id = ?", (str(exc), row_id))
conn.commit()
time.sleep(2)
def _loop(self) -> None:
while True:
try:
item = self._q.get(timeout=5)
if item is _ASYNC_SHUTDOWN:
break
self._flush_row(*item)
except queue.Empty:
continue
except Exception as exc:
logger.error("RetainDB writer error: %s", exc)
def shutdown(self) -> None:
self._q.put(_ASYNC_SHUTDOWN)
self._thread.join(timeout=10)
# ---------------------------------------------------------------------------
# Overlay formatter
# ---------------------------------------------------------------------------
def _build_overlay(profile: dict, query_result: dict, local_entries: list[str] | None = None) -> str:
def _compact(s: str) -> str:
return re.sub(r"\s+", " ", str(s or "")).strip()[:320]
def _norm(s: str) -> str:
return re.sub(r"[^a-z0-9 ]", "", _compact(s).lower())
seen: list[str] = [_norm(e) for e in (local_entries or []) if _norm(e)]
profile_items: list[str] = []
for m in list((profile or {}).get("memories") or [])[:5]:
c = _compact((m or {}).get("content") or "")
n = _norm(c)
if c and n not in seen:
seen.append(n)
profile_items.append(c)
query_items: list[str] = []
for r in list((query_result or {}).get("results") or [])[:5]:
c = _compact((r or {}).get("content") or "")
n = _norm(c)
if c and n not in seen:
seen.append(n)
query_items.append(c)
if not profile_items and not query_items:
return ""
lines = ["[RetainDB Context]", "Profile:"]
lines += [f"- {i}" for i in profile_items] or ["- None"]
lines.append("Relevant memories:")
lines += [f"- {i}" for i in query_items] or ["- None"]
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Main plugin class
# ---------------------------------------------------------------------------
class RetainDBMemoryProvider(MemoryProvider):
"""RetainDB cloud memory with write-behind queue and semantic search."""
"""RetainDB cloud memory — durable queue, semantic search, dialectic synthesis, shared files."""
def __init__(self):
self._api_key = ""
self._base_url = _DEFAULT_BASE_URL
self._project = "hermes"
self._user_id = ""
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread = None
self._sync_thread = None
self._client: _Client | None = None
self._queue: _WriteQueue | None = None
self._user_id = "default"
self._session_id = ""
self._agent_id = "hermes"
self._lock = threading.Lock()
# Prefetch caches
self._context_result = ""
self._dialectic_result = ""
self._agent_model: dict = {}
# Prefetch thread tracking — prevents accumulation on rapid calls
self._prefetch_threads: list[threading.Thread] = []
# ── Core identity ──────────────────────────────────────────────────────
@property
def name(self) -> str:
@@ -122,179 +477,287 @@ class RetainDBMemoryProvider(MemoryProvider):
def is_available(self) -> bool:
return bool(os.environ.get("RETAINDB_API_KEY"))
def get_config_schema(self):
def get_config_schema(self) -> List[Dict[str, Any]]:
return [
{"key": "api_key", "description": "RetainDB API key", "secret": True, "required": True, "env_var": "RETAINDB_API_KEY", "url": "https://retaindb.com"},
{"key": "base_url", "description": "API endpoint", "default": "https://api.retaindb.com"},
{"key": "project", "description": "Project identifier", "default": "hermes"},
{"key": "base_url", "description": "API endpoint", "default": _DEFAULT_BASE_URL},
{"key": "project", "description": "Project identifier (optional — uses 'default' project if not set)", "default": ""},
]
def _headers(self) -> dict:
return {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
def _api(self, method: str, path: str, **kwargs):
"""Make an API call to RetainDB."""
import requests
url = f"{self._base_url}{path}"
resp = requests.request(method, url, headers=self._headers(), timeout=30, **kwargs)
resp.raise_for_status()
return resp.json()
# ── Lifecycle ──────────────────────────────────────────────────────────
def initialize(self, session_id: str, **kwargs) -> None:
self._api_key = os.environ.get("RETAINDB_API_KEY", "")
self._base_url = os.environ.get("RETAINDB_BASE_URL", _DEFAULT_BASE_URL)
self._user_id = kwargs.get("user_id", "default")
self._session_id = session_id
api_key = os.environ.get("RETAINDB_API_KEY", "")
base_url = re.sub(r"/+$", "", os.environ.get("RETAINDB_BASE_URL", _DEFAULT_BASE_URL))
# Derive profile-scoped project name so different profiles don't
# share server-side memory. Explicit RETAINDB_PROJECT always wins.
explicit_project = os.environ.get("RETAINDB_PROJECT")
if explicit_project:
self._project = explicit_project
# Project resolution: RETAINDB_PROJECT > hermes-<profile> > "default"
# If unset, the API auto-creates and uses the "default" project — no config required.
explicit = os.environ.get("RETAINDB_PROJECT")
if explicit:
project = explicit
else:
hermes_home = kwargs.get("hermes_home", "")
hermes_home = str(kwargs.get("hermes_home", ""))
profile_name = os.path.basename(hermes_home) if hermes_home else ""
# Default profile (~/.hermes) → "hermes"; named profiles → "hermes-<name>"
if profile_name and profile_name != ".hermes":
self._project = f"hermes-{profile_name}"
else:
self._project = "hermes"
project = f"hermes-{profile_name}" if (profile_name and profile_name not in {"", ".hermes"}) else "default"
self._client = _Client(api_key, base_url, project)
self._session_id = session_id
self._user_id = kwargs.get("user_id", "default") or "default"
self._agent_id = kwargs.get("agent_id", "hermes") or "hermes"
hermes_home_path = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
db_path = hermes_home_path / "retaindb_queue.db"
self._queue = _WriteQueue(self._client, db_path)
# Seed agent identity from SOUL.md in background
soul_path = hermes_home_path / "SOUL.md"
if soul_path.exists():
soul_content = soul_path.read_text(encoding="utf-8", errors="replace").strip()
if soul_content:
threading.Thread(
target=self._seed_soul,
args=(soul_content,),
name="retaindb-soul-seed",
daemon=True,
).start()
def _seed_soul(self, content: str) -> None:
try:
self._client.seed_agent_identity(self._agent_id, content, source="soul_md")
except Exception as exc:
logger.debug("RetainDB soul seed failed: %s", exc)
def system_prompt_block(self) -> str:
project = self._client.project if self._client else "retaindb"
return (
"# RetainDB Memory\n"
f"Active. Project: {self._project}.\n"
f"Active. Project: {project}.\n"
"Use retaindb_search to find memories, retaindb_remember to store facts, "
"retaindb_profile for a user overview, retaindb_context for task-relevant context."
"retaindb_profile for a user overview, retaindb_context for current-task context."
)
def prefetch(self, query: str, *, session_id: str = "") -> str:
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:
result = self._prefetch_result
self._prefetch_result = ""
if not result:
return ""
return f"## RetainDB Memory\n{result}"
# ── Background prefetch (fires at turn-end, consumed next turn-start) ──
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
def _run():
try:
data = self._api("POST", "/v1/recall", json={
"project": self._project,
"query": query,
"user_id": self._user_id,
"top_k": 5,
})
results = data.get("results", [])
if results:
lines = [r.get("content", "") for r in results if r.get("content")]
with self._prefetch_lock:
self._prefetch_result = "\n".join(f"- {l}" for l in lines)
except Exception as e:
logger.debug("RetainDB prefetch failed: %s", e)
"""Fire context + dialectic + agent model prefetches in background."""
if not self._client:
return
# Wait for any still-running prefetch threads before spawning new ones.
# Prevents thread accumulation if turns fire faster than prefetches complete.
for t in self._prefetch_threads:
t.join(timeout=2.0)
threads = [
threading.Thread(target=self._prefetch_context, args=(query,), name="retaindb-ctx", daemon=True),
threading.Thread(target=self._prefetch_dialectic, args=(query,), name="retaindb-dialectic", daemon=True),
threading.Thread(target=self._prefetch_agent_model, name="retaindb-agent-model", daemon=True),
]
self._prefetch_threads = threads
for t in threads:
t.start()
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="retaindb-prefetch")
self._prefetch_thread.start()
def _prefetch_context(self, query: str) -> None:
try:
query_result = self._client.query_context(self._user_id, self._session_id, query)
profile = self._client.get_profile(self._user_id)
overlay = _build_overlay(profile, query_result)
with self._lock:
self._context_result = overlay
except Exception as exc:
logger.debug("RetainDB context prefetch failed: %s", exc)
def _prefetch_dialectic(self, query: str) -> None:
try:
result = self._client.ask_user(self._user_id, query, reasoning_level=self._reasoning_level(query))
answer = str(result.get("answer") or "")
if answer:
with self._lock:
self._dialectic_result = answer
except Exception as exc:
logger.debug("RetainDB dialectic prefetch failed: %s", exc)
def _prefetch_agent_model(self) -> None:
try:
model = self._client.get_agent_model(self._agent_id)
if model.get("memory_count", 0) > 0:
with self._lock:
self._agent_model = model
except Exception as exc:
logger.debug("RetainDB agent model prefetch failed: %s", exc)
@staticmethod
def _reasoning_level(query: str) -> str:
n = len(query)
if n < 120:
return "low"
if n < 400:
return "medium"
return "high"
def prefetch(self, query: str, *, session_id: str = "") -> str:
"""Consume prefetched results and return them as a context block."""
with self._lock:
context = self._context_result
dialectic = self._dialectic_result
agent_model = self._agent_model
self._context_result = ""
self._dialectic_result = ""
self._agent_model = {}
parts: list[str] = []
if context:
parts.append(context)
if dialectic:
parts.append(f"[RetainDB User Synthesis]\n{dialectic}")
if agent_model and agent_model.get("memory_count", 0) > 0:
model_lines: list[str] = []
if agent_model.get("persona"):
model_lines.append(f"Persona: {agent_model['persona']}")
if agent_model.get("persistent_instructions"):
model_lines.append("Instructions:\n" + "\n".join(f"- {i}" for i in agent_model["persistent_instructions"]))
if agent_model.get("working_style"):
model_lines.append(f"Working style: {agent_model['working_style']}")
if model_lines:
parts.append("[RetainDB Agent Self-Model]\n" + "\n".join(model_lines))
return "\n\n".join(parts)
# ── Turn sync ──────────────────────────────────────────────────────────
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Ingest conversation turn in background (non-blocking)."""
def _sync():
try:
self._api("POST", "/v1/ingest", json={
"project": self._project,
"user_id": self._user_id,
"session_id": self._session_id,
"messages": [
{"role": "user", "content": user_content},
{"role": "assistant", "content": assistant_content},
],
})
except Exception as e:
logger.warning("RetainDB sync failed: %s", e)
"""Queue turn for async ingest. Returns immediately."""
if not self._queue or not user_content:
return
now = datetime.now(timezone.utc).isoformat()
self._queue.enqueue(
self._user_id,
session_id or self._session_id,
[
{"role": "user", "content": user_content, "timestamp": now},
{"role": "assistant", "content": assistant_content, "timestamp": now},
],
)
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
self._sync_thread = threading.Thread(target=_sync, daemon=True, name="retaindb-sync")
self._sync_thread.start()
# ── Tools ──────────────────────────────────────────────────────────────
def get_tool_schemas(self) -> List[Dict[str, Any]]:
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA, REMEMBER_SCHEMA, FORGET_SCHEMA]
return [
PROFILE_SCHEMA, SEARCH_SCHEMA, CONTEXT_SCHEMA,
REMEMBER_SCHEMA, FORGET_SCHEMA,
FILE_UPLOAD_SCHEMA, FILE_LIST_SCHEMA, FILE_READ_SCHEMA,
FILE_INGEST_SCHEMA, FILE_DELETE_SCHEMA,
]
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
if not self._client:
return json.dumps({"error": "RetainDB not initialized"})
try:
if tool_name == "retaindb_profile":
data = self._api("GET", f"/v1/profile/{self._project}/{self._user_id}")
return json.dumps(data)
return json.dumps(self._dispatch(tool_name, args))
except Exception as exc:
return json.dumps({"error": str(exc)})
elif tool_name == "retaindb_search":
query = args.get("query", "")
if not query:
return json.dumps({"error": "query is required"})
data = self._api("POST", "/v1/search", json={
"project": self._project,
"user_id": self._user_id,
"query": query,
"top_k": min(int(args.get("top_k", 8)), 20),
})
return json.dumps(data)
def _dispatch(self, tool_name: str, args: dict) -> Any:
c = self._client
elif tool_name == "retaindb_context":
query = args.get("query", "")
if not query:
return json.dumps({"error": "query is required"})
data = self._api("POST", "/v1/recall", json={
"project": self._project,
"user_id": self._user_id,
"query": query,
"top_k": 5,
})
return json.dumps(data)
if tool_name == "retaindb_profile":
return c.get_profile(self._user_id)
elif tool_name == "retaindb_remember":
content = args.get("content", "")
if not content:
return json.dumps({"error": "content is required"})
data = self._api("POST", "/v1/remember", json={
"project": self._project,
"user_id": self._user_id,
"content": content,
"memory_type": args.get("memory_type", "fact"),
"importance": float(args.get("importance", 0.5)),
})
return json.dumps(data)
if tool_name == "retaindb_search":
query = args.get("query", "")
if not query:
return {"error": "query is required"}
return c.search(self._user_id, self._session_id, query, top_k=min(int(args.get("top_k", 8)), 20))
elif tool_name == "retaindb_forget":
memory_id = args.get("memory_id", "")
if not memory_id:
return json.dumps({"error": "memory_id is required"})
data = self._api("DELETE", f"/v1/memory/{memory_id}")
return json.dumps(data)
if tool_name == "retaindb_context":
query = args.get("query", "")
if not query:
return {"error": "query is required"}
query_result = c.query_context(self._user_id, self._session_id, query)
profile = c.get_profile(self._user_id)
overlay = _build_overlay(profile, query_result)
return {"context": overlay, "raw": query_result}
return json.dumps({"error": f"Unknown tool: {tool_name}"})
except Exception as e:
return json.dumps({"error": str(e)})
if tool_name == "retaindb_remember":
content = args.get("content", "")
if not content:
return {"error": "content is required"}
return c.add_memory(
self._user_id, self._session_id, content,
memory_type=args.get("memory_type", "factual"),
importance=float(args.get("importance", 0.7)),
)
if tool_name == "retaindb_forget":
memory_id = args.get("memory_id", "")
if not memory_id:
return {"error": "memory_id is required"}
return c.delete_memory(memory_id)
# ── File tools ──────────────────────────────────────────────────────
if tool_name == "retaindb_upload_file":
local_path = args.get("local_path", "")
if not local_path:
return {"error": "local_path is required"}
path_obj = Path(local_path)
if not path_obj.exists():
return {"error": f"File not found: {local_path}"}
data = path_obj.read_bytes()
import mimetypes
mime = mimetypes.guess_type(path_obj.name)[0] or "application/octet-stream"
remote_path = args.get("remote_path") or f"/{path_obj.name}"
result = c.upload_file(data, path_obj.name, remote_path, mime, args.get("scope", "PROJECT"), None)
if args.get("ingest") and result.get("file", {}).get("id"):
ingest = c.ingest_file(result["file"]["id"], user_id=self._user_id, agent_id=self._agent_id)
result["ingest"] = ingest
return result
if tool_name == "retaindb_list_files":
return c.list_files(prefix=args.get("prefix"), limit=int(args.get("limit", 50)))
if tool_name == "retaindb_read_file":
file_id = args.get("file_id", "")
if not file_id:
return {"error": "file_id is required"}
meta = c.get_file(file_id)
file_info = meta.get("file") or {}
mime = (file_info.get("mime_type") or "").lower()
raw = c.read_file_content(file_id)
if not (mime.startswith("text/") or any(file_info.get("name", "").endswith(e) for e in (".txt", ".md", ".json", ".csv", ".yaml", ".yml", ".xml", ".html"))):
return {"file_id": file_id, "rdb_uri": file_info.get("rdb_uri"), "name": file_info.get("name"), "content": None, "note": "Binary file — use retaindb_ingest_file to extract text into memory."}
text = raw.decode("utf-8", errors="replace")
return {"file_id": file_id, "rdb_uri": file_info.get("rdb_uri"), "name": file_info.get("name"), "content": text[:32000], "truncated": len(text) > 32000}
if tool_name == "retaindb_ingest_file":
file_id = args.get("file_id", "")
if not file_id:
return {"error": "file_id is required"}
return c.ingest_file(file_id, user_id=self._user_id, agent_id=self._agent_id)
if tool_name == "retaindb_delete_file":
file_id = args.get("file_id", "")
if not file_id:
return {"error": "file_id is required"}
return c.delete_file(file_id)
return {"error": f"Unknown tool: {tool_name}"}
# ── Optional hooks ─────────────────────────────────────────────────────
def on_memory_write(self, action: str, target: str, content: str) -> None:
if action == "add":
try:
self._api("POST", "/v1/remember", json={
"project": self._project,
"user_id": self._user_id,
"content": content,
"memory_type": "preference" if target == "user" else "fact",
})
except Exception as e:
logger.debug("RetainDB memory bridge failed: %s", e)
"""Mirror built-in memory writes to RetainDB."""
if action != "add" or not content or not self._client:
return
try:
memory_type = "preference" if target == "user" else "factual"
self._client.add_memory(self._user_id, self._session_id, content, memory_type=memory_type)
except Exception as exc:
logger.debug("RetainDB memory mirror failed: %s", exc)
def shutdown(self) -> None:
for t in (self._prefetch_thread, self._sync_thread):
if t and t.is_alive():
t.join(timeout=5.0)
for t in self._prefetch_threads:
t.join(timeout=3.0)
if self._queue:
self._queue.shutdown()
def register(ctx) -> None:

View File

@@ -0,0 +1,776 @@
"""Tests for the RetainDB memory plugin.
Covers: _Client HTTP client, _WriteQueue SQLite queue, _build_overlay formatter,
RetainDBMemoryProvider lifecycle/tools/prefetch, thread management, connection pooling.
"""
import json
import os
import sqlite3
import tempfile
import threading
import time
from pathlib import Path
from unittest.mock import MagicMock, patch, PropertyMock
import pytest
# ---------------------------------------------------------------------------
# Imports — guarded since plugins/memory lives outside the standard test path
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _isolate_env(tmp_path, monkeypatch):
"""Ensure HERMES_HOME and RETAINDB vars are isolated."""
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.delenv("RETAINDB_API_KEY", raising=False)
monkeypatch.delenv("RETAINDB_BASE_URL", raising=False)
monkeypatch.delenv("RETAINDB_PROJECT", raising=False)
# We need the repo root on sys.path so the plugin can import agent.memory_provider
import sys
_repo_root = str(Path(__file__).resolve().parents[2])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from plugins.memory.retaindb import (
_Client,
_WriteQueue,
_build_overlay,
RetainDBMemoryProvider,
_ASYNC_SHUTDOWN,
_DEFAULT_BASE_URL,
)
# ===========================================================================
# _Client tests
# ===========================================================================
class TestClient:
"""Test the HTTP client with mocked requests."""
def _make_client(self, api_key="rdb-test-key", base_url="https://api.retaindb.com", project="test"):
return _Client(api_key, base_url, project)
def test_base_url_trailing_slash_stripped(self):
c = self._make_client(base_url="https://api.retaindb.com///")
assert c.base_url == "https://api.retaindb.com"
def test_headers_include_auth(self):
c = self._make_client()
h = c._headers("/v1/files")
assert h["Authorization"] == "Bearer rdb-test-key"
assert "X-API-Key" not in h
def test_headers_include_api_key_for_memory_path(self):
c = self._make_client()
h = c._headers("/v1/memory/search")
assert h["X-API-Key"] == "rdb-test-key"
def test_headers_include_api_key_for_context_path(self):
c = self._make_client()
h = c._headers("/v1/context/query")
assert h["X-API-Key"] == "rdb-test-key"
def test_headers_strip_bearer_prefix(self):
c = self._make_client(api_key="Bearer rdb-test-key")
h = c._headers("/v1/memory/search")
assert h["Authorization"] == "Bearer rdb-test-key"
assert h["X-API-Key"] == "rdb-test-key"
def test_query_context_builds_correct_payload(self):
c = self._make_client()
with patch.object(c, "request") as mock_req:
mock_req.return_value = {"results": []}
c.query_context("user1", "sess1", "test query", max_tokens=500)
mock_req.assert_called_once_with("POST", "/v1/context/query", json_body={
"project": "test",
"query": "test query",
"user_id": "user1",
"session_id": "sess1",
"include_memories": True,
"max_tokens": 500,
})
def test_search_builds_correct_payload(self):
c = self._make_client()
with patch.object(c, "request") as mock_req:
mock_req.return_value = {"results": []}
c.search("user1", "sess1", "find this", top_k=5)
mock_req.assert_called_once_with("POST", "/v1/memory/search", json_body={
"project": "test",
"query": "find this",
"user_id": "user1",
"session_id": "sess1",
"top_k": 5,
"include_pending": True,
})
def test_add_memory_tries_fallback(self):
c = self._make_client()
call_count = 0
def fake_request(method, path, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise RuntimeError("404")
return {"id": "mem-1"}
with patch.object(c, "request", side_effect=fake_request):
result = c.add_memory("u1", "s1", "test fact")
assert result == {"id": "mem-1"}
assert call_count == 2
def test_delete_memory_tries_fallback(self):
c = self._make_client()
call_count = 0
def fake_request(method, path, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise RuntimeError("404")
return {"deleted": True}
with patch.object(c, "request", side_effect=fake_request):
result = c.delete_memory("mem-123")
assert result == {"deleted": True}
assert call_count == 2
def test_ingest_session_payload(self):
c = self._make_client()
with patch.object(c, "request") as mock_req:
mock_req.return_value = {"status": "ok"}
msgs = [{"role": "user", "content": "hi"}]
c.ingest_session("u1", "s1", msgs, timeout=10.0)
mock_req.assert_called_once_with("POST", "/v1/memory/ingest/session", json_body={
"project": "test",
"session_id": "s1",
"user_id": "u1",
"messages": msgs,
"write_mode": "sync",
}, timeout=10.0)
def test_ask_user_payload(self):
c = self._make_client()
with patch.object(c, "request") as mock_req:
mock_req.return_value = {"answer": "test answer"}
c.ask_user("u1", "who am i?", reasoning_level="medium")
mock_req.assert_called_once()
call_kwargs = mock_req.call_args
assert call_kwargs[1]["json_body"]["reasoning_level"] == "medium"
def test_get_agent_model_path(self):
c = self._make_client()
with patch.object(c, "request") as mock_req:
mock_req.return_value = {"memory_count": 3}
c.get_agent_model("hermes")
mock_req.assert_called_once_with(
"GET", "/v1/memory/agent/hermes/model",
params={"project": "test"}, timeout=4.0
)
# ===========================================================================
# _WriteQueue tests
# ===========================================================================
class TestWriteQueue:
"""Test the SQLite-backed write queue with real SQLite."""
def _make_queue(self, tmp_path, client=None):
if client is None:
client = MagicMock()
client.ingest_session = MagicMock(return_value={"status": "ok"})
db_path = tmp_path / "test_queue.db"
return _WriteQueue(client, db_path), client, db_path
def test_enqueue_creates_row(self, tmp_path):
q, client, db_path = self._make_queue(tmp_path)
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
# Give the writer thread a moment to process
time.sleep(1)
q.shutdown()
# If ingest succeeded, the row should be deleted
client.ingest_session.assert_called_once()
def test_enqueue_persists_to_sqlite(self, tmp_path):
client = MagicMock()
# Make ingest hang so the row stays in SQLite
client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5))
db_path = tmp_path / "test_queue.db"
q = _WriteQueue(client, db_path)
q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}])
# Check SQLite directly — row should exist since flush is slow
conn = sqlite3.connect(str(db_path))
rows = conn.execute("SELECT user_id, session_id FROM pending").fetchall()
conn.close()
assert len(rows) >= 1
assert rows[0][0] == "user1"
q.shutdown()
def test_flush_deletes_row_on_success(self, tmp_path):
q, client, db_path = self._make_queue(tmp_path)
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
time.sleep(1)
q.shutdown()
# Row should be gone
conn = sqlite3.connect(str(db_path))
rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0]
conn.close()
assert rows == 0
def test_flush_records_error_on_failure(self, tmp_path):
client = MagicMock()
client.ingest_session = MagicMock(side_effect=RuntimeError("API down"))
db_path = tmp_path / "test_queue.db"
q = _WriteQueue(client, db_path)
q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
time.sleep(3) # Allow retry + sleep(2) in _flush_row
q.shutdown()
# Row should still exist with error recorded
conn = sqlite3.connect(str(db_path))
row = conn.execute("SELECT last_error FROM pending").fetchone()
conn.close()
assert row is not None
assert "API down" in row[0]
def test_thread_local_connection_reuse(self, tmp_path):
q, _, _ = self._make_queue(tmp_path)
# Same thread should get same connection
conn1 = q._get_conn()
conn2 = q._get_conn()
assert conn1 is conn2
q.shutdown()
def test_crash_recovery_replays_pending(self, tmp_path):
"""Simulate crash: create rows, then new queue should replay them."""
db_path = tmp_path / "recovery_test.db"
# First: create a queue and insert rows, but don't let them flush
client1 = MagicMock()
client1.ingest_session = MagicMock(side_effect=RuntimeError("fail"))
q1 = _WriteQueue(client1, db_path)
q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}])
time.sleep(3)
q1.shutdown()
# Now create a new queue — it should replay the pending rows
client2 = MagicMock()
client2.ingest_session = MagicMock(return_value={"status": "ok"})
q2 = _WriteQueue(client2, db_path)
time.sleep(2)
q2.shutdown()
# The replayed row should have been ingested via client2
client2.ingest_session.assert_called_once()
call_args = client2.ingest_session.call_args
assert call_args[0][0] == "user1" # user_id
# ===========================================================================
# _build_overlay tests
# ===========================================================================
class TestBuildOverlay:
"""Test the overlay formatter (pure function)."""
def test_empty_inputs_returns_empty(self):
assert _build_overlay({}, {}) == ""
def test_empty_memories_returns_empty(self):
assert _build_overlay({"memories": []}, {"results": []}) == ""
def test_profile_items_included(self):
profile = {"memories": [{"content": "User likes Python"}]}
result = _build_overlay(profile, {})
assert "User likes Python" in result
assert "[RetainDB Context]" in result
def test_query_results_included(self):
query_result = {"results": [{"content": "Previous discussion about Rust"}]}
result = _build_overlay({}, query_result)
assert "Previous discussion about Rust" in result
def test_deduplication_removes_duplicates(self):
profile = {"memories": [{"content": "User likes Python"}]}
query_result = {"results": [{"content": "User likes Python"}]}
result = _build_overlay(profile, query_result)
assert result.count("User likes Python") == 1
def test_local_entries_filter(self):
profile = {"memories": [{"content": "Already known fact"}]}
result = _build_overlay(profile, {}, local_entries=["Already known fact"])
# The profile item matches a local entry, should be filtered
assert result == ""
def test_max_five_items_per_section(self):
profile = {"memories": [{"content": f"Fact {i}"} for i in range(10)]}
result = _build_overlay(profile, {})
# Should only include first 5
assert "Fact 0" in result
assert "Fact 4" in result
assert "Fact 5" not in result
def test_none_content_handled(self):
profile = {"memories": [{"content": None}, {"content": "Real fact"}]}
result = _build_overlay(profile, {})
assert "Real fact" in result
def test_truncation_at_320_chars(self):
long_content = "x" * 500
profile = {"memories": [{"content": long_content}]}
result = _build_overlay(profile, {})
# Each item is compacted to 320 chars max
for line in result.split("\n"):
if line.startswith("- "):
assert len(line) <= 322 # "- " + 320
# ===========================================================================
# RetainDBMemoryProvider tests
# ===========================================================================
class TestRetainDBMemoryProvider:
"""Test the main plugin class."""
def _make_provider(self, tmp_path, monkeypatch, api_key="rdb-test-key"):
monkeypatch.setenv("RETAINDB_API_KEY", api_key)
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
(tmp_path / ".hermes").mkdir(exist_ok=True)
provider = RetainDBMemoryProvider()
return provider
def test_name(self):
p = RetainDBMemoryProvider()
assert p.name == "retaindb"
def test_is_available_without_key(self):
p = RetainDBMemoryProvider()
assert p.is_available() is False
def test_is_available_with_key(self, monkeypatch):
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test")
p = RetainDBMemoryProvider()
assert p.is_available() is True
def test_config_schema(self):
p = RetainDBMemoryProvider()
schema = p.get_config_schema()
assert len(schema) == 3
keys = [s["key"] for s in schema]
assert "api_key" in keys
assert "base_url" in keys
assert "project" in keys
def test_initialize_creates_client_and_queue(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
assert p._client is not None
assert p._queue is not None
assert p._session_id == "test-session"
p.shutdown()
def test_initialize_default_project(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
assert p._client.project == "default"
p.shutdown()
def test_initialize_explicit_project(self, tmp_path, monkeypatch):
monkeypatch.setenv("RETAINDB_PROJECT", "my-project")
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
assert p._client.project == "my-project"
p.shutdown()
def test_initialize_profile_project(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
profile_home = str(tmp_path / "profiles" / "coder")
p.initialize("test-session", hermes_home=profile_home)
assert p._client.project == "hermes-coder"
p.shutdown()
def test_initialize_seeds_soul_md(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
soul_path = tmp_path / ".hermes" / "SOUL.md"
soul_path.write_text("I am a helpful agent.")
with patch.object(RetainDBMemoryProvider, "_seed_soul") as mock_seed:
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
# Give thread time to start
time.sleep(0.5)
mock_seed.assert_called_once_with("I am a helpful agent.")
p.shutdown()
def test_system_prompt_block(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
block = p.system_prompt_block()
assert "RetainDB Memory" in block
assert "Active" in block
p.shutdown()
def test_tool_schemas_count(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
schemas = p.get_tool_schemas()
assert len(schemas) == 10 # 5 memory + 5 file tools
names = [s["name"] for s in schemas]
assert "retaindb_profile" in names
assert "retaindb_search" in names
assert "retaindb_context" in names
assert "retaindb_remember" in names
assert "retaindb_forget" in names
assert "retaindb_upload_file" in names
assert "retaindb_list_files" in names
assert "retaindb_read_file" in names
assert "retaindb_ingest_file" in names
assert "retaindb_delete_file" in names
def test_handle_tool_call_not_initialized(self):
p = RetainDBMemoryProvider()
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
assert "error" in result
assert "not initialized" in result["error"]
def test_handle_tool_call_unknown_tool(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_nonexistent", {}))
assert result == {"error": "Unknown tool: retaindb_nonexistent"}
p.shutdown()
def test_dispatch_profile(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
with patch.object(p._client, "get_profile", return_value={"memories": []}):
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
assert "memories" in result
p.shutdown()
def test_dispatch_search_requires_query(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_search", {}))
assert result == {"error": "query is required"}
p.shutdown()
def test_dispatch_search(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
with patch.object(p._client, "search", return_value={"results": [{"content": "found"}]}):
result = json.loads(p.handle_tool_call("retaindb_search", {"query": "test"}))
assert "results" in result
p.shutdown()
def test_dispatch_search_top_k_capped(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
with patch.object(p._client, "search") as mock_search:
mock_search.return_value = {"results": []}
p.handle_tool_call("retaindb_search", {"query": "test", "top_k": 100})
# top_k should be capped at 20
assert mock_search.call_args[1]["top_k"] == 20
p.shutdown()
def test_dispatch_remember(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}):
result = json.loads(p.handle_tool_call("retaindb_remember", {"content": "test fact"}))
assert result["id"] == "mem-1"
p.shutdown()
def test_dispatch_remember_requires_content(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_remember", {}))
assert result == {"error": "content is required"}
p.shutdown()
def test_dispatch_forget(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
with patch.object(p._client, "delete_memory", return_value={"deleted": True}):
result = json.loads(p.handle_tool_call("retaindb_forget", {"memory_id": "mem-1"}))
assert result["deleted"] is True
p.shutdown()
def test_dispatch_forget_requires_id(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_forget", {}))
assert result == {"error": "memory_id is required"}
p.shutdown()
def test_dispatch_context(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
with patch.object(p._client, "query_context", return_value={"results": [{"content": "relevant"}]}), \
patch.object(p._client, "get_profile", return_value={"memories": []}):
result = json.loads(p.handle_tool_call("retaindb_context", {"query": "current task"}))
assert "context" in result
assert "raw" in result
p.shutdown()
def test_dispatch_file_list(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
with patch.object(p._client, "list_files", return_value={"files": []}):
result = json.loads(p.handle_tool_call("retaindb_list_files", {}))
assert "files" in result
p.shutdown()
def test_dispatch_file_upload_missing_path(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_upload_file", {}))
assert "error" in result
def test_dispatch_file_upload_not_found(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_upload_file", {"local_path": "/nonexistent/file.txt"}))
assert "File not found" in result["error"]
p.shutdown()
def test_dispatch_file_read_requires_id(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_read_file", {}))
assert result == {"error": "file_id is required"}
p.shutdown()
def test_dispatch_file_ingest_requires_id(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_ingest_file", {}))
assert result == {"error": "file_id is required"}
p.shutdown()
def test_dispatch_file_delete_requires_id(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
result = json.loads(p.handle_tool_call("retaindb_delete_file", {}))
assert result == {"error": "file_id is required"}
p.shutdown()
def test_handle_tool_call_wraps_exception(self, tmp_path, monkeypatch):
p = self._make_provider(tmp_path, monkeypatch)
p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
with patch.object(p._client, "get_profile", side_effect=RuntimeError("API exploded")):
result = json.loads(p.handle_tool_call("retaindb_profile", {}))
assert "API exploded" in result["error"]
p.shutdown()
# ===========================================================================
# Prefetch and thread management tests
# ===========================================================================
class TestPrefetch:
"""Test background prefetch and thread accumulation prevention."""
def _make_initialized_provider(self, tmp_path, monkeypatch):
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir(exist_ok=True)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
p = RetainDBMemoryProvider()
p.initialize("test-session", hermes_home=str(hermes_home))
return p
def test_queue_prefetch_skips_without_client(self):
p = RetainDBMemoryProvider()
p.queue_prefetch("test") # Should not raise
def test_prefetch_returns_empty_when_nothing_cached(self, tmp_path, monkeypatch):
p = self._make_initialized_provider(tmp_path, monkeypatch)
result = p.prefetch("test")
assert result == ""
p.shutdown()
def test_prefetch_consumes_context_result(self, tmp_path, monkeypatch):
p = self._make_initialized_provider(tmp_path, monkeypatch)
# Manually set the cached result
with p._lock:
p._context_result = "[RetainDB Context]\nProfile:\n- User likes tests"
result = p.prefetch("test")
assert "User likes tests" in result
# Should be consumed
assert p.prefetch("test") == ""
p.shutdown()
def test_prefetch_consumes_dialectic_result(self, tmp_path, monkeypatch):
p = self._make_initialized_provider(tmp_path, monkeypatch)
with p._lock:
p._dialectic_result = "User is a software engineer who prefers Python."
result = p.prefetch("test")
assert "[RetainDB User Synthesis]" in result
assert "software engineer" in result
p.shutdown()
def test_prefetch_consumes_agent_model(self, tmp_path, monkeypatch):
p = self._make_initialized_provider(tmp_path, monkeypatch)
with p._lock:
p._agent_model = {
"memory_count": 5,
"persona": "Helpful coding assistant",
"persistent_instructions": ["Be concise", "Use Python"],
"working_style": "Direct and efficient",
}
result = p.prefetch("test")
assert "[RetainDB Agent Self-Model]" in result
assert "Helpful coding assistant" in result
assert "Be concise" in result
assert "Direct and efficient" in result
p.shutdown()
def test_prefetch_skips_empty_agent_model(self, tmp_path, monkeypatch):
p = self._make_initialized_provider(tmp_path, monkeypatch)
with p._lock:
p._agent_model = {"memory_count": 0}
result = p.prefetch("test")
assert "Agent Self-Model" not in result
p.shutdown()
def test_thread_accumulation_guard(self, tmp_path, monkeypatch):
"""Verify old prefetch threads are joined before new ones spawn."""
p = self._make_initialized_provider(tmp_path, monkeypatch)
# Mock the prefetch methods to be slow
with patch.object(p, "_prefetch_context", side_effect=lambda q: time.sleep(0.5)), \
patch.object(p, "_prefetch_dialectic", side_effect=lambda q: time.sleep(0.5)), \
patch.object(p, "_prefetch_agent_model", side_effect=lambda: time.sleep(0.5)):
p.queue_prefetch("query 1")
first_threads = list(p._prefetch_threads)
assert len(first_threads) == 3
# Call again — should join first batch before spawning new
p.queue_prefetch("query 2")
second_threads = list(p._prefetch_threads)
assert len(second_threads) == 3
# Should be different thread objects
for t in second_threads:
assert t not in first_threads
p.shutdown()
def test_reasoning_level_short(self):
assert RetainDBMemoryProvider._reasoning_level("hi") == "low"
def test_reasoning_level_medium(self):
assert RetainDBMemoryProvider._reasoning_level("x" * 200) == "medium"
def test_reasoning_level_long(self):
assert RetainDBMemoryProvider._reasoning_level("x" * 500) == "high"
# ===========================================================================
# sync_turn tests
# ===========================================================================
class TestSyncTurn:
"""Test turn synchronization via the write queue."""
def test_sync_turn_enqueues(self, tmp_path, monkeypatch):
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir(exist_ok=True)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
p = RetainDBMemoryProvider()
p.initialize("test-session", hermes_home=str(hermes_home))
with patch.object(p._queue, "enqueue") as mock_enqueue:
p.sync_turn("user msg", "assistant msg")
mock_enqueue.assert_called_once()
args = mock_enqueue.call_args[0]
assert args[0] == "default" # user_id
assert args[1] == "test-session" # session_id
msgs = args[2]
assert len(msgs) == 2
assert msgs[0]["role"] == "user"
assert msgs[1]["role"] == "assistant"
p.shutdown()
def test_sync_turn_skips_empty_user_content(self, tmp_path, monkeypatch):
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir(exist_ok=True)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
p = RetainDBMemoryProvider()
p.initialize("test-session", hermes_home=str(hermes_home))
with patch.object(p._queue, "enqueue") as mock_enqueue:
p.sync_turn("", "assistant msg")
mock_enqueue.assert_not_called()
p.shutdown()
# ===========================================================================
# on_memory_write hook tests
# ===========================================================================
class TestOnMemoryWrite:
"""Test the built-in memory mirror hook."""
def test_mirrors_add_action(self, tmp_path, monkeypatch):
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir(exist_ok=True)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
p = RetainDBMemoryProvider()
p.initialize("test-session", hermes_home=str(hermes_home))
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add:
p.on_memory_write("add", "user", "User prefers dark mode")
mock_add.assert_called_once()
assert mock_add.call_args[1]["memory_type"] == "preference"
p.shutdown()
def test_skips_non_add_action(self, tmp_path, monkeypatch):
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir(exist_ok=True)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
p = RetainDBMemoryProvider()
p.initialize("test-session", hermes_home=str(hermes_home))
with patch.object(p._client, "add_memory") as mock_add:
p.on_memory_write("remove", "user", "something")
mock_add.assert_not_called()
p.shutdown()
def test_skips_empty_content(self, tmp_path, monkeypatch):
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir(exist_ok=True)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
p = RetainDBMemoryProvider()
p.initialize("test-session", hermes_home=str(hermes_home))
with patch.object(p._client, "add_memory") as mock_add:
p.on_memory_write("add", "user", "")
mock_add.assert_not_called()
p.shutdown()
def test_memory_target_maps_to_type(self, tmp_path, monkeypatch):
monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir(exist_ok=True)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
p = RetainDBMemoryProvider()
p.initialize("test-session", hermes_home=str(hermes_home))
with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add:
p.on_memory_write("add", "memory", "Some env fact")
assert mock_add.call_args[1]["memory_type"] == "factual"
p.shutdown()
# ===========================================================================
# register() test
# ===========================================================================
class TestRegister:
def test_register_calls_register_memory_provider(self):
from plugins.memory.retaindb import register
ctx = MagicMock()
register(ctx)
ctx.register_memory_provider.assert_called_once()
arg = ctx.register_memory_provider.call_args[0][0]
assert isinstance(arg, RetainDBMemoryProvider)