fix(hindsight): flush buffered turns and drop stale prefetch on session switch

Two data-loss / leak gaps in HindsightMemoryProvider.on_session_switch
introduced by #17409.

1. Buffered turns silently lost when retain_every_n_turns > 1.
   on_session_switch unconditionally cleared _session_turns without
   flushing. Users who batched every N>1 turns and switched mid-batch
   (/reset, /new, /resume, /branch, or context compression) had those
   buffered turns disappear. Same data-loss class as the shutdown race,
   different lifecycle event.

   Note commit_memory_session() -> on_session_end() runs *before*
   on_session_switch on /reset, but Hindsight doesn't implement
   on_session_end so the buffer survives that step and dies at clear
   time. /resume, /branch, and compression skip commit_memory_session
   entirely so an on_session_end impl wouldn't help them anyway.

   Fix: snapshot the old _session_id, _document_id, _parent_session_id,
   _turn_index, and _session_turns; spawn one final retain that lands
   under the OLD document_id; then rotate state. Metadata is built
   synchronously against the old self._* so session_id / lineage tags
   on the flushed item all reference the prior session consistently.

2. Stale _prefetch_result leaks across switch.
   If queue_prefetch ran in the old session and the result hadn't been
   consumed by prefetch() yet, on_session_switch left the cached recall
   text in place. The next session's first prefetch() call would return
   text mined from the prior session's bank/query.

   Fix: join any in-flight _prefetch_thread (3s bounded — matches
   shutdown()), then clear _prefetch_result under _prefetch_lock before
   rotating session_id.

Tests
-----
- tests/plugins/memory/test_hindsight_provider.py (TestSessionSwitchBufferFlush):
    - buffered turns flushed under OLD document_id with OLD lineage tags
    - empty buffer => no spurious retain
    - _prefetch_result cleared on switch
    - in-flight prefetch thread is awaited before clear (no race)
- tests/agent/test_memory_session_switch.py: factory extended to seed the
  attrs the new flush path reads (_retain_source, _platform, _bank_id,
  prefetch state, etc.) and stub _run_hindsight_operation so existing
  switch-state assertions keep passing without network setup.
This commit is contained in:
Nicolò Boschi
2026-04-29 14:58:34 +02:00
committed by Teknium
parent 1bedc836b5
commit c38dac742b
3 changed files with 191 additions and 0 deletions

View File

@@ -927,6 +927,95 @@ class TestShutdownRace:
assert provider._shutting_down.is_set()
# ---------------------------------------------------------------------------
# on_session_switch — flush + prefetch reset behavior
# ---------------------------------------------------------------------------
class TestSessionSwitchBufferFlush:
def test_buffered_turns_flushed_before_clear(self, provider_with_config):
"""retain_every_n_turns > 1 must not silently drop partial buffers
on session switch. Whatever's in _session_turns at switch time
should land in the OLD document under the OLD session id."""
p = provider_with_config(retain_every_n_turns=3, retain_async=False)
old_doc = p._document_id
# Two turns buffered, no retain yet (boundary is at turn 3).
p.sync_turn("turn1-user", "turn1-asst")
p.sync_turn("turn2-user", "turn2-asst")
assert p._sync_thread is None
p._client.aretain_batch.assert_not_called()
# Switch — flush should fire under OLD document_id.
p.on_session_switch("new-sid", parent_session_id="test-session", reset=True)
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._client.aretain_batch.assert_called_once()
kw = p._client.aretain_batch.call_args.kwargs
assert kw["document_id"] == old_doc
item = kw["items"][0]
# Both buffered turns must be present in the flushed payload.
content = json.loads(item["content"])
flat = json.dumps(content)
assert "turn1-user" in flat
assert "turn2-user" in flat
# Old session id must appear in lineage tags / metadata.
assert "session:test-session" in item["tags"]
assert item["metadata"]["session_id"] == "test-session"
# And the new session must start with a clean slate.
assert p._session_id == "new-sid"
assert p._session_turns == []
assert p._turn_counter == 0
assert p._document_id != old_doc
assert p._document_id.startswith("new-sid-")
def test_no_flush_when_buffer_empty(self, provider):
"""Switch with no buffered turns must not fire a spurious retain."""
provider.on_session_switch("new-sid")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._client.aretain_batch.assert_not_called()
assert provider._session_id == "new-sid"
def test_prefetch_result_cleared_on_switch(self, provider):
"""Stale recall text from the old session must not leak into the
next session's first prefetch read."""
provider._prefetch_result = "old-session recall: User likes Rust"
provider.on_session_switch("new-sid")
assert provider._prefetch_result == ""
# And subsequent prefetch() should now report empty, not the leftover.
assert provider.prefetch("anything") == ""
def test_in_flight_prefetch_thread_drained_on_switch(self, provider, monkeypatch):
"""on_session_switch must wait for an in-flight prefetch from the
old session to settle before clearing _prefetch_result, otherwise
the thread can race and re-populate the field after the clear."""
import threading
import time as _time
gate = threading.Event()
finished = threading.Event()
def _slow_prefetch():
gate.wait(timeout=5.0)
with provider._prefetch_lock:
provider._prefetch_result = "old-session recall"
finished.set()
provider._prefetch_thread = threading.Thread(target=_slow_prefetch, daemon=True)
provider._prefetch_thread.start()
# Release the prefetch worker so it writes _prefetch_result, then
# call on_session_switch — it must join the thread before clearing.
gate.set()
provider.on_session_switch("new-sid")
assert finished.is_set(), "switch returned before prefetch thread settled"
assert provider._prefetch_result == ""
# ---------------------------------------------------------------------------
# System prompt tests
# ---------------------------------------------------------------------------