mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-30 07:51:45 +08:00
Compare commits
2 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf2b74cef0 | ||
|
|
346b947fc5 |
81
cli.py
81
cli.py
@@ -1171,6 +1171,45 @@ def _resolve_attachment_path(raw_path: str) -> Path | None:
|
|||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
def _format_process_notification(evt: dict) -> "str | None":
|
||||||
|
"""Format a process notification event into a [SYSTEM: ...] message.
|
||||||
|
|
||||||
|
Handles both completion events (notify_on_complete) and watch pattern
|
||||||
|
match events from the unified completion_queue.
|
||||||
|
"""
|
||||||
|
evt_type = evt.get("type", "completion")
|
||||||
|
_sid = evt.get("session_id", "unknown")
|
||||||
|
_cmd = evt.get("command", "unknown")
|
||||||
|
|
||||||
|
if evt_type == "watch_disabled":
|
||||||
|
return f"[SYSTEM: {evt.get('message', '')}]"
|
||||||
|
|
||||||
|
if evt_type == "watch_match":
|
||||||
|
_pat = evt.get("pattern", "?")
|
||||||
|
_out = evt.get("output", "")
|
||||||
|
_sup = evt.get("suppressed", 0)
|
||||||
|
text = (
|
||||||
|
f"[SYSTEM: Background process {_sid} matched "
|
||||||
|
f"watch pattern \"{_pat}\".\n"
|
||||||
|
f"Command: {_cmd}\n"
|
||||||
|
f"Matched output:\n{_out}"
|
||||||
|
)
|
||||||
|
if _sup:
|
||||||
|
text += f"\n({_sup} earlier matches were suppressed by rate limit)"
|
||||||
|
text += "]"
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Default: completion event
|
||||||
|
_exit = evt.get("exit_code", "?")
|
||||||
|
_out = evt.get("output", "")
|
||||||
|
return (
|
||||||
|
f"[SYSTEM: Background process {_sid} completed "
|
||||||
|
f"(exit code {_exit}).\n"
|
||||||
|
f"Command: {_cmd}\n"
|
||||||
|
f"Output:\n{_out}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _detect_file_drop(user_input: str) -> "dict | None":
|
def _detect_file_drop(user_input: str) -> "dict | None":
|
||||||
"""Detect if *user_input* starts with a real local file path.
|
"""Detect if *user_input* starts with a real local file path.
|
||||||
|
|
||||||
@@ -8870,23 +8909,15 @@ class HermesCLI:
|
|||||||
# Periodic config watcher — auto-reload MCP on mcp_servers change
|
# Periodic config watcher — auto-reload MCP on mcp_servers change
|
||||||
if not self._agent_running:
|
if not self._agent_running:
|
||||||
self._check_config_mcp_changes()
|
self._check_config_mcp_changes()
|
||||||
# Check for background process completion notifications
|
# Check for background process notifications (completions
|
||||||
# while the agent is idle (user hasn't typed anything yet).
|
# and watch pattern matches) while agent is idle.
|
||||||
try:
|
try:
|
||||||
from tools.process_registry import process_registry
|
from tools.process_registry import process_registry
|
||||||
if not process_registry.completion_queue.empty():
|
if not process_registry.completion_queue.empty():
|
||||||
completion = process_registry.completion_queue.get_nowait()
|
evt = process_registry.completion_queue.get_nowait()
|
||||||
_exit = completion.get("exit_code", "?")
|
_synth = _format_process_notification(evt)
|
||||||
_cmd = completion.get("command", "unknown")
|
if _synth:
|
||||||
_sid = completion.get("session_id", "unknown")
|
self._pending_input.put(_synth)
|
||||||
_out = completion.get("output", "")
|
|
||||||
_synth = (
|
|
||||||
f"[SYSTEM: Background process {_sid} completed "
|
|
||||||
f"(exit code {_exit}).\n"
|
|
||||||
f"Command: {_cmd}\n"
|
|
||||||
f"Output:\n{_out}]"
|
|
||||||
)
|
|
||||||
self._pending_input.put(_synth)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
continue
|
continue
|
||||||
@@ -9004,25 +9035,15 @@ class HermesCLI:
|
|||||||
_cprint(f"{_DIM}Voice auto-restart failed: {e}{_RST}")
|
_cprint(f"{_DIM}Voice auto-restart failed: {e}{_RST}")
|
||||||
threading.Thread(target=_restart_recording, daemon=True).start()
|
threading.Thread(target=_restart_recording, daemon=True).start()
|
||||||
|
|
||||||
# Drain process completion notifications — any background
|
# Drain process notifications (completions + watch matches)
|
||||||
# process that finished with notify_on_complete while the
|
# that arrived while the agent was running.
|
||||||
# agent was running (or before) gets auto-injected as a
|
|
||||||
# new user message so the agent can react to it.
|
|
||||||
try:
|
try:
|
||||||
from tools.process_registry import process_registry
|
from tools.process_registry import process_registry
|
||||||
while not process_registry.completion_queue.empty():
|
while not process_registry.completion_queue.empty():
|
||||||
completion = process_registry.completion_queue.get_nowait()
|
evt = process_registry.completion_queue.get_nowait()
|
||||||
_exit = completion.get("exit_code", "?")
|
_synth = _format_process_notification(evt)
|
||||||
_cmd = completion.get("command", "unknown")
|
if _synth:
|
||||||
_sid = completion.get("session_id", "unknown")
|
self._pending_input.put(_synth)
|
||||||
_out = completion.get("output", "")
|
|
||||||
_synth = (
|
|
||||||
f"[SYSTEM: Background process {_sid} completed "
|
|
||||||
f"(exit code {_exit}).\n"
|
|
||||||
f"Command: {_cmd}\n"
|
|
||||||
f"Output:\n{_out}]"
|
|
||||||
)
|
|
||||||
self._pending_input.put(_synth)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Non-fatal — don't break the main loop
|
pass # Non-fatal — don't break the main loop
|
||||||
|
|
||||||
|
|||||||
@@ -476,6 +476,33 @@ def _resolve_hermes_bin() -> Optional[list[str]]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _format_gateway_process_notification(evt: dict) -> "str | None":
|
||||||
|
"""Format a watch pattern event from completion_queue into a [SYSTEM:] message."""
|
||||||
|
evt_type = evt.get("type", "completion")
|
||||||
|
_sid = evt.get("session_id", "unknown")
|
||||||
|
_cmd = evt.get("command", "unknown")
|
||||||
|
|
||||||
|
if evt_type == "watch_disabled":
|
||||||
|
return f"[SYSTEM: {evt.get('message', '')}]"
|
||||||
|
|
||||||
|
if evt_type == "watch_match":
|
||||||
|
_pat = evt.get("pattern", "?")
|
||||||
|
_out = evt.get("output", "")
|
||||||
|
_sup = evt.get("suppressed", 0)
|
||||||
|
text = (
|
||||||
|
f"[SYSTEM: Background process {_sid} matched "
|
||||||
|
f"watch pattern \"{_pat}\".\n"
|
||||||
|
f"Command: {_cmd}\n"
|
||||||
|
f"Matched output:\n{_out}"
|
||||||
|
)
|
||||||
|
if _sup:
|
||||||
|
text += f"\n({_sup} earlier matches were suppressed by rate limit)"
|
||||||
|
text += "]"
|
||||||
|
return text
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class GatewayRunner:
|
class GatewayRunner:
|
||||||
"""
|
"""
|
||||||
Main gateway controller.
|
Main gateway controller.
|
||||||
@@ -3430,6 +3457,29 @@ class GatewayRunner:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Process watcher setup error: %s", e)
|
logger.error("Process watcher setup error: %s", e)
|
||||||
|
|
||||||
|
# Drain watch pattern notifications that arrived during the agent run.
|
||||||
|
# Watch events and completions share the same queue; completions are
|
||||||
|
# already handled by the per-process watcher task above, so we only
|
||||||
|
# inject watch-type events here.
|
||||||
|
try:
|
||||||
|
from tools.process_registry import process_registry as _pr
|
||||||
|
_watch_events = []
|
||||||
|
while not _pr.completion_queue.empty():
|
||||||
|
evt = _pr.completion_queue.get_nowait()
|
||||||
|
evt_type = evt.get("type", "completion")
|
||||||
|
if evt_type in ("watch_match", "watch_disabled"):
|
||||||
|
_watch_events.append(evt)
|
||||||
|
# else: completion events are handled by the watcher task
|
||||||
|
for evt in _watch_events:
|
||||||
|
synth_text = _format_gateway_process_notification(evt)
|
||||||
|
if synth_text:
|
||||||
|
try:
|
||||||
|
await self._inject_watch_notification(synth_text, event)
|
||||||
|
except Exception as e2:
|
||||||
|
logger.error("Watch notification injection error: %s", e2)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Watch queue drain error: %s", e)
|
||||||
|
|
||||||
# NOTE: Dangerous command approvals are now handled inline by the
|
# NOTE: Dangerous command approvals are now handled inline by the
|
||||||
# blocking gateway approval mechanism in tools/approval.py. The agent
|
# blocking gateway approval mechanism in tools/approval.py. The agent
|
||||||
# thread blocks until the user responds with /approve or /deny, so by
|
# thread blocks until the user responds with /approve or /deny, so by
|
||||||
@@ -6708,6 +6758,36 @@ class GatewayRunner:
|
|||||||
return prefix
|
return prefix
|
||||||
return user_text
|
return user_text
|
||||||
|
|
||||||
|
async def _inject_watch_notification(self, synth_text: str, original_event) -> None:
|
||||||
|
"""Inject a watch-pattern notification as a synthetic message event.
|
||||||
|
|
||||||
|
Uses the source from the original user event to route the notification
|
||||||
|
back to the correct chat/adapter.
|
||||||
|
"""
|
||||||
|
source = getattr(original_event, "source", None)
|
||||||
|
if not source:
|
||||||
|
return
|
||||||
|
platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform)
|
||||||
|
adapter = None
|
||||||
|
for p, a in self.adapters.items():
|
||||||
|
if p.value == platform_name:
|
||||||
|
adapter = a
|
||||||
|
break
|
||||||
|
if not adapter:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from gateway.platforms.base import MessageEvent, MessageType
|
||||||
|
synth_event = MessageEvent(
|
||||||
|
text=synth_text,
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=source,
|
||||||
|
internal=True,
|
||||||
|
)
|
||||||
|
logger.info("Watch pattern notification — injecting for %s", platform_name)
|
||||||
|
await adapter.handle_message(synth_event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Watch notification injection error: %s", e)
|
||||||
|
|
||||||
async def _run_process_watcher(self, watcher: dict) -> None:
|
async def _run_process_watcher(self, watcher: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Periodically check a background process and push updates to the user.
|
Periodically check a background process and push updates to the user.
|
||||||
|
|||||||
304
tests/tools/test_watch_patterns.py
Normal file
304
tests/tools/test_watch_patterns.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""Tests for watch_patterns background process monitoring feature.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- ProcessSession.watch_patterns field
|
||||||
|
- ProcessRegistry._check_watch_patterns() matching + notification
|
||||||
|
- Rate limiting (WATCH_MAX_PER_WINDOW) and overload kill switch
|
||||||
|
- watch_queue population
|
||||||
|
- Checkpoint persistence of watch_patterns
|
||||||
|
- Terminal tool schema includes watch_patterns
|
||||||
|
- Terminal tool handler passes watch_patterns through
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from tools.process_registry import (
|
||||||
|
ProcessRegistry,
|
||||||
|
ProcessSession,
|
||||||
|
WATCH_MAX_PER_WINDOW,
|
||||||
|
WATCH_WINDOW_SECONDS,
|
||||||
|
WATCH_OVERLOAD_KILL_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def registry():
|
||||||
|
"""Create a fresh ProcessRegistry."""
|
||||||
|
return ProcessRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session(
|
||||||
|
sid="proc_test_watch",
|
||||||
|
command="tail -f app.log",
|
||||||
|
task_id="t1",
|
||||||
|
watch_patterns=None,
|
||||||
|
) -> ProcessSession:
|
||||||
|
s = ProcessSession(
|
||||||
|
id=sid,
|
||||||
|
command=command,
|
||||||
|
task_id=task_id,
|
||||||
|
started_at=time.time(),
|
||||||
|
watch_patterns=watch_patterns or [],
|
||||||
|
)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# ProcessSession field defaults
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
class TestProcessSessionField:
|
||||||
|
def test_default_empty(self):
|
||||||
|
s = ProcessSession(id="proc_1", command="echo hi")
|
||||||
|
assert s.watch_patterns == []
|
||||||
|
assert s._watch_disabled is False
|
||||||
|
assert s._watch_hits == 0
|
||||||
|
assert s._watch_suppressed == 0
|
||||||
|
|
||||||
|
def test_can_set_patterns(self):
|
||||||
|
s = _make_session(watch_patterns=["ERROR", "WARN"])
|
||||||
|
assert s.watch_patterns == ["ERROR", "WARN"]
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Pattern matching + queue population
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
class TestCheckWatchPatterns:
|
||||||
|
def test_no_patterns_no_notification(self, registry):
|
||||||
|
"""No watch_patterns → no notifications."""
|
||||||
|
session = _make_session(watch_patterns=[])
|
||||||
|
registry._check_watch_patterns(session, "ERROR: something broke\n")
|
||||||
|
assert registry.completion_queue.empty()
|
||||||
|
|
||||||
|
def test_no_match_no_notification(self, registry):
|
||||||
|
"""Output that doesn't match any pattern → no notification."""
|
||||||
|
session = _make_session(watch_patterns=["ERROR", "FAIL"])
|
||||||
|
registry._check_watch_patterns(session, "INFO: all good\nDEBUG: fine\n")
|
||||||
|
assert registry.completion_queue.empty()
|
||||||
|
|
||||||
|
def test_basic_match(self, registry):
|
||||||
|
"""Single matching line triggers a notification."""
|
||||||
|
session = _make_session(watch_patterns=["ERROR"])
|
||||||
|
registry._check_watch_patterns(session, "INFO: ok\nERROR: disk full\n")
|
||||||
|
assert not registry.completion_queue.empty()
|
||||||
|
evt = registry.completion_queue.get_nowait()
|
||||||
|
assert evt["type"] == "watch_match"
|
||||||
|
assert evt["pattern"] == "ERROR"
|
||||||
|
assert "disk full" in evt["output"]
|
||||||
|
assert evt["session_id"] == "proc_test_watch"
|
||||||
|
|
||||||
|
def test_multiple_patterns(self, registry):
|
||||||
|
"""First matching pattern is reported."""
|
||||||
|
session = _make_session(watch_patterns=["WARN", "ERROR"])
|
||||||
|
registry._check_watch_patterns(session, "ERROR: bad\nWARN: hmm\n")
|
||||||
|
evt = registry.completion_queue.get_nowait()
|
||||||
|
# ERROR appears first in the output, and we check patterns in order
|
||||||
|
# so "WARN" won't match "ERROR: bad" but "ERROR" will
|
||||||
|
assert evt["pattern"] == "ERROR"
|
||||||
|
assert "bad" in evt["output"]
|
||||||
|
|
||||||
|
def test_disabled_skips(self, registry):
|
||||||
|
"""Disabled watch produces no notifications."""
|
||||||
|
session = _make_session(watch_patterns=["ERROR"])
|
||||||
|
session._watch_disabled = True
|
||||||
|
registry._check_watch_patterns(session, "ERROR: boom\n")
|
||||||
|
assert registry.completion_queue.empty()
|
||||||
|
|
||||||
|
def test_hit_counter_increments(self, registry):
|
||||||
|
"""Each delivered notification increments _watch_hits."""
|
||||||
|
session = _make_session(watch_patterns=["X"])
|
||||||
|
registry._check_watch_patterns(session, "X\n")
|
||||||
|
assert session._watch_hits == 1
|
||||||
|
registry._check_watch_patterns(session, "X\n")
|
||||||
|
assert session._watch_hits == 2
|
||||||
|
|
||||||
|
def test_output_truncation(self, registry):
|
||||||
|
"""Very long matched output is truncated."""
|
||||||
|
session = _make_session(watch_patterns=["X"])
|
||||||
|
# Generate 30 matching lines (more than the 20-line cap)
|
||||||
|
text = "\n".join(f"X line {i}" for i in range(30)) + "\n"
|
||||||
|
registry._check_watch_patterns(session, text)
|
||||||
|
evt = registry.completion_queue.get_nowait()
|
||||||
|
# Should only have 20 lines max
|
||||||
|
assert evt["output"].count("\n") <= 20
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Rate limiting
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
class TestRateLimiting:
|
||||||
|
def test_within_window_limit(self, registry):
|
||||||
|
"""Notifications within the rate limit all get delivered."""
|
||||||
|
session = _make_session(watch_patterns=["E"])
|
||||||
|
for i in range(WATCH_MAX_PER_WINDOW):
|
||||||
|
registry._check_watch_patterns(session, f"E {i}\n")
|
||||||
|
assert registry.completion_queue.qsize() == WATCH_MAX_PER_WINDOW
|
||||||
|
|
||||||
|
def test_exceeds_window_limit(self, registry):
|
||||||
|
"""Notifications beyond the rate limit are suppressed."""
|
||||||
|
session = _make_session(watch_patterns=["E"])
|
||||||
|
for i in range(WATCH_MAX_PER_WINDOW + 5):
|
||||||
|
registry._check_watch_patterns(session, f"E {i}\n")
|
||||||
|
# Only WATCH_MAX_PER_WINDOW should be in the queue
|
||||||
|
assert registry.completion_queue.qsize() == WATCH_MAX_PER_WINDOW
|
||||||
|
assert session._watch_suppressed == 5
|
||||||
|
|
||||||
|
def test_window_resets(self, registry):
|
||||||
|
"""After the window expires, notifications can flow again."""
|
||||||
|
session = _make_session(watch_patterns=["E"])
|
||||||
|
# Fill the window
|
||||||
|
for i in range(WATCH_MAX_PER_WINDOW):
|
||||||
|
registry._check_watch_patterns(session, f"E {i}\n")
|
||||||
|
# One more should be suppressed
|
||||||
|
registry._check_watch_patterns(session, "E extra\n")
|
||||||
|
assert session._watch_suppressed == 1
|
||||||
|
|
||||||
|
# Fast-forward past window
|
||||||
|
session._watch_window_start = time.time() - WATCH_WINDOW_SECONDS - 1
|
||||||
|
registry._check_watch_patterns(session, "E after reset\n")
|
||||||
|
# Should deliver now (window reset)
|
||||||
|
assert registry.completion_queue.qsize() == WATCH_MAX_PER_WINDOW + 1
|
||||||
|
|
||||||
|
def test_suppressed_count_in_next_delivery(self, registry):
|
||||||
|
"""Suppressed count is reported in the next successful delivery."""
|
||||||
|
session = _make_session(watch_patterns=["E"])
|
||||||
|
for i in range(WATCH_MAX_PER_WINDOW):
|
||||||
|
registry._check_watch_patterns(session, f"E {i}\n")
|
||||||
|
# Suppress 3 more
|
||||||
|
for i in range(3):
|
||||||
|
registry._check_watch_patterns(session, f"E suppressed {i}\n")
|
||||||
|
assert session._watch_suppressed == 3
|
||||||
|
|
||||||
|
# Fast-forward past window to allow delivery
|
||||||
|
session._watch_window_start = time.time() - WATCH_WINDOW_SECONDS - 1
|
||||||
|
registry._check_watch_patterns(session, "E back\n")
|
||||||
|
# Drain to the last event
|
||||||
|
last_evt = None
|
||||||
|
while not registry.completion_queue.empty():
|
||||||
|
last_evt = registry.completion_queue.get_nowait()
|
||||||
|
assert last_evt["suppressed"] == 3
|
||||||
|
assert session._watch_suppressed == 0 # reset after delivery
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Overload kill switch
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
class TestOverloadKillSwitch:
|
||||||
|
def test_sustained_overload_disables(self, registry):
|
||||||
|
"""Sustained overload beyond threshold permanently disables watching."""
|
||||||
|
session = _make_session(watch_patterns=["E"])
|
||||||
|
# Fill the window to trigger rate limit
|
||||||
|
for i in range(WATCH_MAX_PER_WINDOW):
|
||||||
|
registry._check_watch_patterns(session, f"E {i}\n")
|
||||||
|
|
||||||
|
# Simulate sustained overload: set overload_since to past threshold
|
||||||
|
session._watch_overload_since = time.time() - WATCH_OVERLOAD_KILL_SECONDS - 1
|
||||||
|
# Force another suppressed hit
|
||||||
|
registry._check_watch_patterns(session, "E overload\n")
|
||||||
|
registry._check_watch_patterns(session, "E overload2\n")
|
||||||
|
|
||||||
|
assert session._watch_disabled is True
|
||||||
|
# Should have a watch_disabled event in the queue
|
||||||
|
disabled_evts = []
|
||||||
|
while not registry.completion_queue.empty():
|
||||||
|
evt = registry.completion_queue.get_nowait()
|
||||||
|
if evt.get("type") == "watch_disabled":
|
||||||
|
disabled_evts.append(evt)
|
||||||
|
assert len(disabled_evts) == 1
|
||||||
|
assert "too many matches" in disabled_evts[0]["message"]
|
||||||
|
|
||||||
|
def test_overload_resets_on_delivery(self, registry):
|
||||||
|
"""Overload timer resets when a notification gets through."""
|
||||||
|
session = _make_session(watch_patterns=["E"])
|
||||||
|
# Start overload tracking
|
||||||
|
session._watch_overload_since = time.time() - 10
|
||||||
|
# But window allows delivery → overload should reset
|
||||||
|
registry._check_watch_patterns(session, "E ok\n")
|
||||||
|
assert session._watch_overload_since == 0.0
|
||||||
|
assert session._watch_disabled is False
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Checkpoint persistence
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
class TestCheckpointPersistence:
|
||||||
|
def test_watch_patterns_in_checkpoint(self, registry):
|
||||||
|
"""watch_patterns is included in checkpoint data."""
|
||||||
|
session = _make_session(watch_patterns=["ERROR", "FAIL"])
|
||||||
|
with registry._lock:
|
||||||
|
registry._running[session.id] = session
|
||||||
|
|
||||||
|
with patch("utils.atomic_json_write") as mock_write:
|
||||||
|
registry._write_checkpoint()
|
||||||
|
args = mock_write.call_args
|
||||||
|
entries = args[0][1] # second positional arg
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["watch_patterns"] == ["ERROR", "FAIL"]
|
||||||
|
|
||||||
|
def test_watch_patterns_recovery(self, registry, tmp_path, monkeypatch):
|
||||||
|
"""watch_patterns survives checkpoint recovery."""
|
||||||
|
import tools.process_registry as pr_mod
|
||||||
|
checkpoint = tmp_path / "processes.json"
|
||||||
|
checkpoint.write_text(json.dumps([{
|
||||||
|
"session_id": "proc_recovered",
|
||||||
|
"command": "tail -f log",
|
||||||
|
"pid": 99999999, # non-existent
|
||||||
|
"pid_scope": "host",
|
||||||
|
"started_at": time.time(),
|
||||||
|
"task_id": "",
|
||||||
|
"session_key": "",
|
||||||
|
"watcher_platform": "",
|
||||||
|
"watcher_chat_id": "",
|
||||||
|
"watcher_thread_id": "",
|
||||||
|
"watcher_interval": 0,
|
||||||
|
"notify_on_complete": False,
|
||||||
|
"watch_patterns": ["PANIC", "OOM"],
|
||||||
|
}]))
|
||||||
|
monkeypatch.setattr(pr_mod, "CHECKPOINT_PATH", checkpoint)
|
||||||
|
# PID doesn't exist, so nothing will be recovered
|
||||||
|
count = registry.recover_from_checkpoint()
|
||||||
|
# Won't recover since PID is fake, but verify the code path doesn't crash
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Terminal tool schema + handler
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
class TestTerminalToolSchema:
|
||||||
|
def test_schema_includes_watch_patterns(self):
|
||||||
|
from tools.terminal_tool import TERMINAL_SCHEMA
|
||||||
|
props = TERMINAL_SCHEMA["parameters"]["properties"]
|
||||||
|
assert "watch_patterns" in props
|
||||||
|
assert props["watch_patterns"]["type"] == "array"
|
||||||
|
assert props["watch_patterns"]["items"] == {"type": "string"}
|
||||||
|
|
||||||
|
def test_handler_passes_watch_patterns(self):
|
||||||
|
"""_handle_terminal passes watch_patterns to terminal_tool."""
|
||||||
|
from tools.terminal_tool import _handle_terminal
|
||||||
|
with patch("tools.terminal_tool.terminal_tool") as mock_tt:
|
||||||
|
mock_tt.return_value = json.dumps({"output": "ok", "exit_code": 0})
|
||||||
|
_handle_terminal(
|
||||||
|
{"command": "echo hi", "watch_patterns": ["ERR"]},
|
||||||
|
task_id="t1",
|
||||||
|
)
|
||||||
|
_, kwargs = mock_tt.call_args
|
||||||
|
assert kwargs.get("watch_patterns") == ["ERR"]
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Code execution tool blocked params
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
class TestCodeExecutionBlocked:
|
||||||
|
def test_watch_patterns_blocked(self):
|
||||||
|
from tools.code_execution_tool import _TERMINAL_BLOCKED_PARAMS
|
||||||
|
assert "watch_patterns" in _TERMINAL_BLOCKED_PARAMS
|
||||||
@@ -301,7 +301,7 @@ def _call(tool_name, args):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Terminal parameters that must not be used from ephemeral sandbox scripts
|
# Terminal parameters that must not be used from ephemeral sandbox scripts
|
||||||
_TERMINAL_BLOCKED_PARAMS = {"background", "check_interval", "pty", "notify_on_complete"}
|
_TERMINAL_BLOCKED_PARAMS = {"background", "check_interval", "pty", "notify_on_complete", "watch_patterns"}
|
||||||
|
|
||||||
|
|
||||||
def _rpc_server_loop(
|
def _rpc_server_loop(
|
||||||
|
|||||||
@@ -58,6 +58,11 @@ MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
|
|||||||
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
|
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
|
||||||
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
|
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
|
||||||
|
|
||||||
|
# Watch pattern rate limiting
|
||||||
|
WATCH_MAX_PER_WINDOW = 8 # Max notifications delivered per window
|
||||||
|
WATCH_WINDOW_SECONDS = 10 # Rolling window length
|
||||||
|
WATCH_OVERLOAD_KILL_SECONDS = 45 # Sustained overload duration before disabling watch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProcessSession:
|
class ProcessSession:
|
||||||
@@ -83,6 +88,14 @@ class ProcessSession:
|
|||||||
watcher_thread_id: str = ""
|
watcher_thread_id: str = ""
|
||||||
watcher_interval: int = 0 # 0 = no watcher configured
|
watcher_interval: int = 0 # 0 = no watcher configured
|
||||||
notify_on_complete: bool = False # Queue agent notification on exit
|
notify_on_complete: bool = False # Queue agent notification on exit
|
||||||
|
# Watch patterns — trigger agent notification when output matches any pattern
|
||||||
|
watch_patterns: List[str] = field(default_factory=list)
|
||||||
|
_watch_hits: int = field(default=0, repr=False) # total matches delivered
|
||||||
|
_watch_suppressed: int = field(default=0, repr=False) # matches dropped by rate limit
|
||||||
|
_watch_overload_since: float = field(default=0.0, repr=False) # when sustained overload began
|
||||||
|
_watch_disabled: bool = field(default=False, repr=False) # permanently killed by overload
|
||||||
|
_watch_window_hits: int = field(default=0, repr=False) # hits in current rate window
|
||||||
|
_watch_window_start: float = field(default=0.0, repr=False)
|
||||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||||
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
||||||
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
||||||
@@ -114,9 +127,10 @@ class ProcessRegistry:
|
|||||||
# Side-channel for check_interval watchers (gateway reads after agent run)
|
# Side-channel for check_interval watchers (gateway reads after agent run)
|
||||||
self.pending_watchers: List[Dict[str, Any]] = []
|
self.pending_watchers: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
# Completion notifications — processes with notify_on_complete push here
|
# Notification queue — unified queue for all background process events.
|
||||||
# on exit. CLI process_loop and gateway drain this after each agent turn
|
# Completion notifications (notify_on_complete) and watch pattern matches
|
||||||
# to auto-trigger a new agent turn with the process results.
|
# both land here, distinguished by "type" field. CLI process_loop and
|
||||||
|
# gateway drain this after each agent turn to auto-trigger new turns.
|
||||||
import queue as _queue_mod
|
import queue as _queue_mod
|
||||||
self.completion_queue: _queue_mod.Queue = _queue_mod.Queue()
|
self.completion_queue: _queue_mod.Queue = _queue_mod.Queue()
|
||||||
|
|
||||||
@@ -128,6 +142,84 @@ class ProcessRegistry:
|
|||||||
lines.pop(0)
|
lines.pop(0)
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _check_watch_patterns(self, session: ProcessSession, new_text: str) -> None:
|
||||||
|
"""Scan new output for watch patterns and queue notifications.
|
||||||
|
|
||||||
|
Called from reader threads with new_text being the freshly-read chunk.
|
||||||
|
Rate-limited: max WATCH_MAX_PER_WINDOW notifications per WATCH_WINDOW_SECONDS.
|
||||||
|
If sustained overload exceeds WATCH_OVERLOAD_KILL_SECONDS, watching is
|
||||||
|
disabled permanently for this process.
|
||||||
|
"""
|
||||||
|
if not session.watch_patterns or session._watch_disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Scan new text line-by-line for pattern matches
|
||||||
|
matched_lines = []
|
||||||
|
matched_pattern = None
|
||||||
|
for line in new_text.splitlines():
|
||||||
|
for pat in session.watch_patterns:
|
||||||
|
if pat in line:
|
||||||
|
matched_lines.append(line.rstrip())
|
||||||
|
if matched_pattern is None:
|
||||||
|
matched_pattern = pat
|
||||||
|
break # one match per line is enough
|
||||||
|
|
||||||
|
if not matched_lines:
|
||||||
|
return
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
with session._lock:
|
||||||
|
# Reset window if it's expired
|
||||||
|
if now - session._watch_window_start >= WATCH_WINDOW_SECONDS:
|
||||||
|
session._watch_window_hits = 0
|
||||||
|
session._watch_window_start = now
|
||||||
|
|
||||||
|
# Check rate limit
|
||||||
|
if session._watch_window_hits >= WATCH_MAX_PER_WINDOW:
|
||||||
|
session._watch_suppressed += len(matched_lines)
|
||||||
|
|
||||||
|
# Track sustained overload for kill switch
|
||||||
|
if session._watch_overload_since == 0.0:
|
||||||
|
session._watch_overload_since = now
|
||||||
|
elif now - session._watch_overload_since > WATCH_OVERLOAD_KILL_SECONDS:
|
||||||
|
session._watch_disabled = True
|
||||||
|
self.completion_queue.put({
|
||||||
|
"session_id": session.id,
|
||||||
|
"command": session.command,
|
||||||
|
"type": "watch_disabled",
|
||||||
|
"suppressed": session._watch_suppressed,
|
||||||
|
"message": (
|
||||||
|
f"Watch patterns disabled for process {session.id} — "
|
||||||
|
f"too many matches ({session._watch_suppressed} suppressed). "
|
||||||
|
f"Use process(action='poll') to check output manually."
|
||||||
|
),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Under the rate limit — deliver notification
|
||||||
|
session._watch_window_hits += 1
|
||||||
|
session._watch_hits += 1
|
||||||
|
# Clear overload tracker since we got a delivery through
|
||||||
|
session._watch_overload_since = 0.0
|
||||||
|
|
||||||
|
# Include suppressed count if any events were dropped
|
||||||
|
suppressed = session._watch_suppressed
|
||||||
|
session._watch_suppressed = 0
|
||||||
|
|
||||||
|
# Trim matched output to a reasonable size
|
||||||
|
output = "\n".join(matched_lines[:20])
|
||||||
|
if len(output) > 2000:
|
||||||
|
output = output[:2000] + "\n...(truncated)"
|
||||||
|
|
||||||
|
self.completion_queue.put({
|
||||||
|
"session_id": session.id,
|
||||||
|
"command": session.command,
|
||||||
|
"type": "watch_match",
|
||||||
|
"pattern": matched_pattern,
|
||||||
|
"output": output,
|
||||||
|
"suppressed": suppressed,
|
||||||
|
})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_host_pid_alive(pid: Optional[int]) -> bool:
|
def _is_host_pid_alive(pid: Optional[int]) -> bool:
|
||||||
"""Best-effort liveness check for host-visible PIDs."""
|
"""Best-effort liveness check for host-visible PIDs."""
|
||||||
@@ -394,6 +486,7 @@ class ProcessRegistry:
|
|||||||
session.output_buffer += chunk
|
session.output_buffer += chunk
|
||||||
if len(session.output_buffer) > session.max_output_chars:
|
if len(session.output_buffer) > session.max_output_chars:
|
||||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||||
|
self._check_watch_patterns(session, chunk)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Process stdout reader ended: %s", e)
|
logger.debug("Process stdout reader ended: %s", e)
|
||||||
finally:
|
finally:
|
||||||
@@ -413,6 +506,7 @@ class ProcessRegistry:
|
|||||||
quoted_log_path = shlex.quote(log_path)
|
quoted_log_path = shlex.quote(log_path)
|
||||||
quoted_pid_path = shlex.quote(pid_path)
|
quoted_pid_path = shlex.quote(pid_path)
|
||||||
quoted_exit_path = shlex.quote(exit_path)
|
quoted_exit_path = shlex.quote(exit_path)
|
||||||
|
prev_output_len = 0 # track delta for watch pattern scanning
|
||||||
while not session.exited:
|
while not session.exited:
|
||||||
time.sleep(2) # Poll every 2 seconds
|
time.sleep(2) # Poll every 2 seconds
|
||||||
try:
|
try:
|
||||||
@@ -420,10 +514,15 @@ class ProcessRegistry:
|
|||||||
result = env.execute(f"cat {quoted_log_path} 2>/dev/null", timeout=10)
|
result = env.execute(f"cat {quoted_log_path} 2>/dev/null", timeout=10)
|
||||||
new_output = result.get("output", "")
|
new_output = result.get("output", "")
|
||||||
if new_output:
|
if new_output:
|
||||||
|
# Compute delta for watch pattern scanning
|
||||||
|
delta = new_output[prev_output_len:] if len(new_output) > prev_output_len else ""
|
||||||
|
prev_output_len = len(new_output)
|
||||||
with session._lock:
|
with session._lock:
|
||||||
session.output_buffer = new_output
|
session.output_buffer = new_output
|
||||||
if len(session.output_buffer) > session.max_output_chars:
|
if len(session.output_buffer) > session.max_output_chars:
|
||||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||||
|
if delta:
|
||||||
|
self._check_watch_patterns(session, delta)
|
||||||
|
|
||||||
# Check if process is still running
|
# Check if process is still running
|
||||||
check = env.execute(
|
check = env.execute(
|
||||||
@@ -467,6 +566,7 @@ class ProcessRegistry:
|
|||||||
session.output_buffer += text
|
session.output_buffer += text
|
||||||
if len(session.output_buffer) > session.max_output_chars:
|
if len(session.output_buffer) > session.max_output_chars:
|
||||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||||
|
self._check_watch_patterns(session, text)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -502,6 +602,7 @@ class ProcessRegistry:
|
|||||||
from tools.ansi_strip import strip_ansi
|
from tools.ansi_strip import strip_ansi
|
||||||
output_tail = strip_ansi(session.output_buffer[-2000:]) if session.output_buffer else ""
|
output_tail = strip_ansi(session.output_buffer[-2000:]) if session.output_buffer else ""
|
||||||
self.completion_queue.put({
|
self.completion_queue.put({
|
||||||
|
"type": "completion",
|
||||||
"session_id": session.id,
|
"session_id": session.id,
|
||||||
"command": session.command,
|
"command": session.command,
|
||||||
"exit_code": session.exit_code,
|
"exit_code": session.exit_code,
|
||||||
@@ -872,6 +973,7 @@ class ProcessRegistry:
|
|||||||
"watcher_thread_id": s.watcher_thread_id,
|
"watcher_thread_id": s.watcher_thread_id,
|
||||||
"watcher_interval": s.watcher_interval,
|
"watcher_interval": s.watcher_interval,
|
||||||
"notify_on_complete": s.notify_on_complete,
|
"notify_on_complete": s.notify_on_complete,
|
||||||
|
"watch_patterns": s.watch_patterns,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Atomic write to avoid corruption on crash
|
# Atomic write to avoid corruption on crash
|
||||||
@@ -932,6 +1034,7 @@ class ProcessRegistry:
|
|||||||
watcher_thread_id=entry.get("watcher_thread_id", ""),
|
watcher_thread_id=entry.get("watcher_thread_id", ""),
|
||||||
watcher_interval=entry.get("watcher_interval", 0),
|
watcher_interval=entry.get("watcher_interval", 0),
|
||||||
notify_on_complete=entry.get("notify_on_complete", False),
|
notify_on_complete=entry.get("notify_on_complete", False),
|
||||||
|
watch_patterns=entry.get("watch_patterns", []),
|
||||||
)
|
)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._running[session.id] = session
|
self._running[session.id] = session
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ import atexit
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any, List
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1140,6 +1140,7 @@ def terminal_tool(
|
|||||||
check_interval: Optional[int] = None,
|
check_interval: Optional[int] = None,
|
||||||
pty: bool = False,
|
pty: bool = False,
|
||||||
notify_on_complete: bool = False,
|
notify_on_complete: bool = False,
|
||||||
|
watch_patterns: Optional[List[str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Execute a command in the configured terminal environment.
|
Execute a command in the configured terminal environment.
|
||||||
@@ -1154,6 +1155,7 @@ def terminal_tool(
|
|||||||
check_interval: Seconds between auto-checks for background processes (gateway only, min 30)
|
check_interval: Seconds between auto-checks for background processes (gateway only, min 30)
|
||||||
pty: If True, use pseudo-terminal for interactive CLI tools (local backend only)
|
pty: If True, use pseudo-terminal for interactive CLI tools (local backend only)
|
||||||
notify_on_complete: If True and background=True, auto-notify the agent when the process exits
|
notify_on_complete: If True and background=True, auto-notify the agent when the process exits
|
||||||
|
watch_patterns: List of strings to watch for in background output; triggers notification on match
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: JSON string with output, exit_code, and error fields
|
str: JSON string with output, exit_code, and error fields
|
||||||
@@ -1439,6 +1441,11 @@ def terminal_tool(
|
|||||||
"notify_on_complete": True,
|
"notify_on_complete": True,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Set watch patterns for output monitoring
|
||||||
|
if watch_patterns and background:
|
||||||
|
proc_session.watch_patterns = list(watch_patterns)
|
||||||
|
result_data["watch_patterns"] = proc_session.watch_patterns
|
||||||
|
|
||||||
# Register check_interval watcher (gateway picks this up after agent run)
|
# Register check_interval watcher (gateway picks this up after agent run)
|
||||||
if check_interval and background:
|
if check_interval and background:
|
||||||
effective_interval = max(30, check_interval)
|
effective_interval = max(30, check_interval)
|
||||||
@@ -1762,6 +1769,11 @@ TERMINAL_SCHEMA = {
|
|||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"description": "When true (and background=true), you'll be automatically notified when the process finishes — no polling needed. Use this for tasks that take a while (tests, builds, deployments) so you can keep working on other things in the meantime.",
|
"description": "When true (and background=true), you'll be automatically notified when the process finishes — no polling needed. Use this for tasks that take a while (tests, builds, deployments) so you can keep working on other things in the meantime.",
|
||||||
"default": False
|
"default": False
|
||||||
|
},
|
||||||
|
"watch_patterns": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "List of strings to watch for in background process output. When any pattern matches a line of output, you'll be notified with the matching text — like notify_on_complete but triggers mid-process on specific output. Use for monitoring logs, watching for errors, or waiting for specific events (e.g. [\"ERROR\", \"FAIL\", \"listening on port\"])."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["command"]
|
"required": ["command"]
|
||||||
@@ -1779,6 +1791,7 @@ def _handle_terminal(args, **kw):
|
|||||||
check_interval=args.get("check_interval"),
|
check_interval=args.get("check_interval"),
|
||||||
pty=args.get("pty", False),
|
pty=args.get("pty", False),
|
||||||
notify_on_complete=args.get("notify_on_complete", False),
|
notify_on_complete=args.get("notify_on_complete", False),
|
||||||
|
watch_patterns=args.get("watch_patterns"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user