fix(agent): tiered context pressure warnings + gateway dedup (#6411)

Combines the approaches from PR #6309 (duan78) and PR #5963 (KUSH42):

Tiered warnings (from #5963):
- Replaces boolean _context_pressure_warned with float _context_pressure_warned_at
- Fires at 85% (orange) and re-fires at 95% (red/critical)
- Adds 'compacting context...' status message before compression

Gateway dedup (from #6309):
- Class-level dict _context_pressure_last_warned survives across AIAgent
  instances (gateway creates a new instance per message)
- 5-minute cooldown per session prevents warning spam
- Higher-tier warnings bypass the cooldown (85% → 95% always fires)
- Compression reset clears the dedup entry for the session
- Stale entries evicted (older than 2x cooldown) to prevent memory leak

Does NOT inject into messages — purely user-facing via _safe_print (CLI)
and status_callback (gateway). Zero prompt cache impact.

Fixes #6309. Fixes #5963.
This commit is contained in:
Teknium
2026-04-08 21:31:44 -07:00
committed by GitHub
parent ffeaf6ffae
commit 54db7cbbe1
2 changed files with 155 additions and 10 deletions

View File

@@ -442,6 +442,13 @@ class AIAgent:
for AI models that support function calling.
"""
# ── Class-level context pressure dedup (survives across instances) ──
# The gateway creates a new AIAgent per message, so instance-level flags
# reset every time. This dict tracks {session_id: (warn_level, timestamp)}
# to suppress duplicate warnings within a cooldown window.
_context_pressure_last_warned: dict = {}
_CONTEXT_PRESSURE_COOLDOWN = 300 # seconds between re-warning same session
@property
def base_url(self) -> str:
return self._base_url
@@ -673,7 +680,8 @@ class AIAgent:
# Context pressure warnings: notify the USER (not the LLM) as context
# fills up. Purely informational — displayed in CLI output and sent via
# status_callback for gateway platforms. Does NOT inject into messages.
self._context_pressure_warned = False
# Tiered: fires at 85% and again at 95% of compaction threshold.
self._context_pressure_warned_at = 0.0 # highest tier already shown
# Activity tracking — updated on each API call, tool execution, and
# stream chunk. Used by the gateway timeout handler to report what the
@@ -6034,12 +6042,16 @@ class AIAgent:
# Only reset the pressure warning if compression actually brought
# us below the warning level (85% of threshold). When compression
# can't reduce enough (e.g. threshold is very low, or system prompt
# alone exceeds the warning level), keep the flag set to prevent
# alone exceeds the warning level), keep the tier set to prevent
# spamming the user with repeated warnings every loop iteration.
if self.context_compressor.threshold_tokens > 0:
_post_progress = _compressed_est / self.context_compressor.threshold_tokens
if _post_progress < 0.85:
self._context_pressure_warned = False
self._context_pressure_warned_at = 0.0
# Clear class-level dedup for this session so a fresh
# warning cycle can start if context grows again.
_sid = self.session_id or "default"
AIAgent._context_pressure_last_warned.pop(_sid, None)
# Clear the file-read dedup cache. After compression the original
# read content is summarised away — if the model re-reads the same
@@ -8979,13 +8991,34 @@ class AIAgent:
# compaction fires, not the raw context window.
# Does not inject into messages — just prints to CLI output
# and fires status_callback for gateway platforms.
# Tiered: 85% (orange) and 95% (red/critical).
if _compressor.threshold_tokens > 0:
_compaction_progress = _real_tokens / _compressor.threshold_tokens
if _compaction_progress >= 0.85 and not self._context_pressure_warned:
self._context_pressure_warned = True
self._emit_context_pressure(_compaction_progress, _compressor)
# Determine the warning tier for this progress level
_warn_tier = 0.0
if _compaction_progress >= 0.95:
_warn_tier = 0.95
elif _compaction_progress >= 0.85:
_warn_tier = 0.85
if _warn_tier > self._context_pressure_warned_at:
# Class-level dedup: check if this session was already
# warned at this tier within the cooldown window.
_sid = self.session_id or "default"
_last = AIAgent._context_pressure_last_warned.get(_sid)
_now = time.time()
if _last is None or _last[0] < _warn_tier or (_now - _last[1]) >= self._CONTEXT_PRESSURE_COOLDOWN:
self._context_pressure_warned_at = _warn_tier
AIAgent._context_pressure_last_warned[_sid] = (_warn_tier, _now)
self._emit_context_pressure(_compaction_progress, _compressor)
# Evict stale entries (older than 2x cooldown)
_cutoff = _now - self._CONTEXT_PRESSURE_COOLDOWN * 2
AIAgent._context_pressure_last_warned = {
k: v for k, v in AIAgent._context_pressure_last_warned.items()
if v[1] > _cutoff
}
if self.compression_enabled and _compressor.should_compress(_real_tokens):
self._safe_print(" ⟳ compacting context…")
messages, active_system_prompt = self._compress_context(
messages, system_message,
approx_tokens=self.context_compressor.last_prompt_tokens,

View File

@@ -150,8 +150,8 @@ def agent():
class TestContextPressureFlags:
"""Context pressure warning flag tracking on AIAgent."""
def test_flag_initialized_false(self, agent):
assert agent._context_pressure_warned is False
def test_flag_initialized_zero(self, agent):
assert agent._context_pressure_warned_at == 0.0
def test_emit_calls_status_callback(self, agent):
"""status_callback should be invoked with event type and message."""
@@ -210,7 +210,7 @@ class TestContextPressureFlags:
def test_flag_reset_on_compression(self, agent):
"""After _compress_context, context pressure flag should reset."""
agent._context_pressure_warned = True
agent._context_pressure_warned_at = 0.85
agent.compression_enabled = True
agent.context_compressor = MagicMock()
@@ -234,7 +234,7 @@ class TestContextPressureFlags:
]
agent._compress_context(messages, "system prompt")
assert agent._context_pressure_warned is False
assert agent._context_pressure_warned_at == 0.0
def test_emit_callback_error_handled(self, agent):
"""If status_callback raises, it should be caught gracefully."""
@@ -247,3 +247,115 @@ class TestContextPressureFlags:
# Should not raise
agent._emit_context_pressure(0.85, compressor)
def test_tiered_reemits_at_95(self, agent):
"""Warning fires at 85%, then fires again when crossing 95%."""
agent._context_pressure_warned_at = 0.85
# Simulate crossing 95%: the tier (0.95) > warned_at (0.85)
assert 0.95 > agent._context_pressure_warned_at
# After emission at 95%, the tier should update
agent._context_pressure_warned_at = 0.95
assert agent._context_pressure_warned_at == 0.95
def test_tiered_no_double_emit_at_same_level(self, agent):
"""Once warned at 85%, further 85%+ readings don't re-warn."""
agent._context_pressure_warned_at = 0.85
# At 88%, tier is 0.85, which is NOT > warned_at (0.85)
_warn_tier = 0.85 if 0.88 >= 0.85 else 0.0
assert not (_warn_tier > agent._context_pressure_warned_at)
def test_flag_not_reset_when_compression_insufficient(self, agent):
"""When compression can't drop below 85%, keep the flag set."""
agent._context_pressure_warned_at = 0.85
agent.compression_enabled = True
agent.context_compressor = MagicMock()
agent.context_compressor.compress.return_value = [
{"role": "user", "content": "Summary of conversation so far."}
]
agent.context_compressor.context_length = 200
# Use a small threshold so the tiny compressed output still
# represents >= 85% of it (prevents flag reset).
agent.context_compressor.threshold_tokens = 10
agent.context_compressor.compression_count = 1
agent.context_compressor.last_prompt_tokens = 0
agent._todo_store = MagicMock()
agent._todo_store.format_for_injection.return_value = None
agent._build_system_prompt = MagicMock(return_value="system prompt")
agent._cached_system_prompt = "old system prompt"
agent._session_db = None
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
]
agent._compress_context(messages, "system prompt")
# Post-compression is ~90% of threshold — flag should NOT reset
assert agent._context_pressure_warned_at == 0.85
class TestContextPressureGatewayDedup:
"""Class-level dedup prevents warning spam across AIAgent instances."""
def setup_method(self):
"""Clear class-level dedup state between tests."""
AIAgent._context_pressure_last_warned.clear()
def test_second_instance_within_cooldown_suppressed(self):
"""Same session, same tier, within cooldown — should be suppressed."""
import time
sid = "test_session_dedup"
# Simulate first warning
AIAgent._context_pressure_last_warned[sid] = (0.85, time.time())
# Second instance checking same tier within cooldown
_last = AIAgent._context_pressure_last_warned.get(sid)
_should_warn = _last is None or _last[0] < 0.85 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN
assert not _should_warn
def test_higher_tier_fires_despite_cooldown(self):
"""Same session, higher tier — should fire even within cooldown."""
import time
sid = "test_session_tier"
AIAgent._context_pressure_last_warned[sid] = (0.85, time.time())
_last = AIAgent._context_pressure_last_warned.get(sid)
# 0.95 > 0.85 stored tier → should warn
_should_warn = _last is None or _last[0] < 0.95 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN
assert _should_warn
def test_warning_fires_after_cooldown_expires(self):
"""Same session, same tier, after cooldown — should fire again."""
import time
sid = "test_session_expired"
# Set a timestamp far in the past
AIAgent._context_pressure_last_warned[sid] = (0.85, time.time() - AIAgent._CONTEXT_PRESSURE_COOLDOWN - 1)
_last = AIAgent._context_pressure_last_warned.get(sid)
_should_warn = _last is None or _last[0] < 0.85 or (time.time() - _last[1]) >= AIAgent._CONTEXT_PRESSURE_COOLDOWN
assert _should_warn
def test_compression_clears_dedup(self):
"""After compression drops below 85%, dedup entry should be cleared."""
import time
sid = "test_session_clear"
AIAgent._context_pressure_last_warned[sid] = (0.85, time.time())
assert sid in AIAgent._context_pressure_last_warned
# Simulate what _compress_context does on reset
AIAgent._context_pressure_last_warned.pop(sid, None)
assert sid not in AIAgent._context_pressure_last_warned
def test_eviction_removes_stale_entries(self):
"""Stale entries older than 2x cooldown should be evicted."""
import time
_now = time.time()
AIAgent._context_pressure_last_warned = {
"fresh": (0.85, _now),
"stale": (0.85, _now - AIAgent._CONTEXT_PRESSURE_COOLDOWN * 3),
}
_cutoff = _now - AIAgent._CONTEXT_PRESSURE_COOLDOWN * 2
AIAgent._context_pressure_last_warned = {
k: v for k, v in AIAgent._context_pressure_last_warned.items()
if v[1] > _cutoff
}
assert "fresh" in AIAgent._context_pressure_last_warned
assert "stale" not in AIAgent._context_pressure_last_warned