mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-29 23:05:20 +08:00
Compare commits
1 Commits
feat/confi
...
feat/head-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67cf37fc26 |
@@ -89,8 +89,6 @@ DEFAULT_CONFIG = {
|
||||
"threshold": 0.85,
|
||||
"summary_model": "google/gemini-3-flash-preview",
|
||||
"summary_provider": "auto",
|
||||
"protect_first_n": 3, # Number of initial turns to always preserve during compression
|
||||
"protect_last_n": 4, # Number of recent turns to always preserve during compression
|
||||
},
|
||||
|
||||
# Auxiliary model overrides (advanced). By default Hermes auto-selects
|
||||
@@ -168,7 +166,7 @@ DEFAULT_CONFIG = {
|
||||
"command_allowlist": [],
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 6,
|
||||
"_config_version": 5,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
|
||||
28
run_agent.py
28
run_agent.py
@@ -582,32 +582,16 @@ class AIAgent:
|
||||
# Initialize context compressor for automatic context management
|
||||
# Compresses conversation when approaching model's context limit
|
||||
# Configuration via config.yaml (compression section) or environment variables
|
||||
compression_config = {}
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_compression_config
|
||||
compression_config = _load_compression_config().get("compression", {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
compression_threshold = float(os.getenv(
|
||||
"CONTEXT_COMPRESSION_THRESHOLD",
|
||||
str(compression_config.get("threshold", 0.85))
|
||||
))
|
||||
compression_threshold = float(os.getenv("CONTEXT_COMPRESSION_THRESHOLD", "0.85"))
|
||||
compression_enabled = os.getenv("CONTEXT_COMPRESSION_ENABLED", "true").lower() in ("true", "1", "yes")
|
||||
if not compression_enabled:
|
||||
compression_enabled = compression_config.get("enabled", True)
|
||||
compression_summary_model = os.getenv("CONTEXT_COMPRESSION_MODEL") or compression_config.get("summary_model") or None
|
||||
|
||||
# Configurable turn protection (clamped 0-12, inspired by openclaw recentTurnsPreserve)
|
||||
protect_first_n = max(0, min(12, int(compression_config.get("protect_first_n", 3))))
|
||||
protect_last_n = max(0, min(12, int(compression_config.get("protect_last_n", 4))))
|
||||
|
||||
compression_summary_model = os.getenv("CONTEXT_COMPRESSION_MODEL") or None
|
||||
|
||||
self.context_compressor = ContextCompressor(
|
||||
model=self.model,
|
||||
threshold_percent=compression_threshold,
|
||||
protect_first_n=protect_first_n,
|
||||
protect_last_n=protect_last_n,
|
||||
summary_target_tokens=2500,
|
||||
protect_first_n=3,
|
||||
protect_last_n=4,
|
||||
summary_target_tokens=500,
|
||||
summary_model_override=compression_summary_model,
|
||||
quiet_mode=self.quiet_mode,
|
||||
base_url=self.base_url,
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""Tests for configurable compaction protection turns."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestCompressionConfigDefaults(unittest.TestCase):
|
||||
"""Verify DEFAULT_CONFIG includes protect_first_n / protect_last_n."""
|
||||
|
||||
def test_default_config_has_protection_fields(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
compression = DEFAULT_CONFIG["compression"]
|
||||
self.assertIn("protect_first_n", compression)
|
||||
self.assertIn("protect_last_n", compression)
|
||||
|
||||
def test_default_values(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
compression = DEFAULT_CONFIG["compression"]
|
||||
self.assertEqual(compression["protect_first_n"], 3)
|
||||
self.assertEqual(compression["protect_last_n"], 4)
|
||||
|
||||
def test_config_version_bumped(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
self.assertGreaterEqual(DEFAULT_CONFIG["_config_version"], 6)
|
||||
|
||||
|
||||
class TestContextCompressorAcceptsConfig(unittest.TestCase):
|
||||
"""Verify ContextCompressor properly receives custom protection values."""
|
||||
|
||||
@patch("agent.context_compressor.get_text_auxiliary_client")
|
||||
def test_custom_protection_values(self, mock_aux):
|
||||
mock_aux.return_value = (None, "test-model")
|
||||
from agent.context_compressor import ContextCompressor
|
||||
compressor = ContextCompressor(
|
||||
model="test/model",
|
||||
protect_first_n=5,
|
||||
protect_last_n=8,
|
||||
)
|
||||
self.assertEqual(compressor.protect_first_n, 5)
|
||||
self.assertEqual(compressor.protect_last_n, 8)
|
||||
|
||||
@patch("agent.context_compressor.get_text_auxiliary_client")
|
||||
def test_default_protection_values(self, mock_aux):
|
||||
mock_aux.return_value = (None, "test-model")
|
||||
from agent.context_compressor import ContextCompressor
|
||||
compressor = ContextCompressor(model="test/model")
|
||||
self.assertEqual(compressor.protect_first_n, 3)
|
||||
self.assertEqual(compressor.protect_last_n, 4)
|
||||
|
||||
|
||||
class TestProtectionClamping(unittest.TestCase):
|
||||
"""Verify protection values are clamped to 0-12 range."""
|
||||
|
||||
def test_clamp_negative_to_zero(self):
|
||||
val = max(0, min(12, -5))
|
||||
self.assertEqual(val, 0)
|
||||
|
||||
def test_clamp_over_max_to_twelve(self):
|
||||
val = max(0, min(12, 50))
|
||||
self.assertEqual(val, 12)
|
||||
|
||||
def test_valid_value_unchanged(self):
|
||||
val = max(0, min(12, 7))
|
||||
self.assertEqual(val, 7)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -393,5 +393,56 @@ class TestStubSchemaDrift(unittest.TestCase):
|
||||
self.assertIn("mode", src)
|
||||
|
||||
|
||||
class TestHeadTailTruncation(unittest.TestCase):
|
||||
"""Tests for head+tail truncation of large stdout in execute_code."""
|
||||
|
||||
def _run(self, code):
|
||||
with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
|
||||
result = execute_code(
|
||||
code=code,
|
||||
task_id="test-task",
|
||||
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
|
||||
)
|
||||
return json.loads(result)
|
||||
|
||||
def test_short_output_not_truncated(self):
|
||||
"""Output under MAX_STDOUT_BYTES should not be truncated."""
|
||||
result = self._run('print("small output")')
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertIn("small output", result["output"])
|
||||
self.assertNotIn("TRUNCATED", result["output"])
|
||||
|
||||
def test_large_output_preserves_head_and_tail(self):
|
||||
"""Output exceeding MAX_STDOUT_BYTES keeps both head and tail."""
|
||||
code = '''
|
||||
# Print HEAD marker, then filler, then TAIL marker
|
||||
print("HEAD_MARKER_START")
|
||||
for i in range(15000):
|
||||
print(f"filler_line_{i:06d}_padding_to_fill_buffer")
|
||||
print("TAIL_MARKER_END")
|
||||
'''
|
||||
result = self._run(code)
|
||||
self.assertEqual(result["status"], "success")
|
||||
output = result["output"]
|
||||
# Head should be preserved
|
||||
self.assertIn("HEAD_MARKER_START", output)
|
||||
# Tail should be preserved (this is the key improvement)
|
||||
self.assertIn("TAIL_MARKER_END", output)
|
||||
# Truncation notice should be present
|
||||
self.assertIn("TRUNCATED", output)
|
||||
|
||||
def test_truncation_notice_format(self):
|
||||
"""Truncation notice includes character counts."""
|
||||
code = '''
|
||||
for i in range(15000):
|
||||
print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx")
|
||||
'''
|
||||
result = self._run(code)
|
||||
output = result["output"]
|
||||
if "TRUNCATED" in output:
|
||||
self.assertIn("chars omitted", output)
|
||||
self.assertIn("total", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -457,11 +457,17 @@ def execute_code(
|
||||
|
||||
# --- Poll loop: watch for exit, timeout, and interrupt ---
|
||||
deadline = time.monotonic() + timeout
|
||||
stdout_chunks: list = []
|
||||
stderr_chunks: list = []
|
||||
|
||||
# Background readers to avoid pipe buffer deadlocks
|
||||
# Background readers to avoid pipe buffer deadlocks.
|
||||
# For stdout we use a head+tail strategy: keep the first HEAD_BYTES
|
||||
# and a rolling window of the last TAIL_BYTES so the final print()
|
||||
# output is never lost. Stderr keeps head-only (errors appear early).
|
||||
_STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head
|
||||
_STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail
|
||||
|
||||
def _drain(pipe, chunks, max_bytes):
|
||||
"""Simple head-only drain (used for stderr)."""
|
||||
total = 0
|
||||
try:
|
||||
while True:
|
||||
@@ -475,8 +481,48 @@ def execute_code(
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
stdout_total_bytes = [0] # mutable ref for total bytes seen
|
||||
|
||||
def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref):
|
||||
"""Drain stdout keeping both head and tail data."""
|
||||
head_collected = 0
|
||||
from collections import deque
|
||||
tail_buf = deque()
|
||||
tail_collected = 0
|
||||
try:
|
||||
while True:
|
||||
data = pipe.read(4096)
|
||||
if not data:
|
||||
break
|
||||
total_ref[0] += len(data)
|
||||
# Fill head buffer first
|
||||
if head_collected < head_bytes:
|
||||
keep = min(len(data), head_bytes - head_collected)
|
||||
head_chunks.append(data[:keep])
|
||||
head_collected += keep
|
||||
data = data[keep:] # remaining goes to tail
|
||||
if not data:
|
||||
continue
|
||||
# Everything past head goes into rolling tail buffer
|
||||
tail_buf.append(data)
|
||||
tail_collected += len(data)
|
||||
# Evict old tail data to stay within tail_bytes budget
|
||||
while tail_collected > tail_bytes and tail_buf:
|
||||
oldest = tail_buf.popleft()
|
||||
tail_collected -= len(oldest)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
# Transfer final tail to output list
|
||||
tail_chunks.extend(tail_buf)
|
||||
|
||||
stdout_head_chunks: list = []
|
||||
stdout_tail_chunks: list = []
|
||||
|
||||
stdout_reader = threading.Thread(
|
||||
target=_drain, args=(proc.stdout, stdout_chunks, MAX_STDOUT_BYTES), daemon=True
|
||||
target=_drain_head_tail,
|
||||
args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks,
|
||||
_STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes),
|
||||
daemon=True
|
||||
)
|
||||
stderr_reader = threading.Thread(
|
||||
target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True
|
||||
@@ -500,12 +546,21 @@ def execute_code(
|
||||
stdout_reader.join(timeout=3)
|
||||
stderr_reader.join(timeout=3)
|
||||
|
||||
stdout_text = b"".join(stdout_chunks).decode("utf-8", errors="replace")
|
||||
stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace")
|
||||
stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace")
|
||||
stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace")
|
||||
|
||||
# Truncation notice
|
||||
if len(stdout_text) >= MAX_STDOUT_BYTES:
|
||||
stdout_text = stdout_text[:MAX_STDOUT_BYTES] + "\n[output truncated at 50KB]"
|
||||
# Assemble stdout with head+tail truncation
|
||||
total_stdout = stdout_total_bytes[0]
|
||||
if total_stdout > MAX_STDOUT_BYTES and stdout_tail:
|
||||
omitted = total_stdout - len(stdout_head) - len(stdout_tail)
|
||||
truncated_notice = (
|
||||
f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted "
|
||||
f"out of {total_stdout:,} total] ...\n\n"
|
||||
)
|
||||
stdout_text = stdout_head + truncated_notice + stdout_tail
|
||||
else:
|
||||
stdout_text = stdout_head + stdout_tail
|
||||
|
||||
exit_code = proc.returncode if proc.returncode is not None else -1
|
||||
duration = round(time.monotonic() - exec_start, 2)
|
||||
|
||||
Reference in New Issue
Block a user