mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
Add ability to interrupt a running agent via the runs API. Previously /v1/runs could start a run and subscribe to events, but there was no way to cancel it. The new endpoint stores agent and task references during execution, calls agent.interrupt() to stop LLM calls, then cancels the asyncio task. Includes 15 tests covering start, events, and stop scenarios.
366 lines
14 KiB
Python
366 lines
14 KiB
Python
"""Tests for /v1/runs endpoints: start, events, and stop.
|
|
|
|
Covers:
|
|
- POST /v1/runs — start a run (202)
|
|
- GET /v1/runs/{run_id}/events — SSE event stream
|
|
- POST /v1/runs/{run_id}/stop — interrupt a running agent
|
|
- Auth, error handling, and cleanup
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import threading
|
|
import time as _time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from aiohttp import web
|
|
from aiohttp.test_utils import TestClient, TestServer
|
|
|
|
from gateway.config import PlatformConfig
|
|
from gateway.platforms.api_server import (
|
|
APIServerAdapter,
|
|
cors_middleware,
|
|
security_headers_middleware,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_adapter(api_key: str = "") -> APIServerAdapter:
|
|
"""Create an adapter with optional API key."""
|
|
extra = {}
|
|
if api_key:
|
|
extra["key"] = api_key
|
|
config = PlatformConfig(enabled=True, extra=extra)
|
|
adapter = APIServerAdapter(config)
|
|
return adapter
|
|
|
|
|
|
def _create_runs_app(adapter: APIServerAdapter) -> web.Application:
|
|
"""Create an aiohttp app with /v1/runs routes registered."""
|
|
mws = [mw for mw in (cors_middleware, security_headers_middleware) if mw is not None]
|
|
app = web.Application(middlewares=mws)
|
|
app["api_server_adapter"] = adapter
|
|
app.router.add_post("/v1/runs", adapter._handle_runs)
|
|
app.router.add_get("/v1/runs/{run_id}/events", adapter._handle_run_events)
|
|
app.router.add_post("/v1/runs/{run_id}/stop", adapter._handle_stop_run)
|
|
return app
|
|
|
|
|
|
def _make_slow_agent(**kwargs):
|
|
"""Create a mock agent that blocks in run_conversation until interrupted.
|
|
|
|
Returns (mock_agent, agent_ready_event, interrupt_event) where
|
|
agent_ready_event is set once run_conversation starts, and
|
|
interrupt_event is set when interrupt() is called.
|
|
"""
|
|
ready = threading.Event()
|
|
interrupted = threading.Event()
|
|
|
|
mock_agent = MagicMock()
|
|
|
|
def _do_interrupt(message=None):
|
|
interrupted.set()
|
|
|
|
mock_agent.interrupt = MagicMock(side_effect=_do_interrupt)
|
|
|
|
def _slow_run(user_message=None, conversation_history=None, task_id=None):
|
|
ready.set()
|
|
# Block until interrupt() is called
|
|
interrupted.wait(timeout=10)
|
|
return {"final_response": "interrupted"}
|
|
|
|
mock_agent.run_conversation.side_effect = _slow_run
|
|
mock_agent.session_prompt_tokens = 0
|
|
mock_agent.session_completion_tokens = 0
|
|
mock_agent.session_total_tokens = 0
|
|
|
|
return mock_agent, ready, interrupted
|
|
|
|
|
|
@pytest.fixture
|
|
def adapter():
|
|
return _make_adapter()
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_adapter():
|
|
return _make_adapter(api_key="sk-secret")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /v1/runs — start a run
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestStartRun:
|
|
@pytest.mark.asyncio
|
|
async def test_start_returns_202(self, adapter):
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
|
mock_agent = MagicMock()
|
|
mock_agent.run_conversation.return_value = {"final_response": "done"}
|
|
mock_agent.session_prompt_tokens = 10
|
|
mock_agent.session_completion_tokens = 5
|
|
mock_agent.session_total_tokens = 15
|
|
mock_create.return_value = mock_agent
|
|
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
|
assert resp.status == 202
|
|
data = await resp.json()
|
|
assert data["status"] == "started"
|
|
assert data["run_id"].startswith("run_")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_invalid_json_returns_400(self, adapter):
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
"/v1/runs",
|
|
data="not json",
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
assert resp.status == 400
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_missing_input_returns_400(self, adapter):
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/runs", json={"model": "test"})
|
|
assert resp.status == 400
|
|
data = await resp.json()
|
|
assert "input" in data["error"]["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_empty_input_returns_400(self, adapter):
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/runs", json={"input": ""})
|
|
assert resp.status == 400
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_requires_auth(self, auth_adapter):
|
|
app = _create_runs_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
|
assert resp.status == 401
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_with_valid_auth(self, auth_adapter):
|
|
app = _create_runs_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(auth_adapter, "_create_agent") as mock_create:
|
|
mock_agent = MagicMock()
|
|
mock_agent.run_conversation.return_value = {"final_response": "ok"}
|
|
mock_agent.session_prompt_tokens = 0
|
|
mock_agent.session_completion_tokens = 0
|
|
mock_agent.session_total_tokens = 0
|
|
mock_create.return_value = mock_agent
|
|
|
|
resp = await cli.post(
|
|
"/v1/runs",
|
|
json={"input": "hello"},
|
|
headers={"Authorization": "Bearer sk-secret"},
|
|
)
|
|
assert resp.status == 202
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /v1/runs/{run_id}/events — SSE event stream
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRunEvents:
|
|
@pytest.mark.asyncio
|
|
async def test_events_stream_returns_completed(self, adapter):
|
|
"""Events stream should receive run.completed when agent finishes."""
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
|
mock_agent = MagicMock()
|
|
mock_agent.run_conversation.return_value = {"final_response": "Hello!"}
|
|
mock_agent.session_prompt_tokens = 10
|
|
mock_agent.session_completion_tokens = 5
|
|
mock_agent.session_total_tokens = 15
|
|
mock_create.return_value = mock_agent
|
|
|
|
# Start run
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
|
assert resp.status == 202
|
|
data = await resp.json()
|
|
run_id = data["run_id"]
|
|
|
|
# Subscribe to events
|
|
events_resp = await cli.get(f"/v1/runs/{run_id}/events")
|
|
assert events_resp.status == 200
|
|
body = await events_resp.text()
|
|
|
|
# Should contain run.completed
|
|
assert "run.completed" in body
|
|
assert "Hello!" in body
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_events_not_found_returns_404(self, adapter):
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/v1/runs/run_nonexistent/events")
|
|
assert resp.status == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_events_requires_auth(self, auth_adapter):
|
|
app = _create_runs_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/v1/runs/run_any/events")
|
|
assert resp.status == 401
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /v1/runs/{run_id}/stop — interrupt a running agent
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestStopRun:
|
|
@pytest.mark.asyncio
|
|
async def test_stop_running_agent(self, adapter):
|
|
"""Stop should interrupt the agent and cancel the task."""
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
|
mock_agent, agent_ready, _ = _make_slow_agent()
|
|
mock_create.return_value = mock_agent
|
|
|
|
# Start run
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
|
assert resp.status == 202
|
|
data = await resp.json()
|
|
run_id = data["run_id"]
|
|
|
|
# Wait for agent to start running in the thread
|
|
agent_ready.wait(timeout=3.0)
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Verify agent ref is stored
|
|
assert run_id in adapter._active_run_agents
|
|
|
|
# Stop the run
|
|
stop_resp = await cli.post(f"/v1/runs/{run_id}/stop")
|
|
assert stop_resp.status == 200
|
|
stop_data = await stop_resp.json()
|
|
assert stop_data["run_id"] == run_id
|
|
assert stop_data["status"] == "stopping"
|
|
|
|
# Agent interrupt should have been called
|
|
mock_agent.interrupt.assert_called_once_with("Stop requested via API")
|
|
|
|
# Refs should be cleaned up
|
|
await asyncio.sleep(0.5)
|
|
assert run_id not in adapter._active_run_agents
|
|
assert run_id not in adapter._active_run_tasks
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_nonexistent_run_returns_404(self, adapter):
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/runs/run_nonexistent/stop")
|
|
assert resp.status == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_requires_auth(self, auth_adapter):
|
|
app = _create_runs_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/runs/run_any/stop")
|
|
assert resp.status == 401
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_already_completed_run_returns_404(self, adapter):
|
|
"""Stopping a run that already finished should return 404 (refs cleaned up)."""
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
|
mock_agent = MagicMock()
|
|
mock_agent.run_conversation.return_value = {"final_response": "done"}
|
|
mock_agent.session_prompt_tokens = 0
|
|
mock_agent.session_completion_tokens = 0
|
|
mock_agent.session_total_tokens = 0
|
|
mock_create.return_value = mock_agent
|
|
|
|
# Start and wait for completion
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
|
assert resp.status == 202
|
|
data = await resp.json()
|
|
run_id = data["run_id"]
|
|
|
|
await asyncio.sleep(0.3)
|
|
|
|
# Run should be done, refs cleaned up
|
|
assert run_id not in adapter._active_run_agents
|
|
|
|
# Stop should return 404
|
|
stop_resp = await cli.post(f"/v1/runs/{run_id}/stop")
|
|
assert stop_resp.status == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_interrupt_exception_does_not_crash(self, adapter):
|
|
"""If agent.interrupt() raises, stop should still succeed."""
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
|
mock_agent, agent_ready, _ = _make_slow_agent()
|
|
# Override the interrupt side_effect to raise
|
|
mock_agent.interrupt = MagicMock(side_effect=RuntimeError("interrupt failed"))
|
|
mock_create.return_value = mock_agent
|
|
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
|
assert resp.status == 202
|
|
data = await resp.json()
|
|
run_id = data["run_id"]
|
|
|
|
agent_ready.wait(timeout=3.0)
|
|
await asyncio.sleep(0.1)
|
|
|
|
stop_resp = await cli.post(f"/v1/runs/{run_id}/stop")
|
|
assert stop_resp.status == 200
|
|
stop_data = await stop_resp.json()
|
|
assert stop_data["status"] == "stopping"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_sends_sentinel_to_events_stream(self, adapter):
|
|
"""After stop, the events stream should close."""
|
|
app = _create_runs_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
|
mock_agent, agent_ready, _ = _make_slow_agent()
|
|
mock_create.return_value = mock_agent
|
|
|
|
# Start run
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
|
assert resp.status == 202
|
|
data = await resp.json()
|
|
run_id = data["run_id"]
|
|
|
|
agent_ready.wait(timeout=3.0)
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Subscribe to events in background
|
|
events_task = asyncio.ensure_future(
|
|
cli.get(f"/v1/runs/{run_id}/events")
|
|
)
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Stop the run
|
|
stop_resp = await cli.post(f"/v1/runs/{run_id}/stop")
|
|
assert stop_resp.status == 200
|
|
|
|
# Events stream should close
|
|
events_resp = await asyncio.wait_for(events_task, timeout=5.0)
|
|
assert events_resp.status == 200
|
|
body = await events_resp.text()
|
|
# Stream should have received run.failed and closed
|
|
assert "run.failed" in body or "stream closed" in body
|