fix(hindsight): flush buffered turns and drop stale prefetch on session switch

Two data-loss / leak gaps in HindsightMemoryProvider.on_session_switch
introduced by #17409.

1. Buffered turns silently lost when retain_every_n_turns > 1.
   on_session_switch unconditionally cleared _session_turns without
   flushing. Users who batched every N>1 turns and switched mid-batch
   (/reset, /new, /resume, /branch, or context compression) had those
   buffered turns disappear. Same data-loss class as the shutdown race,
   different lifecycle event.

   Note commit_memory_session() -> on_session_end() runs *before*
   on_session_switch on /reset, but Hindsight doesn't implement
   on_session_end so the buffer survives that step and dies at clear
   time. /resume, /branch, and compression skip commit_memory_session
   entirely so an on_session_end impl wouldn't help them anyway.

   Fix: snapshot the old _session_id, _document_id, _parent_session_id,
   _turn_index, and _session_turns; spawn one final retain that lands
   under the OLD document_id; then rotate state. Metadata is built
   synchronously against the old self._* so session_id / lineage tags
   on the flushed item all reference the prior session consistently.

2. Stale _prefetch_result leaks across switch.
   If queue_prefetch ran in the old session and the result hadn't been
   consumed by prefetch() yet, on_session_switch left the cached recall
   text in place. The next session's first prefetch() call would return
   text mined from the prior session's bank/query.

   Fix: join any in-flight _prefetch_thread (3s bounded — matches
   shutdown()), then clear _prefetch_result under _prefetch_lock before
   rotating session_id.

Tests
-----
- tests/plugins/memory/test_hindsight_provider.py (TestSessionSwitchBufferFlush):
    - buffered turns flushed under OLD document_id with OLD lineage tags
    - empty buffer => no spurious retain
    - _prefetch_result cleared on switch
    - in-flight prefetch thread is awaited before clear (no race)
- tests/agent/test_memory_session_switch.py: factory extended to seed the
  attrs the new flush path reads (_retain_source, _platform, _bank_id,
  prefetch state, etc.) and stub _run_hindsight_operation so existing
  switch-state assertions keep passing without network setup.
This commit is contained in:
Nicolò Boschi
2026-04-29 14:58:34 +02:00
committed by Teknium
parent 1bedc836b5
commit c38dac742b
3 changed files with 191 additions and 0 deletions

View File

@@ -1447,6 +1447,16 @@ class HindsightMemoryProvider(MemoryProvider):
batching must start from zero so an in-flight retain doesn't flush
under the wrong ``_document_id``.
Before clearing, flush any buffered turns under the *old*
``_document_id``. Users who set ``retain_every_n_turns > 1`` would
otherwise silently lose whatever's in ``_session_turns`` at the
moment of switch — the same data-loss class as the shutdown race,
just at a different lifecycle event.
Also wait for any in-flight prefetch from the old session and drop
its cached result; otherwise the new session's first ``prefetch()``
could read stale recall text from before the switch.
``parent_session_id`` is recorded for lineage tags on future retains.
``reset`` is accepted but not needed for Hindsight's state model —
buffer clearing is correct for every session switch, not only /reset.
@@ -1454,6 +1464,70 @@ class HindsightMemoryProvider(MemoryProvider):
new_id = str(new_session_id or "").strip()
if not new_id:
return
# 1. Flush any buffered turns under the OLD identifiers. Snapshot
# everything before mutating self._* so metadata + tags + doc_id
# all reference the old session consistently.
if self._session_turns:
old_turns = list(self._session_turns)
old_session_id = self._session_id
old_document_id = self._document_id
old_parent_session_id = self._parent_session_id
old_turn_index = self._turn_index
old_metadata = self._build_metadata(
message_count=len(old_turns) * 2,
turn_index=old_turn_index,
)
old_lineage_tags: list[str] = []
if old_session_id:
old_lineage_tags.append(f"session:{old_session_id}")
if old_parent_session_id:
old_lineage_tags.append(f"parent:{old_parent_session_id}")
old_content = "[" + ",".join(old_turns) + "]"
def _flush():
try:
item = self._build_retain_kwargs(
old_content,
context=self._retain_context,
metadata=old_metadata,
tags=old_lineage_tags or None,
)
item.pop("bank_id", None)
item.pop("retain_async", None)
logger.debug(
"Hindsight flush-on-switch: bank=%s, doc=%s, num_turns=%d",
self._bank_id, old_document_id, len(old_turns),
)
self._run_hindsight_operation(
lambda client: client.aretain_batch(
bank_id=self._bank_id,
items=[item],
document_id=old_document_id,
retain_async=self._retain_async,
)
)
except Exception as e:
logger.warning("Hindsight flush-on-switch failed: %s", e, exc_info=True)
# Match sync_turn's serialization — wait for any prior retain
# thread to finish before spawning the flush, so writes
# against the old document arrive in order.
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
self._sync_thread = threading.Thread(
target=_flush, daemon=True, name="hindsight-flush-on-switch"
)
self._sync_thread.start()
# 2. Drain any in-flight prefetch from the old session and drop
# its cached result so the new session doesn't see stale recall.
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
# 3. Now rotate to the new session.
if parent_session_id:
self._parent_session_id = str(parent_session_id).strip()
self._session_id = new_id