fix(flush_memories): always deduct headroom + resolve flush aux model + trim defence

Three fixes for flush_memories / compression context window overflow:

1. ALWAYS deduct headroom before comparing aux_context vs threshold.
   #15631 only deducted inside 'if aux_context < threshold' — which
   never fires in the common same-model case (threshold = context × 0.50
   means aux_context > threshold always). Now headroom is computed
   unconditionally and effective_limit = aux_context - headroom is
   compared against threshold.

2. Also resolve flush_memories auxiliary model in the feasibility check.
   If the user configures separate auxiliary.flush_memories provider,
   the flush model's smaller context was unchecked.

3. Defence-in-depth trimming in flush_memories() for CLI /new and
   gateway resets that bypass preflight compression entirely.
This commit is contained in:
kshitijk4poor
2026-04-25 19:31:27 +05:30
parent d635e2df3f
commit 25d9fc8094
3 changed files with 356 additions and 52 deletions

View File

@@ -2402,6 +2402,34 @@ class AIAgent:
provider=getattr(self, "provider", ""),
)
# Also resolve the flush_memories auxiliary model — it may differ
# from the compression model when the user configures separate
# auxiliary.flush_memories.provider/model, or when the fallback
# chain lands on a different provider. flush_memories runs with
# the FULL pre-compression conversation, so its model's context
# must also be respected.
try:
flush_client, flush_model = get_text_auxiliary_client(
"flush_memories",
main_runtime=self._current_main_runtime(),
)
if flush_client and flush_model:
_flush_ctx = get_model_context_length(
flush_model,
base_url=str(getattr(flush_client, "base_url", "") or ""),
api_key=str(getattr(flush_client, "api_key", "") or ""),
provider=getattr(self, "provider", ""),
)
if _flush_ctx and _flush_ctx < aux_context:
logger.info(
"flush_memories model %s context (%d) < compression "
"model %s context (%d) — using the smaller value",
flush_model, _flush_ctx, aux_model, aux_context,
)
aux_context = _flush_ctx
except Exception:
pass # Non-fatal — fall through with compression model's context
# Hard floor: the auxiliary compression model must have at least
# MINIMUM_CONTEXT_LENGTH (64K) tokens of context. The main model
# is already required to meet this floor (checked earlier in
@@ -2421,29 +2449,25 @@ class AIAgent:
)
threshold = self.context_compressor.threshold_tokens
if aux_context < threshold:
# Auto-correct: lower the live session threshold so
# compression actually works this session. The hard floor
# above guarantees aux_context >= MINIMUM_CONTEXT_LENGTH,
# so the new threshold is always >= 64K.
#
# Headroom: the threshold budgets RAW MESSAGES only, but the
# actual request auxiliary callers send also includes the
# system prompt and every tool schema. With 50+ tools that
# overhead can be 25-30K tokens; setting new_threshold =
# aux_context directly would let messages grow right to the
# aux limit and the first compression/flush request would
# overflow with HTTP 400. Subtract a dynamic headroom
# estimate so the full request still fits.
from agent.model_metadata import estimate_request_tokens_rough
tool_overhead = estimate_request_tokens_rough([], tools=self.tools)
# System prompt is not yet built at __init__ time; allow a
# conservative 10K budget (SOUL/AGENTS.md + memory snapshot +
# skills guidance) plus 2K for the flush instruction and a
# small safety margin.
headroom = tool_overhead + 12_000
# Headroom: the threshold budgets RAW MESSAGES only, but the
# actual request auxiliary callers (compression summariser and
# flush_memories) send also includes the system prompt and every
# tool schema. We must ensure threshold + headroom <= aux_context
# or the first compression/flush request will overflow.
#
# This applies even when aux_context > threshold (the common
# same-model case after a155b4a1) — e.g. 128K context, 85%
# threshold = 108K, 20K overhead → 108K + 20K = 128K exactly
# at the limit, and any token-estimate variance causes a 400.
from agent.model_metadata import estimate_request_tokens_rough
tool_overhead = estimate_request_tokens_rough([], tools=self.tools)
headroom = tool_overhead + 12_000
effective_limit = max(aux_context - headroom, MINIMUM_CONTEXT_LENGTH)
if effective_limit < threshold:
old_threshold = threshold
new_threshold = max(aux_context - headroom, MINIMUM_CONTEXT_LENGTH)
new_threshold = effective_limit
self.context_compressor.threshold_tokens = new_threshold
# Keep threshold_percent in sync so future main-model
# context_length changes (update_model) re-derive from a
@@ -7992,6 +8016,67 @@ class AIAgent:
messages.pop() # remove flush msg
return
# ── Defence-in-depth: trim messages to fit auxiliary context ──
#
# _check_compression_model_feasibility already lowers the
# compression threshold so conversations *triggered by preflight
# compression* should fit. But flush_memories is also called
# from CLI /new and gateway session resets — paths that bypass
# the preflight check entirely. Trim here as a safety net.
try:
from agent.auxiliary_client import get_text_auxiliary_client
from agent.model_metadata import (
get_model_context_length,
estimate_messages_tokens_rough,
)
_fc, _fm = get_text_auxiliary_client(
"flush_memories",
main_runtime=self._current_main_runtime(),
)
_fctx = 0
if _fc and _fm:
_fctx = get_model_context_length(
_fm,
base_url=str(getattr(_fc, "base_url", "") or ""),
api_key=str(getattr(_fc, "api_key", "") or ""),
provider=getattr(self, "provider", ""),
)
if not _fctx:
_fctx = getattr(
getattr(self, "context_compressor", None),
"context_length", 0,
)
if _fctx:
_budget = _fctx - 5120 - 500 # output + tool schema
if _budget > 0:
_est = estimate_messages_tokens_rough(api_messages)
if _est > _budget:
_sys = []
_conv = api_messages
if api_messages and api_messages[0].get("role") == "system":
_sys = [api_messages[0]]
_conv = api_messages[1:]
_rem = _budget - estimate_messages_tokens_rough(_sys)
_kept: list = []
_acc = 0
for _m in reversed(_conv):
_mt = estimate_messages_tokens_rough([_m])
if _acc + _mt > _rem:
break
_kept.append(_m)
_acc += _mt
_kept.reverse()
if len(_kept) < 3 and len(_conv) >= 3:
_kept = _conv[-3:]
api_messages = _sys + _kept
logger.info(
"flush_memories: trimmed %d%d msgs to fit "
"%d-token aux context",
len(_sys) + len(_conv), len(api_messages), _fctx,
)
except Exception as _te:
logger.debug("flush_memories: context trim failed: %s", _te)
# Use auxiliary client for the flush call when available --
# it's cheaper and avoids Codex Responses API incompatibility.
from agent.auxiliary_client import (

View File

@@ -151,15 +151,14 @@ def test_feasibility_check_passes_live_main_runtime():
agent._emit_status = lambda msg: None
agent._check_compression_model_feasibility()
mock_get_client.assert_called_once_with(
"compression",
main_runtime={
"model": "gpt-5.4",
"provider": "openai-codex",
# Called for both compression + flush_memories; verify compression call present
assert any(
c == (("compression",), {"main_runtime": {
"model": "gpt-5.4", "provider": "openai-codex",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_key": "codex-token",
"api_mode": "codex_responses",
},
"api_key": "codex-token", "api_mode": "codex_responses",
}})
for c in mock_get_client.call_args_list
)
@@ -179,12 +178,12 @@ def test_feasibility_check_passes_config_context_length(mock_get_client, mock_ct
agent._emit_status = lambda msg: None
agent._check_compression_model_feasibility()
mock_ctx_len.assert_called_once_with(
"custom/big-model",
base_url="http://custom-endpoint:8080/v1",
api_key="sk-custom",
config_context_length=1_000_000,
provider="openrouter",
# First call is the compression model
assert mock_ctx_len.call_args_list[0] == (
("custom/big-model",),
{"base_url": "http://custom-endpoint:8080/v1",
"api_key": "sk-custom", "config_context_length": 1_000_000,
"provider": "openrouter"},
)
@@ -202,12 +201,11 @@ def test_feasibility_check_ignores_invalid_context_length(mock_get_client, mock_
agent._emit_status = lambda msg: None
agent._check_compression_model_feasibility()
mock_ctx_len.assert_called_once_with(
"custom/model",
base_url="http://custom:8080/v1",
api_key="sk-test",
config_context_length=None,
provider="openrouter",
assert mock_ctx_len.call_args_list[0] == (
("custom/model",),
{"base_url": "http://custom:8080/v1",
"api_key": "sk-test", "config_context_length": None,
"provider": "openrouter"},
)
@@ -255,13 +253,10 @@ def test_init_feasibility_check_uses_aux_context_override_from_config():
)
assert agent._aux_compression_context_length_config == 1_000_000
mock_ctx_len.assert_called_once_with(
"custom/big-model",
base_url="http://custom-endpoint:8080/v1",
api_key="sk-custom",
config_context_length=1_000_000,
provider="",
)
c0 = mock_ctx_len.call_args_list[0]
assert c0.args == ("custom/big-model",)
assert c0.kwargs["base_url"] == "http://custom-endpoint:8080/v1"
assert c0.kwargs["config_context_length"] == 1_000_000
@patch("agent.auxiliary_client.get_text_auxiliary_client")
@@ -311,8 +306,10 @@ def test_exception_does_not_crash(mock_get_client):
@patch("agent.model_metadata.get_model_context_length", return_value=100_000)
@patch("agent.auxiliary_client.get_text_auxiliary_client")
def test_exact_threshold_boundary_no_warning(mock_get_client, mock_ctx_len):
"""No warning when aux context exactly equals the threshold."""
def test_exact_threshold_boundary_triggers_headroom_correction(mock_get_client, mock_ctx_len):
"""When aux context exactly equals the threshold, headroom deduction
still fires — flush_memories adds system prompt + tool schema on top
of the conversation messages, so threshold must be lowered."""
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
mock_client = MagicMock()
mock_client.base_url = "https://openrouter.ai/api/v1"
@@ -324,7 +321,10 @@ def test_exact_threshold_boundary_no_warning(mock_get_client, mock_ctx_len):
agent._check_compression_model_feasibility()
assert len(messages) == 0
# 100K - headroom < 100K → auto-corrects
assert len(messages) == 1
assert "Auto-lowered" in messages[0]
assert agent.context_compressor.threshold_tokens < 100_000
@patch("agent.model_metadata.get_model_context_length", return_value=99_999)

View File

@@ -0,0 +1,219 @@
"""Tests for flush_memories context-overflow prevention.
1. _check_compression_model_feasibility now also resolves the
flush_memories auxiliary model and uses min(compression, flush) as the
effective aux context.
2. Headroom is always deducted before comparing aux_context vs threshold
(not only when aux_context < threshold).
3. flush_memories() trims oversized conversations before the LLM call as
defence-in-depth for paths that bypass preflight compression.
"""
import sys
import types
from types import SimpleNamespace
from unittest.mock import patch, MagicMock
sys.modules.setdefault("fire", types.SimpleNamespace(Fire=lambda *a, **k: None))
sys.modules.setdefault("firecrawl", types.SimpleNamespace(Firecrawl=object))
sys.modules.setdefault("fal_client", types.SimpleNamespace())
import run_agent
# ── Helpers ──────────────────────────────────────────────────────────────
class _FakeOpenAI:
def __init__(self, **kw):
self.api_key = kw.get("api_key", "test")
self.base_url = kw.get("base_url", "http://test")
def close(self):
pass
def _make_agent(monkeypatch, **kw):
monkeypatch.setattr(run_agent, "get_tool_definitions", lambda **k: [
{"type": "function", "function": {
"name": "memory", "description": "m",
"parameters": {"type": "object", "properties": {
"action": {"type": "string"},
"target": {"type": "string"},
"content": {"type": "string"},
}},
}},
])
monkeypatch.setattr(run_agent, "check_toolset_requirements", lambda: {})
monkeypatch.setattr(run_agent, "OpenAI", _FakeOpenAI)
agent = run_agent.AIAgent(
api_key="test-key", base_url="https://test.example.com/v1",
provider=kw.get("provider", "openrouter"),
api_mode=kw.get("api_mode", "chat_completions"),
max_iterations=4, quiet_mode=True,
skip_context_files=True, skip_memory=True,
)
agent._memory_store = MagicMock()
agent._memory_flush_min_turns = 1
agent._user_turn_count = 5
return agent
def _make_msgs(n, chars=400):
return [{"role": "user" if i % 2 == 0 else "assistant",
"content": f"M{i}: " + "x" * max(0, chars - 6)}
for i in range(n)]
def _noop_response():
return SimpleNamespace(
choices=[SimpleNamespace(
finish_reason="stop",
message=SimpleNamespace(content="Nothing.", tool_calls=None),
)],
usage=SimpleNamespace(prompt_tokens=50, completion_tokens=10, total_tokens=60),
)
# ── Feasibility: flush model + always-deduct headroom ────────────────────
class TestFeasibilityFixes:
def test_smaller_flush_model_lowers_effective_context(self, monkeypatch):
"""flush_memories model with smaller context drives the threshold."""
agent = _make_agent(monkeypatch)
agent.context_compressor.context_length = 200_000
agent.context_compressor.threshold_tokens = 100_000
fc = SimpleNamespace(base_url="http://test", api_key="k")
def _aux(task, **kw):
if task == "compression":
return fc, "big-model"
return fc, "small-flush-model"
def _ctx(model, **kw):
return 200_000 if model == "big-model" else 80_000
with patch("agent.auxiliary_client.get_text_auxiliary_client", side_effect=_aux), \
patch("agent.model_metadata.get_model_context_length", side_effect=_ctx):
agent._check_compression_model_feasibility()
assert agent.context_compressor.threshold_tokens < 100_000
def test_same_model_overhead_still_triggers_correction(self, monkeypatch):
"""The primary bug: aux == main model, aux_context > threshold, but
threshold + overhead > aux_context. Headroom must fire even when
aux_context >= threshold."""
agent = _make_agent(monkeypatch)
agent.context_compressor.context_length = 128_000
agent.context_compressor.threshold_tokens = 120_000
fc = SimpleNamespace(base_url="http://test", api_key="k")
with patch("agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fc, "same-model")), \
patch("agent.model_metadata.get_model_context_length",
return_value=128_000):
agent._check_compression_model_feasibility()
# 128K - headroom (~12.1K) ≈ 115.9K < 120K → threshold lowered
assert agent.context_compressor.threshold_tokens < 120_000
def test_flush_resolution_failure_is_non_fatal(self, monkeypatch):
"""If flush model resolution raises, check proceeds with compression model."""
agent = _make_agent(monkeypatch)
agent.context_compressor.context_length = 200_000
agent.context_compressor.threshold_tokens = 100_000
fc = SimpleNamespace(base_url="http://test", api_key="k")
n = [0]
def _aux(task, **kw):
n[0] += 1
if task == "flush_memories":
raise RuntimeError("boom")
return fc, "model"
with patch("agent.auxiliary_client.get_text_auxiliary_client", side_effect=_aux), \
patch("agent.model_metadata.get_model_context_length", return_value=200_000):
agent._check_compression_model_feasibility()
assert n[0] == 2 # both tasks attempted
# ── flush_memories trimming ──────────────────────────────────────────────
class TestFlushMemoriesTrimming:
def test_oversized_conversation_trimmed(self, monkeypatch):
agent = _make_agent(monkeypatch)
agent._cached_system_prompt = "System."
messages = _make_msgs(200, chars=500)
fc = SimpleNamespace(base_url="http://test", api_key="k")
with patch("agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fc, "small")), \
patch("agent.model_metadata.get_model_context_length",
return_value=8_000), \
patch("agent.auxiliary_client.call_llm",
return_value=_noop_response()) as mock:
agent.flush_memories(messages)
sent = mock.call_args.kwargs.get("messages", [])
assert len(sent) < 100
def test_small_conversation_untouched(self, monkeypatch):
agent = _make_agent(monkeypatch)
agent._cached_system_prompt = "System."
messages = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hey"},
{"role": "user", "content": "Save"},
]
fc = SimpleNamespace(base_url="http://test", api_key="k")
with patch("agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fc, "big")), \
patch("agent.model_metadata.get_model_context_length",
return_value=200_000), \
patch("agent.auxiliary_client.call_llm",
return_value=_noop_response()) as mock:
agent.flush_memories(messages)
sent = mock.call_args.kwargs.get("messages", [])
assert len(sent) == 5 # sys + 3 conv + flush
def test_trim_failure_does_not_block_flush(self, monkeypatch):
agent = _make_agent(monkeypatch)
messages = _make_msgs(10, chars=100)
with patch("agent.auxiliary_client.get_text_auxiliary_client",
side_effect=RuntimeError("no provider")), \
patch("agent.auxiliary_client.call_llm",
return_value=_noop_response()) as mock:
agent.flush_memories(messages)
assert mock.called
def test_sentinel_cleaned_after_trim(self, monkeypatch):
agent = _make_agent(monkeypatch)
messages = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hey"},
{"role": "user", "content": "Save"},
]
n = len(messages)
fc = SimpleNamespace(base_url="http://test", api_key="k")
with patch("agent.auxiliary_client.get_text_auxiliary_client",
return_value=(fc, "m")), \
patch("agent.model_metadata.get_model_context_length",
return_value=128_000), \
patch("agent.auxiliary_client.call_llm",
return_value=_noop_response()):
agent.flush_memories(messages)
assert len(messages) == n
assert not any(m.get("_flush_sentinel") for m in messages)