Compare commits

..

1 Commits

Author SHA1 Message Date
teknium1
67cf37fc26 fix: head+tail truncation for execute_code stdout (inspired by openclaw context-pruning)
Previously, _drain() only captured the first MAX_STDOUT_BYTES (50KB) of
stdout, silently dropping all tail output. Scripts that print() their
final results at the end would have those results lost.

Now uses a two-buffer approach: 40% head + 60% tail (rolling window).
This matches the pattern already used in terminal_tool.py (line 1042-1051)
but gives the tail more space since execute_code scripts typically
print() their final results at the end.

Inspired by openclaw's softTrim context-pruning (headChars/tailChars).
2026-03-09 02:15:48 -07:00
5 changed files with 120 additions and 100 deletions

View File

@@ -89,8 +89,6 @@ DEFAULT_CONFIG = {
"threshold": 0.85,
"summary_model": "google/gemini-3-flash-preview",
"summary_provider": "auto",
"protect_first_n": 3, # Number of initial turns to always preserve during compression
"protect_last_n": 4, # Number of recent turns to always preserve during compression
},
# Auxiliary model overrides (advanced). By default Hermes auto-selects
@@ -168,7 +166,7 @@ DEFAULT_CONFIG = {
"command_allowlist": [],
# Config schema version - bump this when adding new required fields
"_config_version": 6,
"_config_version": 5,
}
# =============================================================================

View File

@@ -582,32 +582,16 @@ class AIAgent:
# Initialize context compressor for automatic context management
# Compresses conversation when approaching model's context limit
# Configuration via config.yaml (compression section) or environment variables
compression_config = {}
try:
from hermes_cli.config import load_config as _load_compression_config
compression_config = _load_compression_config().get("compression", {})
except Exception:
pass
compression_threshold = float(os.getenv(
"CONTEXT_COMPRESSION_THRESHOLD",
str(compression_config.get("threshold", 0.85))
))
compression_threshold = float(os.getenv("CONTEXT_COMPRESSION_THRESHOLD", "0.85"))
compression_enabled = os.getenv("CONTEXT_COMPRESSION_ENABLED", "true").lower() in ("true", "1", "yes")
if not compression_enabled:
compression_enabled = compression_config.get("enabled", True)
compression_summary_model = os.getenv("CONTEXT_COMPRESSION_MODEL") or compression_config.get("summary_model") or None
# Configurable turn protection (clamped 0-12, inspired by openclaw recentTurnsPreserve)
protect_first_n = max(0, min(12, int(compression_config.get("protect_first_n", 3))))
protect_last_n = max(0, min(12, int(compression_config.get("protect_last_n", 4))))
compression_summary_model = os.getenv("CONTEXT_COMPRESSION_MODEL") or None
self.context_compressor = ContextCompressor(
model=self.model,
threshold_percent=compression_threshold,
protect_first_n=protect_first_n,
protect_last_n=protect_last_n,
summary_target_tokens=2500,
protect_first_n=3,
protect_last_n=4,
summary_target_tokens=500,
summary_model_override=compression_summary_model,
quiet_mode=self.quiet_mode,
base_url=self.base_url,

View File

@@ -1,68 +0,0 @@
"""Tests for configurable compaction protection turns."""
import unittest
from unittest.mock import patch, MagicMock
class TestCompressionConfigDefaults(unittest.TestCase):
"""Verify DEFAULT_CONFIG includes protect_first_n / protect_last_n."""
def test_default_config_has_protection_fields(self):
from hermes_cli.config import DEFAULT_CONFIG
compression = DEFAULT_CONFIG["compression"]
self.assertIn("protect_first_n", compression)
self.assertIn("protect_last_n", compression)
def test_default_values(self):
from hermes_cli.config import DEFAULT_CONFIG
compression = DEFAULT_CONFIG["compression"]
self.assertEqual(compression["protect_first_n"], 3)
self.assertEqual(compression["protect_last_n"], 4)
def test_config_version_bumped(self):
from hermes_cli.config import DEFAULT_CONFIG
self.assertGreaterEqual(DEFAULT_CONFIG["_config_version"], 6)
class TestContextCompressorAcceptsConfig(unittest.TestCase):
"""Verify ContextCompressor properly receives custom protection values."""
@patch("agent.context_compressor.get_text_auxiliary_client")
def test_custom_protection_values(self, mock_aux):
mock_aux.return_value = (None, "test-model")
from agent.context_compressor import ContextCompressor
compressor = ContextCompressor(
model="test/model",
protect_first_n=5,
protect_last_n=8,
)
self.assertEqual(compressor.protect_first_n, 5)
self.assertEqual(compressor.protect_last_n, 8)
@patch("agent.context_compressor.get_text_auxiliary_client")
def test_default_protection_values(self, mock_aux):
mock_aux.return_value = (None, "test-model")
from agent.context_compressor import ContextCompressor
compressor = ContextCompressor(model="test/model")
self.assertEqual(compressor.protect_first_n, 3)
self.assertEqual(compressor.protect_last_n, 4)
class TestProtectionClamping(unittest.TestCase):
"""Verify protection values are clamped to 0-12 range."""
def test_clamp_negative_to_zero(self):
val = max(0, min(12, -5))
self.assertEqual(val, 0)
def test_clamp_over_max_to_twelve(self):
val = max(0, min(12, 50))
self.assertEqual(val, 12)
def test_valid_value_unchanged(self):
val = max(0, min(12, 7))
self.assertEqual(val, 7)
if __name__ == "__main__":
unittest.main()

View File

@@ -393,5 +393,56 @@ class TestStubSchemaDrift(unittest.TestCase):
self.assertIn("mode", src)
class TestHeadTailTruncation(unittest.TestCase):
"""Tests for head+tail truncation of large stdout in execute_code."""
def _run(self, code):
with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
result = execute_code(
code=code,
task_id="test-task",
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
)
return json.loads(result)
def test_short_output_not_truncated(self):
"""Output under MAX_STDOUT_BYTES should not be truncated."""
result = self._run('print("small output")')
self.assertEqual(result["status"], "success")
self.assertIn("small output", result["output"])
self.assertNotIn("TRUNCATED", result["output"])
def test_large_output_preserves_head_and_tail(self):
"""Output exceeding MAX_STDOUT_BYTES keeps both head and tail."""
code = '''
# Print HEAD marker, then filler, then TAIL marker
print("HEAD_MARKER_START")
for i in range(15000):
print(f"filler_line_{i:06d}_padding_to_fill_buffer")
print("TAIL_MARKER_END")
'''
result = self._run(code)
self.assertEqual(result["status"], "success")
output = result["output"]
# Head should be preserved
self.assertIn("HEAD_MARKER_START", output)
# Tail should be preserved (this is the key improvement)
self.assertIn("TAIL_MARKER_END", output)
# Truncation notice should be present
self.assertIn("TRUNCATED", output)
def test_truncation_notice_format(self):
"""Truncation notice includes character counts."""
code = '''
for i in range(15000):
print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx")
'''
result = self._run(code)
output = result["output"]
if "TRUNCATED" in output:
self.assertIn("chars omitted", output)
self.assertIn("total", output)
if __name__ == "__main__":
unittest.main()

View File

@@ -457,11 +457,17 @@ def execute_code(
# --- Poll loop: watch for exit, timeout, and interrupt ---
deadline = time.monotonic() + timeout
stdout_chunks: list = []
stderr_chunks: list = []
# Background readers to avoid pipe buffer deadlocks
# Background readers to avoid pipe buffer deadlocks.
# For stdout we use a head+tail strategy: keep the first HEAD_BYTES
# and a rolling window of the last TAIL_BYTES so the final print()
# output is never lost. Stderr keeps head-only (errors appear early).
_STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head
_STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail
def _drain(pipe, chunks, max_bytes):
"""Simple head-only drain (used for stderr)."""
total = 0
try:
while True:
@@ -475,8 +481,48 @@ def execute_code(
except (ValueError, OSError):
pass
stdout_total_bytes = [0] # mutable ref for total bytes seen
def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref):
"""Drain stdout keeping both head and tail data."""
head_collected = 0
from collections import deque
tail_buf = deque()
tail_collected = 0
try:
while True:
data = pipe.read(4096)
if not data:
break
total_ref[0] += len(data)
# Fill head buffer first
if head_collected < head_bytes:
keep = min(len(data), head_bytes - head_collected)
head_chunks.append(data[:keep])
head_collected += keep
data = data[keep:] # remaining goes to tail
if not data:
continue
# Everything past head goes into rolling tail buffer
tail_buf.append(data)
tail_collected += len(data)
# Evict old tail data to stay within tail_bytes budget
while tail_collected > tail_bytes and tail_buf:
oldest = tail_buf.popleft()
tail_collected -= len(oldest)
except (ValueError, OSError):
pass
# Transfer final tail to output list
tail_chunks.extend(tail_buf)
stdout_head_chunks: list = []
stdout_tail_chunks: list = []
stdout_reader = threading.Thread(
target=_drain, args=(proc.stdout, stdout_chunks, MAX_STDOUT_BYTES), daemon=True
target=_drain_head_tail,
args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks,
_STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes),
daemon=True
)
stderr_reader = threading.Thread(
target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True
@@ -500,12 +546,21 @@ def execute_code(
stdout_reader.join(timeout=3)
stderr_reader.join(timeout=3)
stdout_text = b"".join(stdout_chunks).decode("utf-8", errors="replace")
stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace")
stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace")
stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace")
# Truncation notice
if len(stdout_text) >= MAX_STDOUT_BYTES:
stdout_text = stdout_text[:MAX_STDOUT_BYTES] + "\n[output truncated at 50KB]"
# Assemble stdout with head+tail truncation
total_stdout = stdout_total_bytes[0]
if total_stdout > MAX_STDOUT_BYTES and stdout_tail:
omitted = total_stdout - len(stdout_head) - len(stdout_tail)
truncated_notice = (
f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted "
f"out of {total_stdout:,} total] ...\n\n"
)
stdout_text = stdout_head + truncated_notice + stdout_tail
else:
stdout_text = stdout_head + stdout_tail
exit_code = proc.returncode if proc.returncode is not None else -1
duration = round(time.monotonic() - exec_start, 2)