diff --git a/plugins/memory/retaindb/__init__.py b/plugins/memory/retaindb/__init__.py index 07f78d005fb..2a3b7a22960 100644 --- a/plugins/memory/retaindb/__init__.py +++ b/plugins/memory/retaindb/__init__.py @@ -336,52 +336,58 @@ class _WriteQueue: self._q: queue.Queue = queue.Queue() self._thread = threading.Thread(target=self._loop, name="retaindb-writer", daemon=True) self._db_path.parent.mkdir(parents=True, exist_ok=True) + # Thread-local connection cache — one connection per thread, reused. + self._local = threading.local() self._init_db() self._thread.start() # Replay any rows left from a previous crash for row_id, user_id, session_id, msgs_json in self._pending_rows(): self._q.put((row_id, user_id, session_id, json.loads(msgs_json))) - def _connect(self) -> sqlite3.Connection: - conn = sqlite3.connect(str(self._db_path), timeout=30) - conn.row_factory = sqlite3.Row + def _get_conn(self) -> sqlite3.Connection: + """Return a cached connection for the current thread.""" + conn = getattr(self._local, "conn", None) + if conn is None: + conn = sqlite3.connect(str(self._db_path), timeout=30) + conn.row_factory = sqlite3.Row + self._local.conn = conn return conn def _init_db(self) -> None: - with self._connect() as conn: - conn.execute("""CREATE TABLE IF NOT EXISTS pending ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id TEXT, session_id TEXT, messages_json TEXT, - created_at TEXT, last_error TEXT - )""") - conn.commit() + conn = self._get_conn() + conn.execute("""CREATE TABLE IF NOT EXISTS pending ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT, session_id TEXT, messages_json TEXT, + created_at TEXT, last_error TEXT + )""") + conn.commit() def _pending_rows(self) -> list: - with self._connect() as conn: - return conn.execute("SELECT id, user_id, session_id, messages_json FROM pending ORDER BY id ASC LIMIT 200").fetchall() + conn = self._get_conn() + return conn.execute("SELECT id, user_id, session_id, messages_json FROM pending ORDER BY id ASC LIMIT 200").fetchall() def enqueue(self, user_id: str, session_id: str, messages: list) -> None: now = datetime.now(timezone.utc).isoformat() - with self._connect() as conn: - cur = conn.execute( - "INSERT INTO pending (user_id, session_id, messages_json, created_at) VALUES (?,?,?,?)", - (user_id, session_id, json.dumps(messages, ensure_ascii=False), now), - ) - row_id = cur.lastrowid - conn.commit() + conn = self._get_conn() + cur = conn.execute( + "INSERT INTO pending (user_id, session_id, messages_json, created_at) VALUES (?,?,?,?)", + (user_id, session_id, json.dumps(messages, ensure_ascii=False), now), + ) + row_id = cur.lastrowid + conn.commit() self._q.put((row_id, user_id, session_id, messages)) def _flush_row(self, row_id: int, user_id: str, session_id: str, messages: list) -> None: try: self._client.ingest_session(user_id, session_id, messages) - with self._connect() as conn: - conn.execute("DELETE FROM pending WHERE id = ?", (row_id,)) - conn.commit() + conn = self._get_conn() + conn.execute("DELETE FROM pending WHERE id = ?", (row_id,)) + conn.commit() except Exception as exc: logger.warning("RetainDB ingest failed (will retry): %s", exc) - with self._connect() as conn: - conn.execute("UPDATE pending SET last_error = ? WHERE id = ?", (str(exc), row_id)) - conn.commit() + conn = self._get_conn() + conn.execute("UPDATE pending SET last_error = ? WHERE id = ?", (str(exc), row_id)) + conn.commit() time.sleep(2) def _loop(self) -> None: @@ -459,6 +465,9 @@ class RetainDBMemoryProvider(MemoryProvider): self._dialectic_result = "" self._agent_model: dict = {} + # Prefetch thread tracking — prevents accumulation on rapid calls + self._prefetch_threads: list[threading.Thread] = [] + # ── Core identity ────────────────────────────────────────────────────── @property @@ -533,9 +542,18 @@ class RetainDBMemoryProvider(MemoryProvider): """Fire context + dialectic + agent model prefetches in background.""" if not self._client: return - threading.Thread(target=self._prefetch_context, args=(query,), name="retaindb-ctx", daemon=True).start() - threading.Thread(target=self._prefetch_dialectic, args=(query,), name="retaindb-dialectic", daemon=True).start() - threading.Thread(target=self._prefetch_agent_model, name="retaindb-agent-model", daemon=True).start() + # Wait for any still-running prefetch threads before spawning new ones. + # Prevents thread accumulation if turns fire faster than prefetches complete. + for t in self._prefetch_threads: + t.join(timeout=2.0) + threads = [ + threading.Thread(target=self._prefetch_context, args=(query,), name="retaindb-ctx", daemon=True), + threading.Thread(target=self._prefetch_dialectic, args=(query,), name="retaindb-dialectic", daemon=True), + threading.Thread(target=self._prefetch_agent_model, name="retaindb-agent-model", daemon=True), + ] + self._prefetch_threads = threads + for t in threads: + t.start() def _prefetch_context(self, query: str) -> None: try: @@ -736,6 +754,8 @@ class RetainDBMemoryProvider(MemoryProvider): logger.debug("RetainDB memory mirror failed: %s", exc) def shutdown(self) -> None: + for t in self._prefetch_threads: + t.join(timeout=3.0) if self._queue: self._queue.shutdown() diff --git a/tests/plugins/test_retaindb_plugin.py b/tests/plugins/test_retaindb_plugin.py new file mode 100644 index 00000000000..7e334709f68 --- /dev/null +++ b/tests/plugins/test_retaindb_plugin.py @@ -0,0 +1,776 @@ +"""Tests for the RetainDB memory plugin. + +Covers: _Client HTTP client, _WriteQueue SQLite queue, _build_overlay formatter, +RetainDBMemoryProvider lifecycle/tools/prefetch, thread management, connection pooling. +""" + +import json +import os +import sqlite3 +import tempfile +import threading +import time +from pathlib import Path +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest + + +# --------------------------------------------------------------------------- +# Imports — guarded since plugins/memory lives outside the standard test path +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _isolate_env(tmp_path, monkeypatch): + """Ensure HERMES_HOME and RETAINDB vars are isolated.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.delenv("RETAINDB_API_KEY", raising=False) + monkeypatch.delenv("RETAINDB_BASE_URL", raising=False) + monkeypatch.delenv("RETAINDB_PROJECT", raising=False) + + +# We need the repo root on sys.path so the plugin can import agent.memory_provider +import sys +_repo_root = str(Path(__file__).resolve().parents[2]) +if _repo_root not in sys.path: + sys.path.insert(0, _repo_root) + +from plugins.memory.retaindb import ( + _Client, + _WriteQueue, + _build_overlay, + RetainDBMemoryProvider, + _ASYNC_SHUTDOWN, + _DEFAULT_BASE_URL, +) + + +# =========================================================================== +# _Client tests +# =========================================================================== + +class TestClient: + """Test the HTTP client with mocked requests.""" + + def _make_client(self, api_key="rdb-test-key", base_url="https://api.retaindb.com", project="test"): + return _Client(api_key, base_url, project) + + def test_base_url_trailing_slash_stripped(self): + c = self._make_client(base_url="https://api.retaindb.com///") + assert c.base_url == "https://api.retaindb.com" + + def test_headers_include_auth(self): + c = self._make_client() + h = c._headers("/v1/files") + assert h["Authorization"] == "Bearer rdb-test-key" + assert "X-API-Key" not in h + + def test_headers_include_api_key_for_memory_path(self): + c = self._make_client() + h = c._headers("/v1/memory/search") + assert h["X-API-Key"] == "rdb-test-key" + + def test_headers_include_api_key_for_context_path(self): + c = self._make_client() + h = c._headers("/v1/context/query") + assert h["X-API-Key"] == "rdb-test-key" + + def test_headers_strip_bearer_prefix(self): + c = self._make_client(api_key="Bearer rdb-test-key") + h = c._headers("/v1/memory/search") + assert h["Authorization"] == "Bearer rdb-test-key" + assert h["X-API-Key"] == "rdb-test-key" + + def test_query_context_builds_correct_payload(self): + c = self._make_client() + with patch.object(c, "request") as mock_req: + mock_req.return_value = {"results": []} + c.query_context("user1", "sess1", "test query", max_tokens=500) + mock_req.assert_called_once_with("POST", "/v1/context/query", json_body={ + "project": "test", + "query": "test query", + "user_id": "user1", + "session_id": "sess1", + "include_memories": True, + "max_tokens": 500, + }) + + def test_search_builds_correct_payload(self): + c = self._make_client() + with patch.object(c, "request") as mock_req: + mock_req.return_value = {"results": []} + c.search("user1", "sess1", "find this", top_k=5) + mock_req.assert_called_once_with("POST", "/v1/memory/search", json_body={ + "project": "test", + "query": "find this", + "user_id": "user1", + "session_id": "sess1", + "top_k": 5, + "include_pending": True, + }) + + def test_add_memory_tries_fallback(self): + c = self._make_client() + call_count = 0 + def fake_request(method, path, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("404") + return {"id": "mem-1"} + + with patch.object(c, "request", side_effect=fake_request): + result = c.add_memory("u1", "s1", "test fact") + assert result == {"id": "mem-1"} + assert call_count == 2 + + def test_delete_memory_tries_fallback(self): + c = self._make_client() + call_count = 0 + def fake_request(method, path, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("404") + return {"deleted": True} + + with patch.object(c, "request", side_effect=fake_request): + result = c.delete_memory("mem-123") + assert result == {"deleted": True} + assert call_count == 2 + + def test_ingest_session_payload(self): + c = self._make_client() + with patch.object(c, "request") as mock_req: + mock_req.return_value = {"status": "ok"} + msgs = [{"role": "user", "content": "hi"}] + c.ingest_session("u1", "s1", msgs, timeout=10.0) + mock_req.assert_called_once_with("POST", "/v1/memory/ingest/session", json_body={ + "project": "test", + "session_id": "s1", + "user_id": "u1", + "messages": msgs, + "write_mode": "sync", + }, timeout=10.0) + + def test_ask_user_payload(self): + c = self._make_client() + with patch.object(c, "request") as mock_req: + mock_req.return_value = {"answer": "test answer"} + c.ask_user("u1", "who am i?", reasoning_level="medium") + mock_req.assert_called_once() + call_kwargs = mock_req.call_args + assert call_kwargs[1]["json_body"]["reasoning_level"] == "medium" + + def test_get_agent_model_path(self): + c = self._make_client() + with patch.object(c, "request") as mock_req: + mock_req.return_value = {"memory_count": 3} + c.get_agent_model("hermes") + mock_req.assert_called_once_with( + "GET", "/v1/memory/agent/hermes/model", + params={"project": "test"}, timeout=4.0 + ) + + +# =========================================================================== +# _WriteQueue tests +# =========================================================================== + +class TestWriteQueue: + """Test the SQLite-backed write queue with real SQLite.""" + + def _make_queue(self, tmp_path, client=None): + if client is None: + client = MagicMock() + client.ingest_session = MagicMock(return_value={"status": "ok"}) + db_path = tmp_path / "test_queue.db" + return _WriteQueue(client, db_path), client, db_path + + def test_enqueue_creates_row(self, tmp_path): + q, client, db_path = self._make_queue(tmp_path) + q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) + # Give the writer thread a moment to process + time.sleep(1) + q.shutdown() + # If ingest succeeded, the row should be deleted + client.ingest_session.assert_called_once() + + def test_enqueue_persists_to_sqlite(self, tmp_path): + client = MagicMock() + # Make ingest hang so the row stays in SQLite + client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5)) + db_path = tmp_path / "test_queue.db" + q = _WriteQueue(client, db_path) + q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}]) + # Check SQLite directly — row should exist since flush is slow + conn = sqlite3.connect(str(db_path)) + rows = conn.execute("SELECT user_id, session_id FROM pending").fetchall() + conn.close() + assert len(rows) >= 1 + assert rows[0][0] == "user1" + q.shutdown() + + def test_flush_deletes_row_on_success(self, tmp_path): + q, client, db_path = self._make_queue(tmp_path) + q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) + time.sleep(1) + q.shutdown() + # Row should be gone + conn = sqlite3.connect(str(db_path)) + rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0] + conn.close() + assert rows == 0 + + def test_flush_records_error_on_failure(self, tmp_path): + client = MagicMock() + client.ingest_session = MagicMock(side_effect=RuntimeError("API down")) + db_path = tmp_path / "test_queue.db" + q = _WriteQueue(client, db_path) + q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) + time.sleep(3) # Allow retry + sleep(2) in _flush_row + q.shutdown() + # Row should still exist with error recorded + conn = sqlite3.connect(str(db_path)) + row = conn.execute("SELECT last_error FROM pending").fetchone() + conn.close() + assert row is not None + assert "API down" in row[0] + + def test_thread_local_connection_reuse(self, tmp_path): + q, _, _ = self._make_queue(tmp_path) + # Same thread should get same connection + conn1 = q._get_conn() + conn2 = q._get_conn() + assert conn1 is conn2 + q.shutdown() + + def test_crash_recovery_replays_pending(self, tmp_path): + """Simulate crash: create rows, then new queue should replay them.""" + db_path = tmp_path / "recovery_test.db" + # First: create a queue and insert rows, but don't let them flush + client1 = MagicMock() + client1.ingest_session = MagicMock(side_effect=RuntimeError("fail")) + q1 = _WriteQueue(client1, db_path) + q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}]) + time.sleep(3) + q1.shutdown() + + # Now create a new queue — it should replay the pending rows + client2 = MagicMock() + client2.ingest_session = MagicMock(return_value={"status": "ok"}) + q2 = _WriteQueue(client2, db_path) + time.sleep(2) + q2.shutdown() + + # The replayed row should have been ingested via client2 + client2.ingest_session.assert_called_once() + call_args = client2.ingest_session.call_args + assert call_args[0][0] == "user1" # user_id + + +# =========================================================================== +# _build_overlay tests +# =========================================================================== + +class TestBuildOverlay: + """Test the overlay formatter (pure function).""" + + def test_empty_inputs_returns_empty(self): + assert _build_overlay({}, {}) == "" + + def test_empty_memories_returns_empty(self): + assert _build_overlay({"memories": []}, {"results": []}) == "" + + def test_profile_items_included(self): + profile = {"memories": [{"content": "User likes Python"}]} + result = _build_overlay(profile, {}) + assert "User likes Python" in result + assert "[RetainDB Context]" in result + + def test_query_results_included(self): + query_result = {"results": [{"content": "Previous discussion about Rust"}]} + result = _build_overlay({}, query_result) + assert "Previous discussion about Rust" in result + + def test_deduplication_removes_duplicates(self): + profile = {"memories": [{"content": "User likes Python"}]} + query_result = {"results": [{"content": "User likes Python"}]} + result = _build_overlay(profile, query_result) + assert result.count("User likes Python") == 1 + + def test_local_entries_filter(self): + profile = {"memories": [{"content": "Already known fact"}]} + result = _build_overlay(profile, {}, local_entries=["Already known fact"]) + # The profile item matches a local entry, should be filtered + assert result == "" + + def test_max_five_items_per_section(self): + profile = {"memories": [{"content": f"Fact {i}"} for i in range(10)]} + result = _build_overlay(profile, {}) + # Should only include first 5 + assert "Fact 0" in result + assert "Fact 4" in result + assert "Fact 5" not in result + + def test_none_content_handled(self): + profile = {"memories": [{"content": None}, {"content": "Real fact"}]} + result = _build_overlay(profile, {}) + assert "Real fact" in result + + def test_truncation_at_320_chars(self): + long_content = "x" * 500 + profile = {"memories": [{"content": long_content}]} + result = _build_overlay(profile, {}) + # Each item is compacted to 320 chars max + for line in result.split("\n"): + if line.startswith("- "): + assert len(line) <= 322 # "- " + 320 + + +# =========================================================================== +# RetainDBMemoryProvider tests +# =========================================================================== + +class TestRetainDBMemoryProvider: + """Test the main plugin class.""" + + def _make_provider(self, tmp_path, monkeypatch, api_key="rdb-test-key"): + monkeypatch.setenv("RETAINDB_API_KEY", api_key) + monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) + (tmp_path / ".hermes").mkdir(exist_ok=True) + provider = RetainDBMemoryProvider() + return provider + + def test_name(self): + p = RetainDBMemoryProvider() + assert p.name == "retaindb" + + def test_is_available_without_key(self): + p = RetainDBMemoryProvider() + assert p.is_available() is False + + def test_is_available_with_key(self, monkeypatch): + monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test") + p = RetainDBMemoryProvider() + assert p.is_available() is True + + def test_config_schema(self): + p = RetainDBMemoryProvider() + schema = p.get_config_schema() + assert len(schema) == 3 + keys = [s["key"] for s in schema] + assert "api_key" in keys + assert "base_url" in keys + assert "project" in keys + + def test_initialize_creates_client_and_queue(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + assert p._client is not None + assert p._queue is not None + assert p._session_id == "test-session" + p.shutdown() + + def test_initialize_default_project(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + assert p._client.project == "default" + p.shutdown() + + def test_initialize_explicit_project(self, tmp_path, monkeypatch): + monkeypatch.setenv("RETAINDB_PROJECT", "my-project") + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + assert p._client.project == "my-project" + p.shutdown() + + def test_initialize_profile_project(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + profile_home = str(tmp_path / "profiles" / "coder") + p.initialize("test-session", hermes_home=profile_home) + assert p._client.project == "hermes-coder" + p.shutdown() + + def test_initialize_seeds_soul_md(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + soul_path = tmp_path / ".hermes" / "SOUL.md" + soul_path.write_text("I am a helpful agent.") + with patch.object(RetainDBMemoryProvider, "_seed_soul") as mock_seed: + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + # Give thread time to start + time.sleep(0.5) + mock_seed.assert_called_once_with("I am a helpful agent.") + p.shutdown() + + def test_system_prompt_block(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + block = p.system_prompt_block() + assert "RetainDB Memory" in block + assert "Active" in block + p.shutdown() + + def test_tool_schemas_count(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + schemas = p.get_tool_schemas() + assert len(schemas) == 10 # 5 memory + 5 file tools + names = [s["name"] for s in schemas] + assert "retaindb_profile" in names + assert "retaindb_search" in names + assert "retaindb_context" in names + assert "retaindb_remember" in names + assert "retaindb_forget" in names + assert "retaindb_upload_file" in names + assert "retaindb_list_files" in names + assert "retaindb_read_file" in names + assert "retaindb_ingest_file" in names + assert "retaindb_delete_file" in names + + def test_handle_tool_call_not_initialized(self): + p = RetainDBMemoryProvider() + result = json.loads(p.handle_tool_call("retaindb_profile", {})) + assert "error" in result + assert "not initialized" in result["error"] + + def test_handle_tool_call_unknown_tool(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_nonexistent", {})) + assert result == {"error": "Unknown tool: retaindb_nonexistent"} + p.shutdown() + + def test_dispatch_profile(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + with patch.object(p._client, "get_profile", return_value={"memories": []}): + result = json.loads(p.handle_tool_call("retaindb_profile", {})) + assert "memories" in result + p.shutdown() + + def test_dispatch_search_requires_query(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_search", {})) + assert result == {"error": "query is required"} + p.shutdown() + + def test_dispatch_search(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + with patch.object(p._client, "search", return_value={"results": [{"content": "found"}]}): + result = json.loads(p.handle_tool_call("retaindb_search", {"query": "test"})) + assert "results" in result + p.shutdown() + + def test_dispatch_search_top_k_capped(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + with patch.object(p._client, "search") as mock_search: + mock_search.return_value = {"results": []} + p.handle_tool_call("retaindb_search", {"query": "test", "top_k": 100}) + # top_k should be capped at 20 + assert mock_search.call_args[1]["top_k"] == 20 + p.shutdown() + + def test_dispatch_remember(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}): + result = json.loads(p.handle_tool_call("retaindb_remember", {"content": "test fact"})) + assert result["id"] == "mem-1" + p.shutdown() + + def test_dispatch_remember_requires_content(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_remember", {})) + assert result == {"error": "content is required"} + p.shutdown() + + def test_dispatch_forget(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + with patch.object(p._client, "delete_memory", return_value={"deleted": True}): + result = json.loads(p.handle_tool_call("retaindb_forget", {"memory_id": "mem-1"})) + assert result["deleted"] is True + p.shutdown() + + def test_dispatch_forget_requires_id(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_forget", {})) + assert result == {"error": "memory_id is required"} + p.shutdown() + + def test_dispatch_context(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + with patch.object(p._client, "query_context", return_value={"results": [{"content": "relevant"}]}), \ + patch.object(p._client, "get_profile", return_value={"memories": []}): + result = json.loads(p.handle_tool_call("retaindb_context", {"query": "current task"})) + assert "context" in result + assert "raw" in result + p.shutdown() + + def test_dispatch_file_list(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + with patch.object(p._client, "list_files", return_value={"files": []}): + result = json.loads(p.handle_tool_call("retaindb_list_files", {})) + assert "files" in result + p.shutdown() + + def test_dispatch_file_upload_missing_path(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_upload_file", {})) + assert "error" in result + + def test_dispatch_file_upload_not_found(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_upload_file", {"local_path": "/nonexistent/file.txt"})) + assert "File not found" in result["error"] + p.shutdown() + + def test_dispatch_file_read_requires_id(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_read_file", {})) + assert result == {"error": "file_id is required"} + p.shutdown() + + def test_dispatch_file_ingest_requires_id(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_ingest_file", {})) + assert result == {"error": "file_id is required"} + p.shutdown() + + def test_dispatch_file_delete_requires_id(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + result = json.loads(p.handle_tool_call("retaindb_delete_file", {})) + assert result == {"error": "file_id is required"} + p.shutdown() + + def test_handle_tool_call_wraps_exception(self, tmp_path, monkeypatch): + p = self._make_provider(tmp_path, monkeypatch) + p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) + with patch.object(p._client, "get_profile", side_effect=RuntimeError("API exploded")): + result = json.loads(p.handle_tool_call("retaindb_profile", {})) + assert "API exploded" in result["error"] + p.shutdown() + + +# =========================================================================== +# Prefetch and thread management tests +# =========================================================================== + +class TestPrefetch: + """Test background prefetch and thread accumulation prevention.""" + + def _make_initialized_provider(self, tmp_path, monkeypatch): + monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir(exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + p = RetainDBMemoryProvider() + p.initialize("test-session", hermes_home=str(hermes_home)) + return p + + def test_queue_prefetch_skips_without_client(self): + p = RetainDBMemoryProvider() + p.queue_prefetch("test") # Should not raise + + def test_prefetch_returns_empty_when_nothing_cached(self, tmp_path, monkeypatch): + p = self._make_initialized_provider(tmp_path, monkeypatch) + result = p.prefetch("test") + assert result == "" + p.shutdown() + + def test_prefetch_consumes_context_result(self, tmp_path, monkeypatch): + p = self._make_initialized_provider(tmp_path, monkeypatch) + # Manually set the cached result + with p._lock: + p._context_result = "[RetainDB Context]\nProfile:\n- User likes tests" + result = p.prefetch("test") + assert "User likes tests" in result + # Should be consumed + assert p.prefetch("test") == "" + p.shutdown() + + def test_prefetch_consumes_dialectic_result(self, tmp_path, monkeypatch): + p = self._make_initialized_provider(tmp_path, monkeypatch) + with p._lock: + p._dialectic_result = "User is a software engineer who prefers Python." + result = p.prefetch("test") + assert "[RetainDB User Synthesis]" in result + assert "software engineer" in result + p.shutdown() + + def test_prefetch_consumes_agent_model(self, tmp_path, monkeypatch): + p = self._make_initialized_provider(tmp_path, monkeypatch) + with p._lock: + p._agent_model = { + "memory_count": 5, + "persona": "Helpful coding assistant", + "persistent_instructions": ["Be concise", "Use Python"], + "working_style": "Direct and efficient", + } + result = p.prefetch("test") + assert "[RetainDB Agent Self-Model]" in result + assert "Helpful coding assistant" in result + assert "Be concise" in result + assert "Direct and efficient" in result + p.shutdown() + + def test_prefetch_skips_empty_agent_model(self, tmp_path, monkeypatch): + p = self._make_initialized_provider(tmp_path, monkeypatch) + with p._lock: + p._agent_model = {"memory_count": 0} + result = p.prefetch("test") + assert "Agent Self-Model" not in result + p.shutdown() + + def test_thread_accumulation_guard(self, tmp_path, monkeypatch): + """Verify old prefetch threads are joined before new ones spawn.""" + p = self._make_initialized_provider(tmp_path, monkeypatch) + # Mock the prefetch methods to be slow + with patch.object(p, "_prefetch_context", side_effect=lambda q: time.sleep(0.5)), \ + patch.object(p, "_prefetch_dialectic", side_effect=lambda q: time.sleep(0.5)), \ + patch.object(p, "_prefetch_agent_model", side_effect=lambda: time.sleep(0.5)): + p.queue_prefetch("query 1") + first_threads = list(p._prefetch_threads) + assert len(first_threads) == 3 + + # Call again — should join first batch before spawning new + p.queue_prefetch("query 2") + second_threads = list(p._prefetch_threads) + assert len(second_threads) == 3 + # Should be different thread objects + for t in second_threads: + assert t not in first_threads + p.shutdown() + + def test_reasoning_level_short(self): + assert RetainDBMemoryProvider._reasoning_level("hi") == "low" + + def test_reasoning_level_medium(self): + assert RetainDBMemoryProvider._reasoning_level("x" * 200) == "medium" + + def test_reasoning_level_long(self): + assert RetainDBMemoryProvider._reasoning_level("x" * 500) == "high" + + +# =========================================================================== +# sync_turn tests +# =========================================================================== + +class TestSyncTurn: + """Test turn synchronization via the write queue.""" + + def test_sync_turn_enqueues(self, tmp_path, monkeypatch): + monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir(exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + p = RetainDBMemoryProvider() + p.initialize("test-session", hermes_home=str(hermes_home)) + with patch.object(p._queue, "enqueue") as mock_enqueue: + p.sync_turn("user msg", "assistant msg") + mock_enqueue.assert_called_once() + args = mock_enqueue.call_args[0] + assert args[0] == "default" # user_id + assert args[1] == "test-session" # session_id + msgs = args[2] + assert len(msgs) == 2 + assert msgs[0]["role"] == "user" + assert msgs[1]["role"] == "assistant" + p.shutdown() + + def test_sync_turn_skips_empty_user_content(self, tmp_path, monkeypatch): + monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir(exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + p = RetainDBMemoryProvider() + p.initialize("test-session", hermes_home=str(hermes_home)) + with patch.object(p._queue, "enqueue") as mock_enqueue: + p.sync_turn("", "assistant msg") + mock_enqueue.assert_not_called() + p.shutdown() + + +# =========================================================================== +# on_memory_write hook tests +# =========================================================================== + +class TestOnMemoryWrite: + """Test the built-in memory mirror hook.""" + + def test_mirrors_add_action(self, tmp_path, monkeypatch): + monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir(exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + p = RetainDBMemoryProvider() + p.initialize("test-session", hermes_home=str(hermes_home)) + with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add: + p.on_memory_write("add", "user", "User prefers dark mode") + mock_add.assert_called_once() + assert mock_add.call_args[1]["memory_type"] == "preference" + p.shutdown() + + def test_skips_non_add_action(self, tmp_path, monkeypatch): + monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir(exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + p = RetainDBMemoryProvider() + p.initialize("test-session", hermes_home=str(hermes_home)) + with patch.object(p._client, "add_memory") as mock_add: + p.on_memory_write("remove", "user", "something") + mock_add.assert_not_called() + p.shutdown() + + def test_skips_empty_content(self, tmp_path, monkeypatch): + monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir(exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + p = RetainDBMemoryProvider() + p.initialize("test-session", hermes_home=str(hermes_home)) + with patch.object(p._client, "add_memory") as mock_add: + p.on_memory_write("add", "user", "") + mock_add.assert_not_called() + p.shutdown() + + def test_memory_target_maps_to_type(self, tmp_path, monkeypatch): + monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir(exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + p = RetainDBMemoryProvider() + p.initialize("test-session", hermes_home=str(hermes_home)) + with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add: + p.on_memory_write("add", "memory", "Some env fact") + assert mock_add.call_args[1]["memory_type"] == "factual" + p.shutdown() + + +# =========================================================================== +# register() test +# =========================================================================== + +class TestRegister: + def test_register_calls_register_memory_provider(self): + from plugins.memory.retaindb import register + ctx = MagicMock() + register(ctx) + ctx.register_memory_provider.assert_called_once() + arg = ctx.register_memory_provider.call_args[0][0] + assert isinstance(arg, RetainDBMemoryProvider)