Compare commits

...

1 Commits

Author SHA1 Message Date
Teknium
3439c2958e fix(agent): include tool tokens in preflight estimate, guard context probe persistence
Two improvements salvaged from PR #2600 (paraddox):

1. Preflight compression now counts tool schema tokens alongside system
   prompt and messages.  With 50+ tools enabled, schemas can add 20-30K
   tokens that were previously invisible to the estimator, delaying
   compression until the API rejected the request.

2. Context probe persistence guard: when the agent steps down context
   tiers after a context-length error, only provider-confirmed numeric
   limits (parsed from the error message) are cached to disk.  Guessed
   fallback tiers from get_next_probe_tier() stay in-memory only,
   preventing wrong values from polluting the persistent cache.

Co-authored-by: paraddox <paraddox@users.noreply.github.com>
2026-03-26 02:00:20 -07:00
2 changed files with 52 additions and 10 deletions

View File

@@ -895,3 +895,26 @@ def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
"""Rough token estimate for a message list (pre-flight only).""" """Rough token estimate for a message list (pre-flight only)."""
total_chars = sum(len(str(msg)) for msg in messages) total_chars = sum(len(str(msg)) for msg in messages)
return total_chars // 4 return total_chars // 4
def estimate_request_tokens_rough(
messages: List[Dict[str, Any]],
*,
system_prompt: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
) -> int:
"""Rough token estimate for a full chat-completions request.
Includes the major payload buckets Hermes sends to providers:
system prompt, conversation messages, and tool schemas. With 50+
tools enabled, schemas alone can add 20-30K tokens — a significant
blind spot when only counting messages.
"""
total_chars = 0
if system_prompt:
total_chars += len(system_prompt)
if messages:
total_chars += sum(len(str(msg)) for msg in messages)
if tools:
total_chars += len(str(tools))
return total_chars // 4

View File

@@ -77,7 +77,7 @@ from agent.prompt_builder import (
) )
from agent.model_metadata import ( from agent.model_metadata import (
fetch_model_metadata, fetch_model_metadata,
estimate_tokens_rough, estimate_messages_tokens_rough, estimate_tokens_rough, estimate_messages_tokens_rough, estimate_request_tokens_rough,
get_next_probe_tier, parse_context_limit_from_error, get_next_probe_tier, parse_context_limit_from_error,
save_context_length, save_context_length,
) )
@@ -1133,6 +1133,7 @@ class AIAgent:
self.context_compressor.last_total_tokens = 0 self.context_compressor.last_total_tokens = 0
self.context_compressor.compression_count = 0 self.context_compressor.compression_count = 0
self.context_compressor._context_probed = False self.context_compressor._context_probed = False
self.context_compressor._context_probe_persistable = False
# Iterative summary from previous session must not bleed into new one (#2635) # Iterative summary from previous session must not bleed into new one (#2635)
self.context_compressor._previous_summary = None self.context_compressor._previous_summary = None
@@ -5820,9 +5821,13 @@ class AIAgent:
and len(messages) > self.context_compressor.protect_first_n and len(messages) > self.context_compressor.protect_first_n
+ self.context_compressor.protect_last_n + 1 + self.context_compressor.protect_last_n + 1
): ):
_sys_tok_est = estimate_tokens_rough(active_system_prompt or "") # Include tool schema tokens — with many tools these can add
_msg_tok_est = estimate_messages_tokens_rough(messages) # 20-30K+ tokens that the old sys+msg estimate missed entirely.
_preflight_tokens = _sys_tok_est + _msg_tok_est _preflight_tokens = estimate_request_tokens_rough(
messages,
system_prompt=active_system_prompt or "",
tools=self.tools or None,
)
if _preflight_tokens >= self.context_compressor.threshold_tokens: if _preflight_tokens >= self.context_compressor.threshold_tokens:
logger.info( logger.info(
@@ -5848,9 +5853,11 @@ class AIAgent:
if len(messages) >= _orig_len: if len(messages) >= _orig_len:
break # Cannot compress further break # Cannot compress further
# Re-estimate after compression # Re-estimate after compression
_sys_tok_est = estimate_tokens_rough(active_system_prompt or "") _preflight_tokens = estimate_request_tokens_rough(
_msg_tok_est = estimate_messages_tokens_rough(messages) messages,
_preflight_tokens = _sys_tok_est + _msg_tok_est system_prompt=active_system_prompt or "",
tools=self.tools or None,
)
if _preflight_tokens < self.context_compressor.threshold_tokens: if _preflight_tokens < self.context_compressor.threshold_tokens:
break # Under threshold break # Under threshold
@@ -6313,12 +6320,16 @@ class AIAgent:
} }
self.context_compressor.update_from_response(usage_dict) self.context_compressor.update_from_response(usage_dict)
# Cache discovered context length after successful call # Cache discovered context length after successful call.
# Only persist limits confirmed by the provider (parsed
# from the error message), not guessed probe tiers.
if self.context_compressor._context_probed: if self.context_compressor._context_probed:
ctx = self.context_compressor.context_length ctx = self.context_compressor.context_length
if getattr(self.context_compressor, "_context_probe_persistable", False):
save_context_length(self.model, self.base_url, ctx) save_context_length(self.model, self.base_url, ctx)
self._safe_print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}") self._safe_print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}")
self.context_compressor._context_probed = False self.context_compressor._context_probed = False
self.context_compressor._context_probe_persistable = False
self.session_prompt_tokens += prompt_tokens self.session_prompt_tokens += prompt_tokens
self.session_completion_tokens += completion_tokens self.session_completion_tokens += completion_tokens
@@ -6619,6 +6630,14 @@ class AIAgent:
compressor.context_length = new_ctx compressor.context_length = new_ctx
compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent) compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent)
compressor._context_probed = True compressor._context_probed = True
# Only persist limits parsed from the provider's
# error message (a real number). Guessed fallback
# tiers from get_next_probe_tier() should stay
# in-memory only — persisting them pollutes the
# cache with wrong values.
compressor._context_probe_persistable = bool(
parsed_limit and parsed_limit == new_ctx
)
self._vprint(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,}{new_ctx:,} tokens", force=True) self._vprint(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,}{new_ctx:,} tokens", force=True)
else: else:
self._vprint(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...", force=True) self._vprint(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...", force=True)