mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 15:01:34 +08:00
192 lines
6.8 KiB
Python
192 lines
6.8 KiB
Python
|
|
"""Tests for the gateway /steer command handler.
|
||
|
|
|
||
|
|
/steer injects a user message into the agent's next tool result without
|
||
|
|
interrupting. The gateway runner must:
|
||
|
|
|
||
|
|
1. When an agent IS running → call ``agent.steer(text)``, do NOT set
|
||
|
|
``_interrupt_requested``, do NOT touch ``_pending_messages``.
|
||
|
|
2. When the agent is the PENDING sentinel → fall back to /queue
|
||
|
|
semantics (store in ``adapter._pending_messages``).
|
||
|
|
3. When no agent is active → strip the slash prefix and let the normal
|
||
|
|
prompt pipeline handle it as a regular user message.
|
||
|
|
"""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from datetime import datetime
|
||
|
|
from types import SimpleNamespace
|
||
|
|
from unittest.mock import AsyncMock, MagicMock
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||
|
|
from gateway.platforms.base import MessageEvent
|
||
|
|
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||
|
|
|
||
|
|
|
||
|
|
def _make_source() -> SessionSource:
|
||
|
|
return SessionSource(
|
||
|
|
platform=Platform.TELEGRAM,
|
||
|
|
user_id="u1",
|
||
|
|
chat_id="c1",
|
||
|
|
user_name="tester",
|
||
|
|
chat_type="dm",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _make_event(text: str) -> MessageEvent:
|
||
|
|
return MessageEvent(
|
||
|
|
text=text,
|
||
|
|
source=_make_source(),
|
||
|
|
message_id="m1",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _make_runner(session_entry: SessionEntry):
|
||
|
|
from gateway.run import GatewayRunner
|
||
|
|
|
||
|
|
runner = object.__new__(GatewayRunner)
|
||
|
|
runner.config = GatewayConfig(
|
||
|
|
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||
|
|
)
|
||
|
|
adapter = MagicMock()
|
||
|
|
adapter.send = AsyncMock()
|
||
|
|
adapter._pending_messages = {}
|
||
|
|
runner.adapters = {Platform.TELEGRAM: adapter}
|
||
|
|
runner._voice_mode = {}
|
||
|
|
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||
|
|
runner.session_store = MagicMock()
|
||
|
|
runner.session_store.get_or_create_session.return_value = session_entry
|
||
|
|
runner.session_store.load_transcript.return_value = []
|
||
|
|
runner.session_store.has_any_sessions.return_value = True
|
||
|
|
runner._running_agents = {}
|
||
|
|
runner._running_agents_ts = {}
|
||
|
|
runner._pending_messages = {}
|
||
|
|
runner._pending_approvals = {}
|
||
|
|
runner._session_db = MagicMock()
|
||
|
|
runner._session_db.get_session_title.return_value = None
|
||
|
|
runner._reasoning_config = None
|
||
|
|
runner._provider_routing = {}
|
||
|
|
runner._fallback_model = None
|
||
|
|
runner._show_reasoning = False
|
||
|
|
runner._is_user_authorized = lambda _source: True
|
||
|
|
runner._set_session_env = lambda _context: None
|
||
|
|
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
|
||
|
|
runner._send_voice_reply = AsyncMock()
|
||
|
|
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
|
||
|
|
runner._emit_gateway_run_progress = AsyncMock()
|
||
|
|
return runner, adapter
|
||
|
|
|
||
|
|
|
||
|
|
def _session_entry() -> SessionEntry:
|
||
|
|
return SessionEntry(
|
||
|
|
session_key=build_session_key(_make_source()),
|
||
|
|
session_id="sess-1",
|
||
|
|
created_at=datetime.now(),
|
||
|
|
updated_at=datetime.now(),
|
||
|
|
platform=Platform.TELEGRAM,
|
||
|
|
chat_type="dm",
|
||
|
|
total_tokens=0,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_steer_calls_agent_steer_and_does_not_interrupt():
|
||
|
|
"""When an agent is running, /steer must call agent.steer(text) and
|
||
|
|
leave interrupt state untouched."""
|
||
|
|
runner, adapter = _make_runner(_session_entry())
|
||
|
|
sk = build_session_key(_make_source())
|
||
|
|
|
||
|
|
running_agent = MagicMock()
|
||
|
|
running_agent.steer.return_value = True
|
||
|
|
runner._running_agents[sk] = running_agent
|
||
|
|
|
||
|
|
result = await runner._handle_message(_make_event("/steer also check auth.log"))
|
||
|
|
|
||
|
|
# The handler replied with a confirmation
|
||
|
|
assert result is not None
|
||
|
|
assert "steer" in result.lower() or "queued" in result.lower()
|
||
|
|
# The agent's steer() was called with the payload (prefix stripped)
|
||
|
|
running_agent.steer.assert_called_once_with("also check auth.log")
|
||
|
|
# Critically: interrupt was NOT called
|
||
|
|
running_agent.interrupt.assert_not_called()
|
||
|
|
# And no user-text queueing happened — the steer doesn't go into
|
||
|
|
# _pending_messages (that would be turn-boundary /queue semantics).
|
||
|
|
assert runner._pending_messages == {}
|
||
|
|
assert adapter._pending_messages == {}
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_steer_without_payload_returns_usage():
|
||
|
|
runner, _adapter = _make_runner(_session_entry())
|
||
|
|
sk = build_session_key(_make_source())
|
||
|
|
running_agent = MagicMock()
|
||
|
|
runner._running_agents[sk] = running_agent
|
||
|
|
|
||
|
|
result = await runner._handle_message(_make_event("/steer"))
|
||
|
|
|
||
|
|
assert result is not None
|
||
|
|
assert "Usage" in result or "usage" in result
|
||
|
|
running_agent.steer.assert_not_called()
|
||
|
|
running_agent.interrupt.assert_not_called()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_steer_with_pending_sentinel_falls_back_to_queue():
|
||
|
|
"""When the agent hasn't finished booting (sentinel), /steer should
|
||
|
|
queue as a turn-boundary follow-up instead of crashing."""
|
||
|
|
from gateway.run import _AGENT_PENDING_SENTINEL
|
||
|
|
|
||
|
|
runner, adapter = _make_runner(_session_entry())
|
||
|
|
sk = build_session_key(_make_source())
|
||
|
|
runner._running_agents[sk] = _AGENT_PENDING_SENTINEL
|
||
|
|
|
||
|
|
result = await runner._handle_message(_make_event("/steer wait up"))
|
||
|
|
|
||
|
|
assert result is not None
|
||
|
|
assert "queued" in result.lower() or "starting" in result.lower()
|
||
|
|
# The fallback put the text into the adapter's pending queue.
|
||
|
|
assert sk in adapter._pending_messages
|
||
|
|
assert adapter._pending_messages[sk].text == "wait up"
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_steer_agent_without_steer_method_falls_back():
|
||
|
|
"""If the running agent somehow lacks the steer() method (older build,
|
||
|
|
test stub), the handler must not explode — fall back to /queue."""
|
||
|
|
runner, adapter = _make_runner(_session_entry())
|
||
|
|
sk = build_session_key(_make_source())
|
||
|
|
|
||
|
|
# A bare object that does NOT have steer() — use a spec'd Mock so
|
||
|
|
# hasattr(agent, "steer") returns False.
|
||
|
|
running_agent = MagicMock(spec=[])
|
||
|
|
runner._running_agents[sk] = running_agent
|
||
|
|
|
||
|
|
result = await runner._handle_message(_make_event("/steer fallback"))
|
||
|
|
|
||
|
|
assert result is not None
|
||
|
|
# Must mention queueing since steer wasn't available
|
||
|
|
assert "queued" in result.lower()
|
||
|
|
assert sk in adapter._pending_messages
|
||
|
|
assert adapter._pending_messages[sk].text == "fallback"
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_steer_rejected_payload_returns_rejection_message():
|
||
|
|
"""If agent.steer() returns False (e.g. empty after strip — though
|
||
|
|
the gateway already guards this), surface a rejection message."""
|
||
|
|
runner, _adapter = _make_runner(_session_entry())
|
||
|
|
sk = build_session_key(_make_source())
|
||
|
|
|
||
|
|
running_agent = MagicMock()
|
||
|
|
running_agent.steer.return_value = False
|
||
|
|
runner._running_agents[sk] = running_agent
|
||
|
|
|
||
|
|
result = await runner._handle_message(_make_event("/steer hello"))
|
||
|
|
|
||
|
|
assert result is not None
|
||
|
|
assert "rejected" in result.lower() or "empty" in result.lower()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__": # pragma: no cover
|
||
|
|
pytest.main([__file__, "-v"])
|