mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-01 08:21:50 +08:00
fix(hindsight): flush buffered turns and drop stale prefetch on session switch
Two data-loss / leak gaps in HindsightMemoryProvider.on_session_switch introduced by #17409. 1. Buffered turns silently lost when retain_every_n_turns > 1. on_session_switch unconditionally cleared _session_turns without flushing. Users who batched every N>1 turns and switched mid-batch (/reset, /new, /resume, /branch, or context compression) had those buffered turns disappear. Same data-loss class as the shutdown race, different lifecycle event. Note commit_memory_session() -> on_session_end() runs *before* on_session_switch on /reset, but Hindsight doesn't implement on_session_end so the buffer survives that step and dies at clear time. /resume, /branch, and compression skip commit_memory_session entirely so an on_session_end impl wouldn't help them anyway. Fix: snapshot the old _session_id, _document_id, _parent_session_id, _turn_index, and _session_turns; spawn one final retain that lands under the OLD document_id; then rotate state. Metadata is built synchronously against the old self._* so session_id / lineage tags on the flushed item all reference the prior session consistently. 2. Stale _prefetch_result leaks across switch. If queue_prefetch ran in the old session and the result hadn't been consumed by prefetch() yet, on_session_switch left the cached recall text in place. The next session's first prefetch() call would return text mined from the prior session's bank/query. Fix: join any in-flight _prefetch_thread (3s bounded — matches shutdown()), then clear _prefetch_result under _prefetch_lock before rotating session_id. Tests ----- - tests/plugins/memory/test_hindsight_provider.py (TestSessionSwitchBufferFlush): - buffered turns flushed under OLD document_id with OLD lineage tags - empty buffer => no spurious retain - _prefetch_result cleared on switch - in-flight prefetch thread is awaited before clear (no race) - tests/agent/test_memory_session_switch.py: factory extended to seed the attrs the new flush path reads (_retain_source, _platform, _bank_id, prefetch state, etc.) and stub _run_hindsight_operation so existing switch-state assertions keep passing without network setup.
This commit is contained in:
@@ -1447,6 +1447,16 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
batching must start from zero so an in-flight retain doesn't flush
|
||||
under the wrong ``_document_id``.
|
||||
|
||||
Before clearing, flush any buffered turns under the *old*
|
||||
``_document_id``. Users who set ``retain_every_n_turns > 1`` would
|
||||
otherwise silently lose whatever's in ``_session_turns`` at the
|
||||
moment of switch — the same data-loss class as the shutdown race,
|
||||
just at a different lifecycle event.
|
||||
|
||||
Also wait for any in-flight prefetch from the old session and drop
|
||||
its cached result; otherwise the new session's first ``prefetch()``
|
||||
could read stale recall text from before the switch.
|
||||
|
||||
``parent_session_id`` is recorded for lineage tags on future retains.
|
||||
``reset`` is accepted but not needed for Hindsight's state model —
|
||||
buffer clearing is correct for every session switch, not only /reset.
|
||||
@@ -1454,6 +1464,70 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
new_id = str(new_session_id or "").strip()
|
||||
if not new_id:
|
||||
return
|
||||
|
||||
# 1. Flush any buffered turns under the OLD identifiers. Snapshot
|
||||
# everything before mutating self._* so metadata + tags + doc_id
|
||||
# all reference the old session consistently.
|
||||
if self._session_turns:
|
||||
old_turns = list(self._session_turns)
|
||||
old_session_id = self._session_id
|
||||
old_document_id = self._document_id
|
||||
old_parent_session_id = self._parent_session_id
|
||||
old_turn_index = self._turn_index
|
||||
old_metadata = self._build_metadata(
|
||||
message_count=len(old_turns) * 2,
|
||||
turn_index=old_turn_index,
|
||||
)
|
||||
old_lineage_tags: list[str] = []
|
||||
if old_session_id:
|
||||
old_lineage_tags.append(f"session:{old_session_id}")
|
||||
if old_parent_session_id:
|
||||
old_lineage_tags.append(f"parent:{old_parent_session_id}")
|
||||
old_content = "[" + ",".join(old_turns) + "]"
|
||||
|
||||
def _flush():
|
||||
try:
|
||||
item = self._build_retain_kwargs(
|
||||
old_content,
|
||||
context=self._retain_context,
|
||||
metadata=old_metadata,
|
||||
tags=old_lineage_tags or None,
|
||||
)
|
||||
item.pop("bank_id", None)
|
||||
item.pop("retain_async", None)
|
||||
logger.debug(
|
||||
"Hindsight flush-on-switch: bank=%s, doc=%s, num_turns=%d",
|
||||
self._bank_id, old_document_id, len(old_turns),
|
||||
)
|
||||
self._run_hindsight_operation(
|
||||
lambda client: client.aretain_batch(
|
||||
bank_id=self._bank_id,
|
||||
items=[item],
|
||||
document_id=old_document_id,
|
||||
retain_async=self._retain_async,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Hindsight flush-on-switch failed: %s", e, exc_info=True)
|
||||
|
||||
# Match sync_turn's serialization — wait for any prior retain
|
||||
# thread to finish before spawning the flush, so writes
|
||||
# against the old document arrive in order.
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=5.0)
|
||||
self._sync_thread = threading.Thread(
|
||||
target=_flush, daemon=True, name="hindsight-flush-on-switch"
|
||||
)
|
||||
self._sync_thread.start()
|
||||
|
||||
# 2. Drain any in-flight prefetch from the old session and drop
|
||||
# its cached result so the new session doesn't see stale recall.
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = ""
|
||||
|
||||
# 3. Now rotate to the new session.
|
||||
if parent_session_id:
|
||||
self._parent_session_id = str(parent_session_id).strip()
|
||||
self._session_id = new_id
|
||||
|
||||
@@ -205,6 +205,7 @@ def _make_hindsight_provider():
|
||||
bypassing __init__ and seeding the attributes on_session_switch
|
||||
reads/writes. This keeps the test hermetic.
|
||||
"""
|
||||
import threading
|
||||
hindsight_mod = pytest.importorskip("plugins.memory.hindsight")
|
||||
provider = object.__new__(hindsight_mod.HindsightMemoryProvider)
|
||||
provider._session_id = "old-sid"
|
||||
@@ -213,6 +214,33 @@ def _make_hindsight_provider():
|
||||
provider._session_turns = ["turn-1", "turn-2"]
|
||||
provider._turn_counter = 2
|
||||
provider._turn_index = 2
|
||||
# Attrs read by _build_metadata / _build_retain_kwargs when the
|
||||
# buffer-flush path on session switch fires. Empty strings keep the
|
||||
# metadata minimal but well-formed.
|
||||
provider._retain_source = ""
|
||||
provider._platform = ""
|
||||
provider._user_id = ""
|
||||
provider._user_name = ""
|
||||
provider._chat_id = ""
|
||||
provider._chat_name = ""
|
||||
provider._chat_type = ""
|
||||
provider._thread_id = ""
|
||||
provider._agent_identity = ""
|
||||
provider._agent_workspace = ""
|
||||
provider._retain_tags = []
|
||||
provider._retain_context = "test-context"
|
||||
provider._retain_async = False
|
||||
provider._bank_id = "test-bank"
|
||||
# Prefetch state the switch path drains/clears.
|
||||
provider._prefetch_thread = None
|
||||
provider._prefetch_lock = threading.Lock()
|
||||
provider._prefetch_result = ""
|
||||
# Sync thread tracking — flush spawn target.
|
||||
provider._sync_thread = None
|
||||
# Stub the network-touching helper so the spawned flush thread is a
|
||||
# no-op in unit tests. Real plugin behavior is covered by the
|
||||
# mock-client tests in tests/plugins/memory/test_hindsight_provider.py.
|
||||
provider._run_hindsight_operation = lambda _op: None
|
||||
return provider
|
||||
|
||||
|
||||
|
||||
@@ -927,6 +927,95 @@ class TestShutdownRace:
|
||||
assert provider._shutting_down.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_session_switch — flush + prefetch reset behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionSwitchBufferFlush:
|
||||
def test_buffered_turns_flushed_before_clear(self, provider_with_config):
|
||||
"""retain_every_n_turns > 1 must not silently drop partial buffers
|
||||
on session switch. Whatever's in _session_turns at switch time
|
||||
should land in the OLD document under the OLD session id."""
|
||||
p = provider_with_config(retain_every_n_turns=3, retain_async=False)
|
||||
old_doc = p._document_id
|
||||
|
||||
# Two turns buffered, no retain yet (boundary is at turn 3).
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
assert p._sync_thread is None
|
||||
p._client.aretain_batch.assert_not_called()
|
||||
|
||||
# Switch — flush should fire under OLD document_id.
|
||||
p.on_session_switch("new-sid", parent_session_id="test-session", reset=True)
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
kw = p._client.aretain_batch.call_args.kwargs
|
||||
assert kw["document_id"] == old_doc
|
||||
item = kw["items"][0]
|
||||
# Both buffered turns must be present in the flushed payload.
|
||||
content = json.loads(item["content"])
|
||||
flat = json.dumps(content)
|
||||
assert "turn1-user" in flat
|
||||
assert "turn2-user" in flat
|
||||
# Old session id must appear in lineage tags / metadata.
|
||||
assert "session:test-session" in item["tags"]
|
||||
assert item["metadata"]["session_id"] == "test-session"
|
||||
|
||||
# And the new session must start with a clean slate.
|
||||
assert p._session_id == "new-sid"
|
||||
assert p._session_turns == []
|
||||
assert p._turn_counter == 0
|
||||
assert p._document_id != old_doc
|
||||
assert p._document_id.startswith("new-sid-")
|
||||
|
||||
def test_no_flush_when_buffer_empty(self, provider):
|
||||
"""Switch with no buffered turns must not fire a spurious retain."""
|
||||
provider.on_session_switch("new-sid")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._client.aretain_batch.assert_not_called()
|
||||
assert provider._session_id == "new-sid"
|
||||
|
||||
def test_prefetch_result_cleared_on_switch(self, provider):
|
||||
"""Stale recall text from the old session must not leak into the
|
||||
next session's first prefetch read."""
|
||||
provider._prefetch_result = "old-session recall: User likes Rust"
|
||||
provider.on_session_switch("new-sid")
|
||||
assert provider._prefetch_result == ""
|
||||
# And subsequent prefetch() should now report empty, not the leftover.
|
||||
assert provider.prefetch("anything") == ""
|
||||
|
||||
def test_in_flight_prefetch_thread_drained_on_switch(self, provider, monkeypatch):
|
||||
"""on_session_switch must wait for an in-flight prefetch from the
|
||||
old session to settle before clearing _prefetch_result, otherwise
|
||||
the thread can race and re-populate the field after the clear."""
|
||||
import threading
|
||||
import time as _time
|
||||
|
||||
gate = threading.Event()
|
||||
finished = threading.Event()
|
||||
|
||||
def _slow_prefetch():
|
||||
gate.wait(timeout=5.0)
|
||||
with provider._prefetch_lock:
|
||||
provider._prefetch_result = "old-session recall"
|
||||
finished.set()
|
||||
|
||||
provider._prefetch_thread = threading.Thread(target=_slow_prefetch, daemon=True)
|
||||
provider._prefetch_thread.start()
|
||||
|
||||
# Release the prefetch worker so it writes _prefetch_result, then
|
||||
# call on_session_switch — it must join the thread before clearing.
|
||||
gate.set()
|
||||
provider.on_session_switch("new-sid")
|
||||
|
||||
assert finished.is_set(), "switch returned before prefetch thread settled"
|
||||
assert provider._prefetch_result == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user