mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-01 08:21:50 +08:00
Follow-up to the cherry-picked PR #17447. The original flush spawned a bare threading.Thread for the buffer-flush path, overwriting self._sync_thread — which is aliased to the long-lived writer thread. Two consequences: 1. No serialization with the writer queue. If old-session retains were still queued in _retain_queue, the flush ran concurrently with the writer and both threads could call aretain_batch against the same document_id. 2. The pre-spawn 'self._sync_thread.join(timeout=5.0)' tried to join the long-lived writer, which never exits, so the join was a no-op that just timed out — never actually serialized anything. Fix: enqueue the flush closure on _retain_queue via _ensure_writer + put(). Natural FIFO ordering behind any pending retains, no new thread, no broken join. Shutdown-aware so it doesn't enqueue after teardown. Tests updated to drain via _retain_queue.join() instead of the stale _sync_thread.join(). Added regression guard test_flush_serializes_behind_pending_retains_via_writer_queue that blocks the writer mid-retain to prove the flush waits in FIFO behind the old retain. Also seeds _retain_queue / _shutting_down / stubbed _ensure_writer on the bare-object test helper in test_memory_session_switch.py so that path doesn't blow up under the new queue-enqueue. tests/plugins/memory/test_hindsight_provider.py + tests/agent/test_memory_session_switch.py: 103/103 passing.
1448 lines
58 KiB
Python
1448 lines
58 KiB
Python
"""Tests for the Hindsight memory provider plugin.
|
|
|
|
Tests cover config loading, tool handlers (tags, max_tokens, types),
|
|
prefetch (auto_recall, preamble, query truncation), sync_turn (auto_retain,
|
|
turn counting, tags), and schema completeness.
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
import sys
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from plugins.memory.hindsight import (
|
|
HindsightMemoryProvider,
|
|
RECALL_SCHEMA,
|
|
REFLECT_SCHEMA,
|
|
RETAIN_SCHEMA,
|
|
_load_config,
|
|
_build_embedded_profile_env,
|
|
_normalize_retain_tags,
|
|
_resolve_bank_id_template,
|
|
_sanitize_bank_segment,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _clean_env(monkeypatch):
|
|
"""Ensure no stale env vars leak between tests."""
|
|
for key in (
|
|
"HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID",
|
|
"HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_TIMEOUT",
|
|
"HINDSIGHT_IDLE_TIMEOUT", "HINDSIGHT_LLM_API_KEY",
|
|
"HINDSIGHT_RETAIN_TAGS", "HINDSIGHT_RETAIN_SOURCE",
|
|
"HINDSIGHT_RETAIN_USER_PREFIX", "HINDSIGHT_RETAIN_ASSISTANT_PREFIX",
|
|
):
|
|
monkeypatch.delenv(key, raising=False)
|
|
|
|
|
|
def _make_mock_client():
|
|
"""Create a mock Hindsight client with async methods."""
|
|
async def _aretain(
|
|
bank_id,
|
|
content,
|
|
timestamp=None,
|
|
context=None,
|
|
document_id=None,
|
|
metadata=None,
|
|
entities=None,
|
|
tags=None,
|
|
update_mode=None,
|
|
retain_async=None,
|
|
):
|
|
return SimpleNamespace(ok=True)
|
|
|
|
client = MagicMock()
|
|
client.aretain = AsyncMock(side_effect=_aretain)
|
|
client.arecall = AsyncMock(
|
|
return_value=SimpleNamespace(
|
|
results=[
|
|
SimpleNamespace(text="Memory 1"),
|
|
SimpleNamespace(text="Memory 2"),
|
|
]
|
|
)
|
|
)
|
|
client.areflect = AsyncMock(
|
|
return_value=SimpleNamespace(text="Synthesized answer")
|
|
)
|
|
client.aretain_batch = AsyncMock()
|
|
client.aclose = AsyncMock()
|
|
return client
|
|
|
|
|
|
class _FakeSessionDB:
|
|
def __init__(self, messages=None):
|
|
self._messages = list(messages or [])
|
|
|
|
def get_messages_as_conversation(self, session_id):
|
|
return list(self._messages)
|
|
|
|
|
|
@pytest.fixture()
|
|
def provider(tmp_path, monkeypatch):
|
|
"""Create an initialized HindsightMemoryProvider with a mock client."""
|
|
config = {
|
|
"mode": "cloud",
|
|
"apiKey": "test-key",
|
|
"api_url": "http://localhost:9999",
|
|
"bank_id": "test-bank",
|
|
"budget": "mid",
|
|
"memory_mode": "hybrid",
|
|
}
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps(config))
|
|
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home", lambda: tmp_path
|
|
)
|
|
|
|
p = HindsightMemoryProvider()
|
|
p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli")
|
|
p._client = _make_mock_client()
|
|
return p
|
|
|
|
|
|
@pytest.fixture()
|
|
def provider_with_config(tmp_path, monkeypatch):
|
|
"""Create a provider factory that accepts custom config overrides."""
|
|
def _make(**overrides):
|
|
config = {
|
|
"mode": "cloud",
|
|
"apiKey": "test-key",
|
|
"api_url": "http://localhost:9999",
|
|
"bank_id": "test-bank",
|
|
"budget": "mid",
|
|
"memory_mode": "hybrid",
|
|
}
|
|
config.update(overrides)
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps(config))
|
|
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home", lambda: tmp_path
|
|
)
|
|
|
|
p = HindsightMemoryProvider()
|
|
p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli")
|
|
p._client = _make_mock_client()
|
|
return p
|
|
return _make
|
|
|
|
|
|
def test_normalize_retain_tags_accepts_csv_and_dedupes():
|
|
assert _normalize_retain_tags("agent:fakeassistantname, source_system:hermes-agent, agent:fakeassistantname") == [
|
|
"agent:fakeassistantname",
|
|
"source_system:hermes-agent",
|
|
]
|
|
|
|
|
|
def test_normalize_retain_tags_accepts_json_array_string():
|
|
value = json.dumps(["agent:fakeassistantname", "source_system:hermes-agent"])
|
|
assert _normalize_retain_tags(value) == ["agent:fakeassistantname", "source_system:hermes-agent"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schema tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSchemas:
|
|
def test_retain_schema_has_content(self):
|
|
assert RETAIN_SCHEMA["name"] == "hindsight_retain"
|
|
assert "content" in RETAIN_SCHEMA["parameters"]["properties"]
|
|
assert "tags" in RETAIN_SCHEMA["parameters"]["properties"]
|
|
assert "content" in RETAIN_SCHEMA["parameters"]["required"]
|
|
|
|
def test_recall_schema_has_query(self):
|
|
assert RECALL_SCHEMA["name"] == "hindsight_recall"
|
|
assert "query" in RECALL_SCHEMA["parameters"]["properties"]
|
|
assert "query" in RECALL_SCHEMA["parameters"]["required"]
|
|
|
|
def test_reflect_schema_has_query(self):
|
|
assert REFLECT_SCHEMA["name"] == "hindsight_reflect"
|
|
assert "query" in REFLECT_SCHEMA["parameters"]["properties"]
|
|
|
|
def test_get_tool_schemas_returns_three(self, provider):
|
|
schemas = provider.get_tool_schemas()
|
|
assert len(schemas) == 3
|
|
names = {s["name"] for s in schemas}
|
|
assert names == {"hindsight_retain", "hindsight_recall", "hindsight_reflect"}
|
|
|
|
def test_context_mode_returns_no_tools(self, provider_with_config):
|
|
p = provider_with_config(memory_mode="context")
|
|
assert p.get_tool_schemas() == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConfig:
|
|
def test_default_values(self, provider):
|
|
assert provider._auto_retain is True
|
|
assert provider._auto_recall is True
|
|
assert provider._retain_every_n_turns == 1
|
|
assert provider._recall_max_tokens == 4096
|
|
assert provider._recall_max_input_chars == 800
|
|
assert provider._tags is None
|
|
assert provider._recall_tags is None
|
|
assert provider._bank_mission == ""
|
|
assert provider._bank_retain_mission is None
|
|
assert provider._retain_context == "conversation between Hermes Agent and the User"
|
|
|
|
def test_custom_config_values(self, provider_with_config):
|
|
p = provider_with_config(
|
|
retain_tags=["tag1", "tag2"],
|
|
retain_source="hermes",
|
|
retain_user_prefix="User (fakeusername)",
|
|
retain_assistant_prefix="Assistant (fakeassistantname)",
|
|
recall_tags=["recall-tag"],
|
|
recall_tags_match="all",
|
|
auto_retain=False,
|
|
auto_recall=False,
|
|
retain_every_n_turns=3,
|
|
retain_context="custom-ctx",
|
|
bank_retain_mission="Extract key facts",
|
|
recall_max_tokens=2048,
|
|
recall_types=["world", "experience"],
|
|
recall_prompt_preamble="Custom preamble:",
|
|
recall_max_input_chars=500,
|
|
bank_mission="Test agent mission",
|
|
)
|
|
assert p._tags == ["tag1", "tag2"]
|
|
assert p._retain_tags == ["tag1", "tag2"]
|
|
assert p._retain_source == "hermes"
|
|
assert p._retain_user_prefix == "User (fakeusername)"
|
|
assert p._retain_assistant_prefix == "Assistant (fakeassistantname)"
|
|
assert p._recall_tags == ["recall-tag"]
|
|
assert p._recall_tags_match == "all"
|
|
assert p._auto_retain is False
|
|
assert p._auto_recall is False
|
|
assert p._retain_every_n_turns == 3
|
|
assert p._retain_context == "custom-ctx"
|
|
assert p._bank_retain_mission == "Extract key facts"
|
|
assert p._recall_max_tokens == 2048
|
|
assert p._recall_types == ["world", "experience"]
|
|
assert p._recall_prompt_preamble == "Custom preamble:"
|
|
assert p._recall_max_input_chars == 500
|
|
assert p._bank_mission == "Test agent mission"
|
|
|
|
def test_config_from_env_fallback(self, tmp_path, monkeypatch):
|
|
"""When no config file exists, falls back to env vars."""
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home",
|
|
lambda: tmp_path / "nonexistent",
|
|
)
|
|
monkeypatch.setenv("HINDSIGHT_MODE", "cloud")
|
|
monkeypatch.setenv("HINDSIGHT_API_KEY", "env-key")
|
|
monkeypatch.setenv("HINDSIGHT_BANK_ID", "env-bank")
|
|
monkeypatch.setenv("HINDSIGHT_BUDGET", "high")
|
|
|
|
cfg = _load_config()
|
|
assert cfg["apiKey"] == "env-key"
|
|
assert cfg["banks"]["hermes"]["bankId"] == "env-bank"
|
|
assert cfg["banks"]["hermes"]["budget"] == "high"
|
|
|
|
def test_embedded_profile_env_includes_idle_timeout_from_config(self):
|
|
env = _build_embedded_profile_env({
|
|
"llm_provider": "openai",
|
|
"llm_model": "gpt-4o-mini",
|
|
"idle_timeout": 0,
|
|
})
|
|
|
|
assert env["HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT"] == "0"
|
|
|
|
def test_embedded_profile_env_includes_idle_timeout_from_env(self, monkeypatch):
|
|
monkeypatch.setenv("HINDSIGHT_IDLE_TIMEOUT", "42")
|
|
|
|
env = _build_embedded_profile_env({
|
|
"llm_provider": "openai",
|
|
"llm_model": "gpt-4o-mini",
|
|
})
|
|
|
|
assert env["HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT"] == "42"
|
|
|
|
def test_get_client_passes_idle_timeout_to_hindsight_embedded(self, monkeypatch):
|
|
captured = {}
|
|
|
|
class FakeHindsightEmbedded:
|
|
def __init__(self, **kwargs):
|
|
captured.update(kwargs)
|
|
|
|
monkeypatch.setitem(sys.modules, "hindsight", SimpleNamespace(HindsightEmbedded=FakeHindsightEmbedded))
|
|
monkeypatch.setattr("plugins.memory.hindsight._check_local_runtime", lambda: (True, ""))
|
|
|
|
p = HindsightMemoryProvider()
|
|
p._mode = "local_embedded"
|
|
p._config = {
|
|
"profile": "hermes",
|
|
"llm_provider": "openai_compatible",
|
|
"llm_api_key": "test-key",
|
|
"llm_model": "test-model",
|
|
"idle_timeout": 0,
|
|
}
|
|
p._llm_base_url = "http://localhost:8060/v1"
|
|
|
|
p._get_client()
|
|
|
|
assert captured["idle_timeout"] == 0
|
|
assert captured["llm_provider"] == "openai"
|
|
|
|
|
|
class TestPostSetup:
|
|
def test_local_embedded_setup_materializes_profile_env(self, tmp_path, monkeypatch):
|
|
hermes_home = tmp_path / "hermes-home"
|
|
user_home = tmp_path / "user-home"
|
|
user_home.mkdir()
|
|
monkeypatch.setenv("HOME", str(user_home))
|
|
|
|
selections = iter([1, 0]) # local_embedded, openai
|
|
monkeypatch.setattr("hermes_cli.memory_setup._curses_select", lambda *args, **kwargs: next(selections))
|
|
monkeypatch.setattr("shutil.which", lambda name: None)
|
|
monkeypatch.setattr("builtins.input", lambda prompt="": "")
|
|
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
|
monkeypatch.setattr("getpass.getpass", lambda prompt="": "sk-local-test")
|
|
saved_configs = []
|
|
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: saved_configs.append(cfg.copy()))
|
|
|
|
provider = HindsightMemoryProvider()
|
|
provider.post_setup(str(hermes_home), {"memory": {}})
|
|
|
|
assert saved_configs[-1]["memory"]["provider"] == "hindsight"
|
|
env_text = (hermes_home / ".env").read_text()
|
|
assert "HINDSIGHT_LLM_API_KEY=sk-local-test\n" in env_text
|
|
assert "HINDSIGHT_TIMEOUT=120\n" in env_text
|
|
assert "HINDSIGHT_IDLE_TIMEOUT=300\n" in env_text
|
|
|
|
profile_env = user_home / ".hindsight" / "profiles" / "hermes.env"
|
|
assert profile_env.exists()
|
|
assert profile_env.read_text() == (
|
|
"HINDSIGHT_API_LLM_PROVIDER=openai\n"
|
|
"HINDSIGHT_API_LLM_API_KEY=sk-local-test\n"
|
|
"HINDSIGHT_API_LLM_MODEL=gpt-4o-mini\n"
|
|
"HINDSIGHT_API_LOG_LEVEL=info\n"
|
|
"HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT=300\n"
|
|
)
|
|
|
|
def test_local_embedded_setup_respects_existing_profile_name(self, tmp_path, monkeypatch):
|
|
hermes_home = tmp_path / "hermes-home"
|
|
user_home = tmp_path / "user-home"
|
|
user_home.mkdir()
|
|
monkeypatch.setenv("HOME", str(user_home))
|
|
|
|
selections = iter([1, 0]) # local_embedded, openai
|
|
monkeypatch.setattr("hermes_cli.memory_setup._curses_select", lambda *args, **kwargs: next(selections))
|
|
monkeypatch.setattr("shutil.which", lambda name: None)
|
|
monkeypatch.setattr("builtins.input", lambda prompt="": "")
|
|
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
|
monkeypatch.setattr("getpass.getpass", lambda prompt="": "sk-local-test")
|
|
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None)
|
|
|
|
provider = HindsightMemoryProvider()
|
|
provider.save_config({"profile": "coder"}, str(hermes_home))
|
|
provider.post_setup(str(hermes_home), {"memory": {}})
|
|
|
|
coder_env = user_home / ".hindsight" / "profiles" / "coder.env"
|
|
hermes_env = user_home / ".hindsight" / "profiles" / "hermes.env"
|
|
assert coder_env.exists()
|
|
assert not hermes_env.exists()
|
|
|
|
def test_local_embedded_setup_preserves_existing_key_when_input_left_blank(self, tmp_path, monkeypatch):
|
|
hermes_home = tmp_path / "hermes-home"
|
|
user_home = tmp_path / "user-home"
|
|
user_home.mkdir()
|
|
monkeypatch.setenv("HOME", str(user_home))
|
|
|
|
selections = iter([1, 0]) # local_embedded, openai
|
|
monkeypatch.setattr("hermes_cli.memory_setup._curses_select", lambda *args, **kwargs: next(selections))
|
|
monkeypatch.setattr("shutil.which", lambda name: None)
|
|
monkeypatch.setattr("builtins.input", lambda prompt="": "")
|
|
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
|
monkeypatch.setattr("getpass.getpass", lambda prompt="": "")
|
|
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None)
|
|
|
|
env_path = hermes_home / ".env"
|
|
env_path.parent.mkdir(parents=True, exist_ok=True)
|
|
env_path.write_text("HINDSIGHT_LLM_API_KEY=existing-key\n")
|
|
|
|
provider = HindsightMemoryProvider()
|
|
provider.post_setup(str(hermes_home), {"memory": {}})
|
|
|
|
profile_env = user_home / ".hindsight" / "profiles" / "hermes.env"
|
|
assert profile_env.exists()
|
|
assert "HINDSIGHT_API_LLM_API_KEY=existing-key\n" in profile_env.read_text()
|
|
|
|
|
|
def test_local_embedded_setup_blank_inputs_preserve_existing_config(self, tmp_path, monkeypatch):
|
|
"""Pressing Enter through setup should keep existing Hindsight values."""
|
|
hermes_home = tmp_path / "hermes-home"
|
|
user_home = tmp_path / "user-home"
|
|
user_home.mkdir()
|
|
monkeypatch.setenv("HOME", str(user_home))
|
|
monkeypatch.setattr("plugins.memory.hindsight.get_hermes_home", lambda: hermes_home)
|
|
|
|
existing_config = {
|
|
"mode": "local_embedded",
|
|
"llm_provider": "openai_compatible",
|
|
"llm_base_url": "http://192.168.1.161:8060/v1",
|
|
"llm_api_key": "9913",
|
|
"llm_model": "gemma-4-26B-A4B-it-heretic-oQ4",
|
|
"bank_id": "hermes",
|
|
"recall_budget": "mid",
|
|
"idle_timeout": 0,
|
|
"HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT": "0",
|
|
"HINDSIGHT_API_CONSOLIDATION_LLM_BATCH_SIZE": "1",
|
|
"timeout": 120,
|
|
}
|
|
provider = HindsightMemoryProvider()
|
|
provider.save_config(existing_config, str(hermes_home))
|
|
|
|
# Simulate pressing Enter at the mode and LLM-provider pickers, which
|
|
# should select their current values, and pressing Enter at text prompts.
|
|
monkeypatch.setattr("hermes_cli.memory_setup._curses_select", lambda *args, **kwargs: kwargs.get("default", 0))
|
|
monkeypatch.setattr("shutil.which", lambda name: None)
|
|
monkeypatch.setattr("builtins.input", lambda prompt="": "")
|
|
monkeypatch.setattr("sys.stdin.isatty", lambda: True)
|
|
monkeypatch.setattr("getpass.getpass", lambda prompt="": "")
|
|
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None)
|
|
|
|
provider = HindsightMemoryProvider()
|
|
provider.post_setup(str(hermes_home), {"memory": {}})
|
|
|
|
saved = json.loads((hermes_home / "hindsight" / "config.json").read_text())
|
|
assert saved["mode"] == "local_embedded"
|
|
assert saved["llm_provider"] == "openai_compatible"
|
|
assert saved["llm_base_url"] == "http://192.168.1.161:8060/v1"
|
|
assert saved["llm_api_key"] == "9913"
|
|
assert saved["llm_model"] == "gemma-4-26B-A4B-it-heretic-oQ4"
|
|
assert saved["idle_timeout"] == 0
|
|
assert saved["HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT"] == "0"
|
|
assert saved["HINDSIGHT_API_CONSOLIDATION_LLM_BATCH_SIZE"] == "1"
|
|
assert saved["timeout"] == 120
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool handler tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestToolHandlers:
|
|
def test_retain_success(self, provider):
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_retain", {"content": "user likes dark mode"}
|
|
))
|
|
assert result["result"] == "Memory stored successfully."
|
|
provider._client.aretain.assert_called_once()
|
|
call_kwargs = provider._client.aretain.call_args.kwargs
|
|
assert call_kwargs["bank_id"] == "test-bank"
|
|
assert call_kwargs["content"] == "user likes dark mode"
|
|
|
|
def test_retain_with_tags(self, provider_with_config):
|
|
p = provider_with_config(retain_tags=["pref", "ui"])
|
|
p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"})
|
|
call_kwargs = p._client.aretain.call_args.kwargs
|
|
assert call_kwargs["tags"] == ["pref", "ui"]
|
|
|
|
def test_retain_merges_per_call_tags_with_config_tags(self, provider_with_config):
|
|
p = provider_with_config(retain_tags=["pref", "ui"])
|
|
p.handle_tool_call(
|
|
"hindsight_retain",
|
|
{"content": "likes dark mode", "tags": ["client:x", "ui"]},
|
|
)
|
|
call_kwargs = p._client.aretain.call_args.kwargs
|
|
assert call_kwargs["tags"] == ["pref", "ui", "client:x"]
|
|
|
|
def test_retain_without_tags(self, provider):
|
|
provider.handle_tool_call("hindsight_retain", {"content": "hello"})
|
|
call_kwargs = provider._client.aretain.call_args.kwargs
|
|
assert "tags" not in call_kwargs
|
|
|
|
def test_retain_missing_content(self, provider):
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_retain", {}
|
|
))
|
|
assert "error" in result
|
|
|
|
def test_recall_success(self, provider):
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_recall", {"query": "dark mode"}
|
|
))
|
|
assert "Memory 1" in result["result"]
|
|
assert "Memory 2" in result["result"]
|
|
|
|
def test_recall_passes_max_tokens(self, provider_with_config):
|
|
p = provider_with_config(recall_max_tokens=2048)
|
|
p.handle_tool_call("hindsight_recall", {"query": "test"})
|
|
call_kwargs = p._client.arecall.call_args.kwargs
|
|
assert call_kwargs["max_tokens"] == 2048
|
|
|
|
def test_recall_passes_tags(self, provider_with_config):
|
|
p = provider_with_config(recall_tags=["tag1"], recall_tags_match="all")
|
|
p.handle_tool_call("hindsight_recall", {"query": "test"})
|
|
call_kwargs = p._client.arecall.call_args.kwargs
|
|
assert call_kwargs["tags"] == ["tag1"]
|
|
assert call_kwargs["tags_match"] == "all"
|
|
|
|
def test_recall_passes_types(self, provider_with_config):
|
|
p = provider_with_config(recall_types=["world", "experience"])
|
|
p.handle_tool_call("hindsight_recall", {"query": "test"})
|
|
call_kwargs = p._client.arecall.call_args.kwargs
|
|
assert call_kwargs["types"] == ["world", "experience"]
|
|
|
|
def test_recall_no_results(self, provider):
|
|
provider._client.arecall.return_value = SimpleNamespace(results=[])
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_recall", {"query": "test"}
|
|
))
|
|
assert result["result"] == "No relevant memories found."
|
|
|
|
def test_recall_missing_query(self, provider):
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_recall", {}
|
|
))
|
|
assert "error" in result
|
|
|
|
def test_reflect_success(self, provider):
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_reflect", {"query": "summarize"}
|
|
))
|
|
assert result["result"] == "Synthesized answer"
|
|
|
|
def test_reflect_missing_query(self, provider):
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_reflect", {}
|
|
))
|
|
assert "error" in result
|
|
|
|
def test_unknown_tool(self, provider):
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_unknown", {}
|
|
))
|
|
assert "error" in result
|
|
|
|
def test_retain_error_handling(self, provider):
|
|
provider._client.aretain.side_effect = RuntimeError("connection failed")
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_retain", {"content": "test"}
|
|
))
|
|
assert "error" in result
|
|
assert "connection failed" in result["error"]
|
|
|
|
def test_recall_error_handling(self, provider):
|
|
provider._client.arecall.side_effect = RuntimeError("timeout")
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_recall", {"query": "test"}
|
|
))
|
|
assert "error" in result
|
|
|
|
def test_local_embedded_recall_reconnects_after_idle_shutdown(self, provider, monkeypatch):
|
|
first_client = _make_mock_client()
|
|
first_client.arecall.side_effect = RuntimeError("Cannot connect to host 127.0.0.1:8888")
|
|
second_client = _make_mock_client()
|
|
second_client.arecall.return_value = SimpleNamespace(
|
|
results=[SimpleNamespace(text="Recovered memory")]
|
|
)
|
|
clients = iter([first_client, second_client])
|
|
|
|
provider._mode = "local_embedded"
|
|
provider._client = first_client
|
|
monkeypatch.setattr(provider, "_get_client", lambda: next(clients))
|
|
|
|
result = json.loads(provider.handle_tool_call(
|
|
"hindsight_recall", {"query": "test"}
|
|
))
|
|
|
|
assert result["result"] == "1. Recovered memory"
|
|
assert provider._client is second_client
|
|
first_client.arecall.assert_called_once()
|
|
second_client.arecall.assert_called_once()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prefetch tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPrefetch:
|
|
def test_prefetch_returns_empty_when_no_result(self, provider):
|
|
assert provider.prefetch("test") == ""
|
|
|
|
def test_prefetch_default_preamble(self, provider):
|
|
provider._prefetch_result = "- some memory"
|
|
result = provider.prefetch("test")
|
|
assert "Hindsight Memory" in result
|
|
assert "- some memory" in result
|
|
|
|
def test_prefetch_custom_preamble(self, provider_with_config):
|
|
p = provider_with_config(recall_prompt_preamble="Custom header:")
|
|
p._prefetch_result = "- memory line"
|
|
result = p.prefetch("test")
|
|
assert result.startswith("Custom header:")
|
|
assert "- memory line" in result
|
|
|
|
def test_queue_prefetch_skipped_in_tools_mode(self, provider_with_config):
|
|
p = provider_with_config(memory_mode="tools")
|
|
p.queue_prefetch("test")
|
|
# Should not start a thread
|
|
assert p._prefetch_thread is None
|
|
|
|
def test_queue_prefetch_skipped_when_auto_recall_off(self, provider_with_config):
|
|
p = provider_with_config(auto_recall=False)
|
|
p.queue_prefetch("test")
|
|
assert p._prefetch_thread is None
|
|
|
|
def test_queue_prefetch_truncates_query(self, provider_with_config):
|
|
p = provider_with_config(recall_max_input_chars=10)
|
|
# Mock _run_sync to capture the query
|
|
original_query = None
|
|
|
|
def _capture_recall(**kwargs):
|
|
nonlocal original_query
|
|
original_query = kwargs.get("query", "")
|
|
return SimpleNamespace(results=[])
|
|
|
|
p._client.arecall = AsyncMock(side_effect=_capture_recall)
|
|
|
|
long_query = "a" * 100
|
|
p.queue_prefetch(long_query)
|
|
if p._prefetch_thread:
|
|
p._prefetch_thread.join(timeout=5.0)
|
|
|
|
# The query passed to arecall should be truncated
|
|
if original_query is not None:
|
|
assert len(original_query) <= 10
|
|
|
|
def test_queue_prefetch_passes_recall_params(self, provider_with_config):
|
|
p = provider_with_config(
|
|
recall_tags=["t1"],
|
|
recall_tags_match="all",
|
|
recall_max_tokens=1024,
|
|
recall_types=["world"],
|
|
)
|
|
p.queue_prefetch("test query")
|
|
if p._prefetch_thread:
|
|
p._prefetch_thread.join(timeout=5.0)
|
|
|
|
call_kwargs = p._client.arecall.call_args.kwargs
|
|
assert call_kwargs["max_tokens"] == 1024
|
|
assert call_kwargs["tags"] == ["t1"]
|
|
assert call_kwargs["tags_match"] == "all"
|
|
assert call_kwargs["types"] == ["world"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# sync_turn tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSyncTurn:
|
|
def test_sync_turn_retains_metadata_rich_turn(self, provider_with_config):
|
|
p = provider_with_config(
|
|
retain_tags=["conv", "session1"],
|
|
retain_source="hermes",
|
|
retain_user_prefix="User (fakeusername)",
|
|
retain_assistant_prefix="Assistant (fakeassistantname)",
|
|
)
|
|
p.initialize(
|
|
session_id="session-1",
|
|
platform="discord",
|
|
user_id="fakeusername-123",
|
|
user_name="fakeusername",
|
|
chat_id="1485316232612941897",
|
|
chat_name="fakeassistantname-forums",
|
|
chat_type="thread",
|
|
thread_id="1491249007475949698",
|
|
agent_identity="fakeassistantname",
|
|
)
|
|
p._client = _make_mock_client()
|
|
|
|
p.sync_turn("hello", "hi there")
|
|
p._retain_queue.join()
|
|
|
|
p._client.aretain_batch.assert_called_once()
|
|
call_kwargs = p._client.aretain_batch.call_args.kwargs
|
|
assert call_kwargs["bank_id"] == "test-bank"
|
|
assert call_kwargs["document_id"].startswith("session-1-")
|
|
assert call_kwargs["retain_async"] is True
|
|
assert len(call_kwargs["items"]) == 1
|
|
item = call_kwargs["items"][0]
|
|
assert item["context"] == "conversation between Hermes Agent and the User"
|
|
assert item["tags"] == ["conv", "session1", "session:session-1"]
|
|
content = json.loads(item["content"])
|
|
assert len(content) == 1
|
|
assert content[0][0]["role"] == "user"
|
|
assert content[0][0]["content"] == "User (fakeusername): hello"
|
|
assert content[0][1]["role"] == "assistant"
|
|
assert content[0][1]["content"] == "Assistant (fakeassistantname): hi there"
|
|
assert item["metadata"]["source"] == "hermes"
|
|
assert item["metadata"]["session_id"] == "session-1"
|
|
assert item["metadata"]["platform"] == "discord"
|
|
assert item["metadata"]["user_id"] == "fakeusername-123"
|
|
assert item["metadata"]["user_name"] == "fakeusername"
|
|
assert item["metadata"]["chat_id"] == "1485316232612941897"
|
|
assert item["metadata"]["chat_name"] == "fakeassistantname-forums"
|
|
assert item["metadata"]["chat_type"] == "thread"
|
|
assert item["metadata"]["thread_id"] == "1491249007475949698"
|
|
assert item["metadata"]["agent_identity"] == "fakeassistantname"
|
|
assert item["metadata"]["turn_index"] == "1"
|
|
assert item["metadata"]["message_count"] == "2"
|
|
assert re.fullmatch(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?\+00:00", content[0][0]["timestamp"])
|
|
assert re.fullmatch(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", item["metadata"]["retained_at"])
|
|
|
|
def test_sync_turn_skipped_when_auto_retain_off(self, provider_with_config):
|
|
p = provider_with_config(auto_retain=False)
|
|
p.sync_turn("hello", "hi")
|
|
assert p._sync_thread is None
|
|
p._client.aretain_batch.assert_not_called()
|
|
|
|
def test_sync_turn_with_tags(self, provider_with_config):
|
|
p = provider_with_config(retain_tags=["conv", "session1"])
|
|
p.sync_turn("hello", "hi")
|
|
p._retain_queue.join()
|
|
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
|
assert "conv" in item["tags"]
|
|
assert "session1" in item["tags"]
|
|
assert "session:test-session" in item["tags"]
|
|
|
|
def test_sync_turn_uses_aretain_batch(self, provider):
|
|
"""sync_turn should use aretain_batch with retain_async."""
|
|
provider.sync_turn("hello", "hi")
|
|
provider._retain_queue.join()
|
|
provider._client.aretain_batch.assert_called_once()
|
|
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
|
assert call_kwargs["document_id"].startswith("test-session-")
|
|
assert call_kwargs["retain_async"] is True
|
|
assert len(call_kwargs["items"]) == 1
|
|
assert call_kwargs["items"][0]["context"] == "conversation between Hermes Agent and the User"
|
|
|
|
def test_sync_turn_custom_context(self, provider_with_config):
|
|
p = provider_with_config(retain_context="my-agent")
|
|
p.sync_turn("hello", "hi")
|
|
p._retain_queue.join()
|
|
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
|
assert item["context"] == "my-agent"
|
|
|
|
def test_sync_turn_every_n_turns(self, provider_with_config):
|
|
p = provider_with_config(retain_every_n_turns=3, retain_async=False)
|
|
p.sync_turn("turn1-user", "turn1-asst")
|
|
assert p._sync_thread is None
|
|
p.sync_turn("turn2-user", "turn2-asst")
|
|
assert p._sync_thread is None
|
|
p.sync_turn("turn3-user", "turn3-asst")
|
|
p._retain_queue.join()
|
|
p._client.aretain_batch.assert_called_once()
|
|
call_kwargs = p._client.aretain_batch.call_args.kwargs
|
|
assert call_kwargs["document_id"].startswith("test-session-")
|
|
assert call_kwargs["retain_async"] is False
|
|
item = call_kwargs["items"][0]
|
|
content = json.loads(item["content"])
|
|
assert len(content) == 3
|
|
assert content[-1][0]["role"] == "user"
|
|
assert content[-1][0]["content"] == "User: turn3-user"
|
|
assert content[-1][1]["role"] == "assistant"
|
|
assert content[-1][1]["content"] == "Assistant: turn3-asst"
|
|
assert item["metadata"]["turn_index"] == "3"
|
|
assert item["metadata"]["message_count"] == "6"
|
|
|
|
def test_sync_turn_accumulates_full_session(self, provider_with_config):
|
|
"""Each retain sends the ENTIRE session, not just the latest batch."""
|
|
p = provider_with_config(retain_every_n_turns=2)
|
|
|
|
p.sync_turn("turn1-user", "turn1-asst")
|
|
p.sync_turn("turn2-user", "turn2-asst")
|
|
p._retain_queue.join()
|
|
|
|
p._client.aretain_batch.reset_mock()
|
|
|
|
p.sync_turn("turn3-user", "turn3-asst")
|
|
p.sync_turn("turn4-user", "turn4-asst")
|
|
p._retain_queue.join()
|
|
|
|
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
|
|
# Should contain ALL turns from the session
|
|
assert "turn1-user" in content
|
|
assert "turn2-user" in content
|
|
assert "turn3-user" in content
|
|
assert "turn4-user" in content
|
|
|
|
def test_sync_turn_passes_document_id(self, provider):
|
|
"""sync_turn should pass document_id (session_id + per-startup ts)."""
|
|
provider.sync_turn("hello", "hi")
|
|
provider._retain_queue.join()
|
|
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
|
# Format: {session_id}-{YYYYMMDD_HHMMSS_microseconds}
|
|
assert call_kwargs["document_id"].startswith("test-session-")
|
|
assert call_kwargs["document_id"] == provider._document_id
|
|
|
|
def test_resume_creates_new_document(self, tmp_path, monkeypatch):
|
|
"""Resuming a session (re-initializing) gets a new document_id
|
|
so previously stored content is not overwritten."""
|
|
config = {"mode": "cloud", "apiKey": "k", "api_url": "http://x", "bank_id": "b"}
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps(config))
|
|
monkeypatch.setattr("plugins.memory.hindsight.get_hermes_home", lambda: tmp_path)
|
|
|
|
p1 = HindsightMemoryProvider()
|
|
p1.initialize(session_id="resumed-session", hermes_home=str(tmp_path), platform="cli")
|
|
|
|
# Sleep just enough that the microsecond timestamp differs
|
|
import time
|
|
time.sleep(0.001)
|
|
|
|
p2 = HindsightMemoryProvider()
|
|
p2.initialize(session_id="resumed-session", hermes_home=str(tmp_path), platform="cli")
|
|
|
|
# Same session, but each process gets its own document_id
|
|
assert p1._document_id != p2._document_id
|
|
assert p1._document_id.startswith("resumed-session-")
|
|
assert p2._document_id.startswith("resumed-session-")
|
|
|
|
def test_sync_turn_session_tag(self, provider):
|
|
"""Each retain should be tagged with session:<id> for filtering."""
|
|
provider.sync_turn("hello", "hi")
|
|
provider._retain_queue.join()
|
|
item = provider._client.aretain_batch.call_args.kwargs["items"][0]
|
|
assert "session:test-session" in item["tags"]
|
|
|
|
def test_sync_turn_parent_session_tag(self, tmp_path, monkeypatch):
|
|
"""When initialized with parent_session_id, parent tag is added."""
|
|
config = {"mode": "cloud", "apiKey": "k", "api_url": "http://x", "bank_id": "b"}
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps(config))
|
|
monkeypatch.setattr("plugins.memory.hindsight.get_hermes_home", lambda: tmp_path)
|
|
|
|
p = HindsightMemoryProvider()
|
|
p.initialize(
|
|
session_id="child-session",
|
|
hermes_home=str(tmp_path),
|
|
platform="cli",
|
|
parent_session_id="parent-session",
|
|
)
|
|
p._client = _make_mock_client()
|
|
p.sync_turn("hello", "hi")
|
|
p._retain_queue.join()
|
|
|
|
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
|
assert "session:child-session" in item["tags"]
|
|
assert "parent:parent-session" in item["tags"]
|
|
|
|
def test_sync_turn_error_does_not_raise(self, provider):
|
|
provider._client.aretain_batch.side_effect = RuntimeError("network error")
|
|
provider.sync_turn("hello", "hi")
|
|
provider._retain_queue.join()
|
|
|
|
def test_sync_turn_preserves_unicode(self, provider_with_config):
|
|
"""Non-ASCII text (CJK, ZWJ emoji) must survive JSON round-trip intact."""
|
|
p = provider_with_config()
|
|
p._client = _make_mock_client()
|
|
p.sync_turn("안녕 こんにちは 你好", "👨👩👧👦 family")
|
|
p._retain_queue.join()
|
|
p._client.aretain_batch.assert_called_once()
|
|
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
|
# ensure_ascii=False means non-ASCII chars appear as-is in the raw JSON,
|
|
# not as \uXXXX escape sequences.
|
|
raw_json = item["content"]
|
|
assert "안녕" in raw_json
|
|
assert "こんにちは" in raw_json
|
|
assert "你好" in raw_json
|
|
assert "👨👩👧👦" in raw_json
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shutdown / writer tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestShutdownRace:
|
|
def test_sync_turn_uses_single_writer_thread(self, provider):
|
|
"""All retains run through one long-lived writer thread."""
|
|
provider.sync_turn("a", "b")
|
|
provider._retain_queue.join()
|
|
first_writer = provider._writer_thread
|
|
assert first_writer is not None
|
|
assert first_writer.is_alive()
|
|
|
|
provider.sync_turn("c", "d")
|
|
provider._retain_queue.join()
|
|
# Same thread reused — no ad-hoc thread per call.
|
|
assert provider._writer_thread is first_writer
|
|
assert provider._client.aretain_batch.call_count == 2
|
|
|
|
def test_sync_turn_after_shutdown_is_dropped(self, provider):
|
|
"""Once shutdown has fired, new sync_turn() calls are no-ops.
|
|
|
|
This is the core of the fix: the plugin must not enqueue a retain
|
|
during interpreter teardown — that's what causes the
|
|
'cannot schedule new futures' RuntimeError + unclosed aiohttp
|
|
sessions on CLI exit.
|
|
"""
|
|
client = provider._client
|
|
provider.shutdown()
|
|
before_calls = client.aretain_batch.call_count
|
|
provider.sync_turn("late", "turn")
|
|
# No new enqueue — the retain queue stays empty.
|
|
assert provider._retain_queue.empty()
|
|
# And no new client call (would be impossible anyway since shutdown
|
|
# nulled self._client; we assert via the captured handle).
|
|
assert client.aretain_batch.call_count == before_calls
|
|
|
|
def test_queue_prefetch_after_shutdown_is_dropped(self, provider):
|
|
provider.shutdown()
|
|
provider.queue_prefetch("late query")
|
|
assert provider._prefetch_thread is None
|
|
|
|
def test_shutdown_drains_pending_retains(self, provider):
|
|
"""Shutdown must wait for queued retains to complete, not abandon them.
|
|
|
|
Otherwise the LAST in-flight turn — typically the most important —
|
|
is silently lost.
|
|
"""
|
|
client = provider._client
|
|
provider.sync_turn("a", "b")
|
|
provider.sync_turn("c", "d")
|
|
provider.shutdown()
|
|
# Both retains drained before shutdown returned.
|
|
assert client.aretain_batch.call_count == 2
|
|
assert provider._retain_queue.empty()
|
|
|
|
def test_shutdown_is_idempotent(self, provider):
|
|
provider.sync_turn("a", "b")
|
|
provider.shutdown()
|
|
# Second shutdown shouldn't blow up or re-close the client.
|
|
provider.shutdown()
|
|
assert provider._shutting_down.is_set()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# on_session_switch — flush + prefetch reset behavior
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSessionSwitchBufferFlush:
|
|
def test_buffered_turns_flushed_before_clear(self, provider_with_config):
|
|
"""retain_every_n_turns > 1 must not silently drop partial buffers
|
|
on session switch. Whatever's in _session_turns at switch time
|
|
should land in the OLD document under the OLD session id."""
|
|
p = provider_with_config(retain_every_n_turns=3, retain_async=False)
|
|
old_doc = p._document_id
|
|
|
|
# Two turns buffered, no retain yet (boundary is at turn 3). The
|
|
# writer hasn't been started either — sync_turn's early return
|
|
# skips _ensure_writer when no retain is due.
|
|
p.sync_turn("turn1-user", "turn1-asst")
|
|
p.sync_turn("turn2-user", "turn2-asst")
|
|
assert p._sync_thread is None
|
|
p._client.aretain_batch.assert_not_called()
|
|
|
|
# Switch — flush should fire under OLD document_id via the writer queue.
|
|
p.on_session_switch("new-sid", parent_session_id="test-session", reset=True)
|
|
p._retain_queue.join()
|
|
|
|
p._client.aretain_batch.assert_called_once()
|
|
kw = p._client.aretain_batch.call_args.kwargs
|
|
assert kw["document_id"] == old_doc
|
|
item = kw["items"][0]
|
|
# Both buffered turns must be present in the flushed payload.
|
|
content = json.loads(item["content"])
|
|
flat = json.dumps(content)
|
|
assert "turn1-user" in flat
|
|
assert "turn2-user" in flat
|
|
# Old session id must appear in lineage tags / metadata.
|
|
assert "session:test-session" in item["tags"]
|
|
assert item["metadata"]["session_id"] == "test-session"
|
|
|
|
# And the new session must start with a clean slate.
|
|
assert p._session_id == "new-sid"
|
|
assert p._session_turns == []
|
|
assert p._turn_counter == 0
|
|
assert p._document_id != old_doc
|
|
assert p._document_id.startswith("new-sid-")
|
|
|
|
def test_no_flush_when_buffer_empty(self, provider):
|
|
"""Switch with no buffered turns must not fire a spurious retain."""
|
|
provider.on_session_switch("new-sid")
|
|
# Nothing enqueued — join is immediate.
|
|
provider._retain_queue.join()
|
|
provider._client.aretain_batch.assert_not_called()
|
|
assert provider._session_id == "new-sid"
|
|
|
|
def test_prefetch_result_cleared_on_switch(self, provider):
|
|
"""Stale recall text from the old session must not leak into the
|
|
next session's first prefetch read."""
|
|
provider._prefetch_result = "old-session recall: User likes Rust"
|
|
provider.on_session_switch("new-sid")
|
|
assert provider._prefetch_result == ""
|
|
# And subsequent prefetch() should now report empty, not the leftover.
|
|
assert provider.prefetch("anything") == ""
|
|
|
|
def test_in_flight_prefetch_thread_drained_on_switch(self, provider, monkeypatch):
|
|
"""on_session_switch must wait for an in-flight prefetch from the
|
|
old session to settle before clearing _prefetch_result, otherwise
|
|
the thread can race and re-populate the field after the clear."""
|
|
import threading
|
|
import time as _time
|
|
|
|
gate = threading.Event()
|
|
finished = threading.Event()
|
|
|
|
def _slow_prefetch():
|
|
gate.wait(timeout=5.0)
|
|
with provider._prefetch_lock:
|
|
provider._prefetch_result = "old-session recall"
|
|
finished.set()
|
|
|
|
provider._prefetch_thread = threading.Thread(target=_slow_prefetch, daemon=True)
|
|
provider._prefetch_thread.start()
|
|
|
|
# Release the prefetch worker so it writes _prefetch_result, then
|
|
# call on_session_switch — it must join the thread before clearing.
|
|
gate.set()
|
|
provider.on_session_switch("new-sid")
|
|
|
|
assert finished.is_set(), "switch returned before prefetch thread settled"
|
|
assert provider._prefetch_result == ""
|
|
|
|
def test_flush_serializes_behind_pending_retains_via_writer_queue(
|
|
self, provider_with_config
|
|
):
|
|
"""The flush closure must ride the same _retain_queue sync_turn
|
|
uses, so it lands FIFO behind any still-queued old-session
|
|
retains rather than racing them on a separate thread.
|
|
|
|
Regression guard: an earlier draft spawned a raw threading.Thread
|
|
for flush, overwriting _sync_thread and racing the writer against
|
|
the same document_id.
|
|
"""
|
|
import threading as _threading
|
|
|
|
p = provider_with_config(retain_every_n_turns=2, retain_async=False)
|
|
|
|
# Block the first writer job until we've enqueued the flush
|
|
# behind it. This proves ordering — the flush MUST wait.
|
|
gate = _threading.Event()
|
|
call_order: list[str] = []
|
|
|
|
def _aretain_batch_tracking(**kw):
|
|
idx = kw["items"][0]["metadata"].get("turn_index", "")
|
|
call_order.append(str(idx))
|
|
if idx == "2":
|
|
# First retain blocks until we've enqueued the flush.
|
|
gate.wait(timeout=5.0)
|
|
|
|
p._client.aretain_batch = AsyncMock(side_effect=_aretain_batch_tracking)
|
|
|
|
# Turn 1+2 → boundary hit → retain enqueued (will block).
|
|
p.sync_turn("turn1-user", "turn1-asst")
|
|
p.sync_turn("turn2-user", "turn2-asst")
|
|
|
|
# One more buffered turn so flush has something to land.
|
|
p.sync_turn("turn3-user", "turn3-asst")
|
|
|
|
# Switch while the first retain is still blocked on `gate`.
|
|
p.on_session_switch("new-sid", parent_session_id="test-session")
|
|
|
|
# Release the first retain. Flush must have been enqueued
|
|
# BEHIND it, and run second.
|
|
gate.set()
|
|
p._retain_queue.join()
|
|
|
|
# The flush carries all buffered turns; sync_turn's retain #2
|
|
# carried the batch at boundary time. Two distinct calls.
|
|
assert p._client.aretain_batch.call_count == 2
|
|
# First call landed while buffer was [t1, t2]; flush landed
|
|
# after we added t3. So the second call must be strictly after.
|
|
assert call_order[0] == "2"
|
|
# Flush retain has turn_index matching the buffered count at
|
|
# switch time (3 turns accumulated, _turn_index was set to 3
|
|
# by the last sync_turn).
|
|
assert call_order[1] == "3"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# System prompt tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSystemPrompt:
|
|
def test_hybrid_mode_prompt(self, provider):
|
|
block = provider.system_prompt_block()
|
|
assert "Hindsight Memory" in block
|
|
assert "hindsight_recall" in block
|
|
assert "automatically injected" in block
|
|
|
|
def test_context_mode_prompt(self, provider_with_config):
|
|
p = provider_with_config(memory_mode="context")
|
|
block = p.system_prompt_block()
|
|
assert "context mode" in block
|
|
assert "hindsight_recall" not in block
|
|
|
|
def test_tools_mode_prompt(self, provider_with_config):
|
|
p = provider_with_config(memory_mode="tools")
|
|
block = p.system_prompt_block()
|
|
assert "tools mode" in block
|
|
assert "hindsight_recall" in block
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config schema tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConfigSchema:
|
|
def test_schema_has_all_new_fields(self, provider):
|
|
schema = provider.get_config_schema()
|
|
keys = {f["key"] for f in schema}
|
|
expected_keys = {
|
|
"mode", "api_url", "api_key", "llm_provider", "llm_api_key",
|
|
"llm_model", "bank_id", "bank_id_template", "bank_mission", "bank_retain_mission",
|
|
"recall_budget", "memory_mode", "recall_prefetch_method",
|
|
"retain_tags", "retain_source",
|
|
"retain_user_prefix", "retain_assistant_prefix",
|
|
"recall_tags", "recall_tags_match",
|
|
"auto_recall", "auto_retain",
|
|
"retain_every_n_turns", "retain_async", "retain_context",
|
|
"recall_max_tokens", "recall_max_input_chars",
|
|
"recall_prompt_preamble",
|
|
}
|
|
assert expected_keys.issubset(keys), f"Missing: {expected_keys - keys}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# bank_id_template tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBankIdTemplate:
|
|
def test_sanitize_bank_segment_passthrough(self):
|
|
assert _sanitize_bank_segment("hermes") == "hermes"
|
|
assert _sanitize_bank_segment("my-agent_1") == "my-agent_1"
|
|
|
|
def test_sanitize_bank_segment_strips_unsafe(self):
|
|
assert _sanitize_bank_segment("josh@example.com") == "josh-example-com"
|
|
assert _sanitize_bank_segment("chat:#general") == "chat-general"
|
|
assert _sanitize_bank_segment(" spaces ") == "spaces"
|
|
|
|
def test_sanitize_bank_segment_empty(self):
|
|
assert _sanitize_bank_segment("") == ""
|
|
assert _sanitize_bank_segment(None) == ""
|
|
|
|
def test_resolve_empty_template_uses_fallback(self):
|
|
result = _resolve_bank_id_template(
|
|
"", fallback="hermes", profile="coder"
|
|
)
|
|
assert result == "hermes"
|
|
|
|
def test_resolve_with_profile(self):
|
|
result = _resolve_bank_id_template(
|
|
"hermes-{profile}", fallback="hermes",
|
|
profile="coder", workspace="", platform="", user="", session="",
|
|
)
|
|
assert result == "hermes-coder"
|
|
|
|
def test_resolve_with_multiple_placeholders(self):
|
|
result = _resolve_bank_id_template(
|
|
"{workspace}-{profile}-{platform}",
|
|
fallback="hermes",
|
|
profile="coder", workspace="myorg", platform="cli",
|
|
user="", session="",
|
|
)
|
|
assert result == "myorg-coder-cli"
|
|
|
|
def test_resolve_collapses_empty_placeholders(self):
|
|
# When user is empty, "hermes-{user}" becomes "hermes-" -> trimmed to "hermes"
|
|
result = _resolve_bank_id_template(
|
|
"hermes-{user}", fallback="default",
|
|
profile="", workspace="", platform="", user="", session="",
|
|
)
|
|
assert result == "hermes"
|
|
|
|
def test_resolve_collapses_double_dashes(self):
|
|
# Two empty placeholders with a dash between them should collapse
|
|
result = _resolve_bank_id_template(
|
|
"{workspace}-{profile}-{user}", fallback="fallback",
|
|
profile="coder", workspace="", platform="", user="", session="",
|
|
)
|
|
assert result == "coder"
|
|
|
|
def test_resolve_empty_rendered_falls_back(self):
|
|
result = _resolve_bank_id_template(
|
|
"{user}-{profile}", fallback="fallback",
|
|
profile="", workspace="", platform="", user="", session="",
|
|
)
|
|
assert result == "fallback"
|
|
|
|
def test_resolve_sanitizes_placeholder_values(self):
|
|
result = _resolve_bank_id_template(
|
|
"user-{user}", fallback="hermes",
|
|
profile="", workspace="", platform="",
|
|
user="josh@example.com", session="",
|
|
)
|
|
assert result == "user-josh-example-com"
|
|
|
|
def test_resolve_invalid_template_returns_fallback(self):
|
|
# Unknown placeholder should fall back without raising
|
|
result = _resolve_bank_id_template(
|
|
"hermes-{unknown}", fallback="hermes",
|
|
profile="", workspace="", platform="", user="", session="",
|
|
)
|
|
assert result == "hermes"
|
|
|
|
def test_provider_uses_bank_id_template_from_config(self, tmp_path, monkeypatch):
|
|
config = {
|
|
"mode": "cloud",
|
|
"apiKey": "k",
|
|
"api_url": "http://x",
|
|
"bank_id": "fallback-bank",
|
|
"bank_id_template": "hermes-{profile}",
|
|
}
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps(config))
|
|
monkeypatch.setattr("plugins.memory.hindsight.get_hermes_home", lambda: tmp_path)
|
|
|
|
p = HindsightMemoryProvider()
|
|
p.initialize(
|
|
session_id="s1",
|
|
hermes_home=str(tmp_path),
|
|
platform="cli",
|
|
agent_identity="coder",
|
|
agent_workspace="hermes",
|
|
)
|
|
assert p._bank_id == "hermes-coder"
|
|
assert p._bank_id_template == "hermes-{profile}"
|
|
|
|
def test_provider_without_template_uses_static_bank_id(self, tmp_path, monkeypatch):
|
|
config = {
|
|
"mode": "cloud",
|
|
"apiKey": "k",
|
|
"api_url": "http://x",
|
|
"bank_id": "my-static-bank",
|
|
}
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps(config))
|
|
monkeypatch.setattr("plugins.memory.hindsight.get_hermes_home", lambda: tmp_path)
|
|
|
|
p = HindsightMemoryProvider()
|
|
p.initialize(
|
|
session_id="s1",
|
|
hermes_home=str(tmp_path),
|
|
platform="cli",
|
|
agent_identity="coder",
|
|
)
|
|
assert p._bank_id == "my-static-bank"
|
|
|
|
def test_provider_template_with_missing_profile_falls_back(self, tmp_path, monkeypatch):
|
|
config = {
|
|
"mode": "cloud",
|
|
"apiKey": "k",
|
|
"api_url": "http://x",
|
|
"bank_id": "hermes-fallback",
|
|
"bank_id_template": "hermes-{profile}",
|
|
}
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps(config))
|
|
monkeypatch.setattr("plugins.memory.hindsight.get_hermes_home", lambda: tmp_path)
|
|
|
|
p = HindsightMemoryProvider()
|
|
# No agent_identity passed — template renders to "hermes-" which collapses to "hermes"
|
|
p.initialize(session_id="s1", hermes_home=str(tmp_path), platform="cli")
|
|
assert p._bank_id == "hermes"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Availability tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAvailability:
|
|
def test_available_with_api_key(self, tmp_path, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home",
|
|
lambda: tmp_path / "nonexistent",
|
|
)
|
|
monkeypatch.setenv("HINDSIGHT_API_KEY", "test-key")
|
|
p = HindsightMemoryProvider()
|
|
assert p.is_available()
|
|
|
|
def test_not_available_without_config(self, tmp_path, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home",
|
|
lambda: tmp_path / "nonexistent",
|
|
)
|
|
p = HindsightMemoryProvider()
|
|
assert not p.is_available()
|
|
|
|
def test_available_in_local_mode(self, tmp_path, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home",
|
|
lambda: tmp_path / "nonexistent",
|
|
)
|
|
monkeypatch.setenv("HINDSIGHT_MODE", "local")
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.importlib.import_module",
|
|
lambda name: object(),
|
|
)
|
|
p = HindsightMemoryProvider()
|
|
assert p.is_available()
|
|
|
|
def test_available_with_snake_case_api_key_in_config(self, tmp_path, monkeypatch):
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps({
|
|
"mode": "cloud",
|
|
"api_key": "***",
|
|
}))
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home",
|
|
lambda: tmp_path,
|
|
)
|
|
|
|
p = HindsightMemoryProvider()
|
|
|
|
assert p.is_available()
|
|
|
|
def test_local_mode_unavailable_when_runtime_import_fails(self, tmp_path, monkeypatch):
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home",
|
|
lambda: tmp_path / "nonexistent",
|
|
)
|
|
monkeypatch.setenv("HINDSIGHT_MODE", "local")
|
|
|
|
def _raise(_name):
|
|
raise RuntimeError(
|
|
"NumPy was built with baseline optimizations: (x86_64-v2)"
|
|
)
|
|
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.importlib.import_module",
|
|
_raise,
|
|
)
|
|
p = HindsightMemoryProvider()
|
|
assert not p.is_available()
|
|
|
|
def test_initialize_disables_local_mode_when_runtime_import_fails(self, tmp_path, monkeypatch):
|
|
config = {"mode": "local_embedded"}
|
|
config_path = tmp_path / "hindsight" / "config.json"
|
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
config_path.write_text(json.dumps(config))
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.get_hermes_home", lambda: tmp_path
|
|
)
|
|
|
|
def _raise(_name):
|
|
raise RuntimeError("x86_64-v2 unsupported")
|
|
|
|
monkeypatch.setattr(
|
|
"plugins.memory.hindsight.importlib.import_module",
|
|
_raise,
|
|
)
|
|
|
|
p = HindsightMemoryProvider()
|
|
p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli")
|
|
assert p._mode == "disabled"
|
|
|
|
|
|
class TestSharedEventLoopLifecycle:
|
|
"""Regression tests for #11923 — Hindsight leaking aiohttp ClientSession /
|
|
TCPConnector objects in long-running gateway processes.
|
|
|
|
Root cause: the module-global ``_loop`` / ``_loop_thread`` pair is shared
|
|
across every HindsightMemoryProvider instance in the process (the plugin
|
|
loader builds one provider per AIAgent, and the gateway builds one AIAgent
|
|
per concurrent chat session). When a session ended, ``shutdown()`` stopped
|
|
the shared loop, which orphaned every *other* live provider's aiohttp
|
|
ClientSession on a dead loop. Those sessions were never closed and surfaced
|
|
as ``Unclosed client session`` / ``Unclosed connector`` errors.
|
|
"""
|
|
|
|
def test_shutdown_does_not_stop_shared_event_loop(self, provider_with_config):
|
|
from plugins.memory import hindsight as hindsight_mod
|
|
|
|
async def _noop():
|
|
return 1
|
|
|
|
# Prime the shared loop by scheduling a trivial coroutine — mirrors
|
|
# the first time any real async call (arecall/aretain/areflect) runs.
|
|
assert hindsight_mod._run_sync(_noop()) == 1
|
|
|
|
loop_before = hindsight_mod._loop
|
|
thread_before = hindsight_mod._loop_thread
|
|
assert loop_before is not None and loop_before.is_running()
|
|
assert thread_before is not None and thread_before.is_alive()
|
|
|
|
# Build two independent providers (two concurrent chat sessions).
|
|
provider_a = provider_with_config()
|
|
provider_b = provider_with_config()
|
|
|
|
# End session A.
|
|
provider_a.shutdown()
|
|
|
|
# Module-global loop/thread must still be the same live objects —
|
|
# provider B (and any other sibling provider) is still relying on them.
|
|
assert hindsight_mod._loop is loop_before, (
|
|
"shutdown() swapped out the shared event loop — sibling providers "
|
|
"would have their aiohttp ClientSession orphaned (#11923)"
|
|
)
|
|
assert hindsight_mod._loop.is_running(), (
|
|
"shutdown() stopped the shared event loop — sibling providers' "
|
|
"aiohttp sessions would leak (#11923)"
|
|
)
|
|
assert hindsight_mod._loop_thread is thread_before
|
|
assert hindsight_mod._loop_thread.is_alive()
|
|
|
|
# Provider B can still dispatch async work on the shared loop.
|
|
async def _still_working():
|
|
return 42
|
|
|
|
assert hindsight_mod._run_sync(_still_working()) == 42
|
|
|
|
provider_b.shutdown()
|
|
|
|
def test_client_aclose_called_on_cloud_mode_shutdown(self, provider):
|
|
"""Per-provider session cleanup still runs even though the shared
|
|
loop is preserved. Each provider's own aiohttp session is closed
|
|
via ``self._client.aclose()``; only the (empty) shared loop survives.
|
|
"""
|
|
assert provider._client is not None
|
|
mock_client = provider._client
|
|
|
|
provider.shutdown()
|
|
|
|
mock_client.aclose.assert_called_once()
|
|
assert provider._client is None
|
|
|
|
|
|
class TestShutdown:
|
|
def test_local_embedded_shutdown_closes_inner_async_client_on_shared_loop(self, provider):
|
|
inner_client = _make_mock_client()
|
|
embedded = MagicMock()
|
|
embedded._client = inner_client
|
|
embedded.close = MagicMock()
|
|
|
|
provider._mode = "local_embedded"
|
|
provider._client = embedded
|
|
|
|
provider.shutdown()
|
|
|
|
inner_client.aclose.assert_awaited_once()
|
|
embedded.close.assert_called_once()
|
|
assert embedded._client is None
|
|
assert provider._client is None
|
|
|