Compare commits

...

1 Commits

Author SHA1 Message Date
teknium1
885c5dc5e6 feat: context window usage warnings at 80% and 95%
Adds one-time warnings when context usage crosses critical thresholds:
- 80%: suggests /compress or /new if responses degrade
- 95%: warns of imminent errors/truncation, suggests /new

Each threshold fires at most once per session to avoid spam.
Warnings show actual token counts and percentage. Suppressed for
subagents (delegate_depth > 0) where the user can't act on them.
Always shown in CLI mode regardless of quiet_mode setting.

Inspired by OpenCode PR #152 (context window warning).

Bug fix found during live testing:
- Anthropic prompt caching reports input tokens across three fields
  (input_tokens, cache_read_input_tokens, cache_creation_input_tokens).
  The existing code only counted input_tokens, causing the context
  compressor to see ~0 tokens when caching was active. Fixed by summing
  all three fields. This also fixes context % display in the status bar
  for Anthropic users.

Changes:
- agent/context_compressor.py: add check_context_warning() with
  _warned_80/_warned_95 state tracking
- run_agent.py: call check_context_warning() after each API response,
  fix Anthropic cached token counting
- tests/test_context_warning.py: 8 tests covering thresholds,
  one-shot behavior, escalation, edge cases

Live tested with:
- Nous Portal (chat_completions mode) ✔
- Anthropic direct (anthropic_messages mode) ✔
- Interactive CLI session ✔
2026-03-16 06:30:04 -07:00
3 changed files with 112 additions and 0 deletions

View File

@@ -65,12 +65,41 @@ class ContextCompressor:
self.summary_model = summary_model_override or ""
# Context usage warning thresholds (fire once each per session)
self._warned_80 = False
self._warned_95 = False
def update_from_response(self, usage: Dict[str, Any]):
"""Update tracked token usage from API response."""
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
self.last_completion_tokens = usage.get("completion_tokens", 0)
self.last_total_tokens = usage.get("total_tokens", 0)
def check_context_warning(self) -> str | None:
"""Return a warning string if context usage crossed a threshold, else None.
Each threshold fires at most once per session to avoid spam.
"""
if not self.context_length or not self.last_prompt_tokens:
return None
pct = self.last_prompt_tokens / self.context_length
if pct >= 0.95 and not self._warned_95:
self._warned_95 = True
used = f"{self.last_prompt_tokens:,}"
total = f"{self.context_length:,}"
return (
f"⚠ Context nearly exhausted ({used}/{total} tokens, {pct:.0%}). "
f"Risk of errors or truncation. Use /new to start fresh."
)
if pct >= 0.80 and not self._warned_80:
self._warned_80 = True
used = f"{self.last_prompt_tokens:,}"
total = f"{self.context_length:,}"
return (
f"⚠ Context window {pct:.0%} full ({used}/{total} tokens). "
f"Consider /compress or /new if responses degrade."
)
def should_compress(self, prompt_tokens: int = None) -> bool:
"""Check if context exceeds the compression threshold."""
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens

View File

@@ -5102,6 +5102,10 @@ class AIAgent:
if hasattr(response, 'usage') and response.usage:
if self.api_mode in ("codex_responses", "anthropic_messages"):
prompt_tokens = getattr(response.usage, 'input_tokens', 0) or 0
# Include cached input tokens for accurate context tracking
# (Anthropic reports non-cached, cache-read, and cache-creation separately)
prompt_tokens += getattr(response.usage, 'cache_read_input_tokens', 0) or 0
prompt_tokens += getattr(response.usage, 'cache_creation_input_tokens', 0) or 0
completion_tokens = getattr(response.usage, 'output_tokens', 0) or 0
total_tokens = (
getattr(response.usage, 'total_tokens', None)
@@ -5118,6 +5122,15 @@ class AIAgent:
}
self.context_compressor.update_from_response(usage_dict)
# Emit one-time warnings when context crosses 80% or 95%.
# Always show these (even in quiet_mode) — they're critical
# user-facing alerts, not debug noise. Only suppress for
# subagents (delegate_depth > 0) where the user can't act.
if getattr(self, '_delegate_depth', 0) == 0:
_ctx_warning = self.context_compressor.check_context_warning()
if _ctx_warning:
print(f"\n{_ctx_warning}\n")
# Cache discovered context length after successful call
if self.context_compressor._context_probed:
ctx = self.context_compressor.context_length

View File

@@ -0,0 +1,70 @@
"""Tests for context window usage warnings."""
from agent.context_compressor import ContextCompressor
class TestContextWarning:
def _make_compressor(self, context_length=200_000):
c = ContextCompressor(model="test/model", threshold_percent=0.50)
c.context_length = context_length
c.threshold_tokens = int(context_length * 0.50)
return c
def test_no_warning_below_80_percent(self):
c = self._make_compressor()
c.update_from_response({"prompt_tokens": 100_000}) # 50%
assert c.check_context_warning() is None
def test_warning_at_80_percent(self):
c = self._make_compressor()
c.update_from_response({"prompt_tokens": 160_000}) # 80%
warning = c.check_context_warning()
assert warning is not None
assert "80%" in warning
assert "/compress" in warning
def test_warning_at_95_percent(self):
c = self._make_compressor()
c.update_from_response({"prompt_tokens": 190_000}) # 95%
warning = c.check_context_warning()
assert warning is not None
assert "95%" in warning
assert "/new" in warning
def test_warning_fires_only_once_per_threshold(self):
c = self._make_compressor()
c.update_from_response({"prompt_tokens": 170_000}) # 85%
w1 = c.check_context_warning()
assert w1 is not None # First time at 80%
c.update_from_response({"prompt_tokens": 175_000}) # Still above 80%
w2 = c.check_context_warning()
assert w2 is None # Already warned
def test_95_fires_after_80_already_warned(self):
c = self._make_compressor()
c.update_from_response({"prompt_tokens": 165_000}) # 82.5%
w1 = c.check_context_warning()
assert w1 is not None
assert "82%" in w1 or "Context window" in w1
c.update_from_response({"prompt_tokens": 195_000}) # 97.5%
w2 = c.check_context_warning()
assert w2 is not None
assert "nearly exhausted" in w2 # Escalated warning
def test_no_warning_when_context_length_zero(self):
c = self._make_compressor(context_length=0)
c.update_from_response({"prompt_tokens": 100_000})
assert c.check_context_warning() is None
def test_no_warning_when_no_tokens(self):
c = self._make_compressor()
assert c.check_context_warning() is None
def test_warning_includes_token_counts(self):
c = self._make_compressor(context_length=100_000)
c.update_from_response({"prompt_tokens": 85_000})
warning = c.check_context_warning()
assert "85,000" in warning
assert "100,000" in warning