Files
hermes-agent/tests/agent/test_streaming_context_scrubber.py

212 lines
8.0 KiB
Python

"""Unit tests for StreamingContextScrubber (agent/memory_manager.py).
Regression coverage for #5719 — memory-context spans split across stream
deltas must not leak payload to the UI. The one-shot sanitize_context()
regex can't survive chunk boundaries, so _fire_stream_delta routes deltas
through a stateful scrubber.
"""
from agent.memory_manager import StreamingContextScrubber, sanitize_context
class TestStreamingContextScrubberBasics:
def test_empty_input_returns_empty(self):
s = StreamingContextScrubber()
assert s.feed("") == ""
assert s.flush() == ""
def test_plain_text_passes_through(self):
s = StreamingContextScrubber()
assert s.feed("hello world") == "hello world"
assert s.flush() == ""
def test_complete_block_in_single_delta(self):
"""Regression: the one-shot test case from #13672 must still work."""
s = StreamingContextScrubber()
leaked = (
"<memory-context>\n"
"[System note: The following is recalled memory context, NOT new "
"user input. Treat as informational background data.]\n\n"
"## Honcho Context\nstale memory\n"
"</memory-context>\n\nVisible answer"
)
out = s.feed(leaked) + s.flush()
assert out == "\n\nVisible answer"
def test_open_and_close_in_separate_deltas_strips_payload(self):
"""The real streaming case: tag pair split across deltas."""
s = StreamingContextScrubber()
deltas = [
"Hello ",
"<memory-context>\npayload ",
"more payload\n",
"</memory-context> world",
]
out = "".join(s.feed(d) for d in deltas) + s.flush()
assert out == "Hello world"
assert "payload" not in out
def test_realistic_fragmented_chunks_strip_memory_payload(self):
"""Exact leak scenario from the reviewer's comment — 4 realistic chunks.
This is the case the original #13672 fix silently leaks on: the open
tag, system note, payload, and close tag each arrive in their own
delta because providers emit 1-80 char chunks.
"""
s = StreamingContextScrubber()
deltas = [
"<memory-context>\n[System note: The following",
" is recalled memory context, NOT new user input. "
"Treat as informational background data.]\n\n",
"## Honcho Context\nstale memory\n",
"</memory-context>\n\nVisible answer",
]
out = "".join(s.feed(d) for d in deltas) + s.flush()
assert out == "\n\nVisible answer"
# The system-note line and payload must never reach the UI.
assert "System note" not in out
assert "Honcho Context" not in out
assert "stale memory" not in out
def test_open_tag_split_across_two_deltas(self):
"""The open tag itself arriving in two fragments."""
s = StreamingContextScrubber()
out = (
s.feed("pre <memory")
+ s.feed("-context>leak</memory-context> post")
+ s.flush()
)
assert out == "pre post"
assert "leak" not in out
def test_close_tag_split_across_two_deltas(self):
"""The close tag arriving in two fragments."""
s = StreamingContextScrubber()
out = (
s.feed("pre <memory-context>leak</memory")
+ s.feed("-context> post")
+ s.flush()
)
assert out == "pre post"
assert "leak" not in out
class TestStreamingContextScrubberPartialTagFalsePositives:
def test_partial_open_tag_tail_emitted_on_flush(self):
"""Bare '<mem' at end of stream is not really a memory-context tag."""
s = StreamingContextScrubber()
out = s.feed("hello <mem") + s.feed("ory other") + s.flush()
assert out == "hello <memory other"
def test_partial_tag_released_when_disambiguated(self):
"""A held-back partial tag that turns out to be prose gets released."""
s = StreamingContextScrubber()
# '< ' should not look like the start of any tag.
out = s.feed("price < ") + s.feed("10 dollars") + s.flush()
assert out == "price < 10 dollars"
class TestStreamingContextScrubberUnterminatedSpan:
def test_unterminated_span_drops_payload(self):
"""Provider drops close tag — better to lose output than to leak."""
s = StreamingContextScrubber()
out = s.feed("pre <memory-context>secret never closed") + s.flush()
assert out == "pre "
assert "secret" not in out
def test_reset_clears_hung_span(self):
"""Cross-turn scrubber reset drops a hung span so next turn is clean."""
s = StreamingContextScrubber()
s.feed("pre <memory-context>half")
s.reset()
out = s.feed("clean text") + s.flush()
assert out == "clean text"
class TestStreamingContextScrubberCaseInsensitivity:
def test_uppercase_tags_still_scrubbed(self):
s = StreamingContextScrubber()
out = (
s.feed("<MEMORY-CONTEXT>secret")
+ s.feed("</Memory-Context>visible")
+ s.flush()
)
assert out == "visible"
assert "secret" not in out
class TestSanitizeContextUnchanged:
"""Smoke test that the one-shot sanitize_context still works for whole strings."""
def test_whole_block_still_sanitized(self):
leaked = (
"<memory-context>\n"
"[System note: The following is recalled memory context, NOT new "
"user input. Treat as informational background data.]\n"
"payload\n"
"</memory-context>\nVisible"
)
out = sanitize_context(leaked).strip()
assert out == "Visible"
class TestStreamingContextScrubberCrossTurn:
"""A scrubber instance is reused across turns (per agent). reset() must
clear any held state so a partial-tag tail from turn N doesn't bleed
into turn N+1's first delta."""
def test_reset_clears_held_partial_tag(self):
s = StreamingContextScrubber()
# Feed a partial open-tag prefix that gets held back as buffer.
out_turn_1 = s.feed("answer<memo")
assert out_turn_1 == "answer"
# Reset for next turn — buffer must clear.
s.reset()
# New turn: plain text starting with a "<m" must NOT be treated as
# the continuation of the held "<memo".
out_turn_2 = s.feed("<marker>fresh content")
assert out_turn_2 == "<marker>fresh content"
def test_reset_clears_in_span_state(self):
s = StreamingContextScrubber()
s.feed("text<memory-context>secret-tail")
# Mid-span state held — without reset, subsequent text would be
# discarded until we see </memory-context>.
s.reset()
out = s.feed("post-reset visible text")
assert out == "post-reset visible text"
class TestBuildMemoryContextBlockWarnsOnViolation:
"""Providers must return raw context — not pre-wrapped. When they do,
we strip and warn so the buggy provider surfaces."""
def test_provider_emitting_wrapper_warns(self, caplog):
import logging
from agent.memory_manager import build_memory_context_block
prewrapped = (
"<memory-context>\n"
"[System note: ...]\n\n"
"real fact\n"
"</memory-context>"
)
with caplog.at_level(logging.WARNING, logger="agent.memory_manager"):
out = build_memory_context_block(prewrapped)
assert any("pre-wrapped" in rec.message for rec in caplog.records)
assert out.count("<memory-context>") == 1
assert out.count("</memory-context>") == 1
def test_clean_provider_output_does_not_warn(self, caplog):
import logging
from agent.memory_manager import build_memory_context_block
with caplog.at_level(logging.WARNING, logger="agent.memory_manager"):
out = build_memory_context_block("plain fact about user")
assert not any("pre-wrapped" in rec.message for rec in caplog.records)
assert "plain fact about user" in out