mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
366 lines
14 KiB
Python
366 lines
14 KiB
Python
|
|
"""Tests for /v1/runs endpoints: start, events, and stop.
|
||
|
|
|
||
|
|
Covers:
|
||
|
|
- POST /v1/runs — start a run (202)
|
||
|
|
- GET /v1/runs/{run_id}/events — SSE event stream
|
||
|
|
- POST /v1/runs/{run_id}/stop — interrupt a running agent
|
||
|
|
- Auth, error handling, and cleanup
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import threading
|
||
|
|
import time as _time
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from aiohttp import web
|
||
|
|
from aiohttp.test_utils import TestClient, TestServer
|
||
|
|
|
||
|
|
from gateway.config import PlatformConfig
|
||
|
|
from gateway.platforms.api_server import (
|
||
|
|
APIServerAdapter,
|
||
|
|
cors_middleware,
|
||
|
|
security_headers_middleware,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Helpers
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
def _make_adapter(api_key: str = "") -> APIServerAdapter:
|
||
|
|
"""Create an adapter with optional API key."""
|
||
|
|
extra = {}
|
||
|
|
if api_key:
|
||
|
|
extra["key"] = api_key
|
||
|
|
config = PlatformConfig(enabled=True, extra=extra)
|
||
|
|
adapter = APIServerAdapter(config)
|
||
|
|
return adapter
|
||
|
|
|
||
|
|
|
||
|
|
def _create_runs_app(adapter: APIServerAdapter) -> web.Application:
|
||
|
|
"""Create an aiohttp app with /v1/runs routes registered."""
|
||
|
|
mws = [mw for mw in (cors_middleware, security_headers_middleware) if mw is not None]
|
||
|
|
app = web.Application(middlewares=mws)
|
||
|
|
app["api_server_adapter"] = adapter
|
||
|
|
app.router.add_post("/v1/runs", adapter._handle_runs)
|
||
|
|
app.router.add_get("/v1/runs/{run_id}/events", adapter._handle_run_events)
|
||
|
|
app.router.add_post("/v1/runs/{run_id}/stop", adapter._handle_stop_run)
|
||
|
|
return app
|
||
|
|
|
||
|
|
|
||
|
|
def _make_slow_agent(**kwargs):
|
||
|
|
"""Create a mock agent that blocks in run_conversation until interrupted.
|
||
|
|
|
||
|
|
Returns (mock_agent, agent_ready_event, interrupt_event) where
|
||
|
|
agent_ready_event is set once run_conversation starts, and
|
||
|
|
interrupt_event is set when interrupt() is called.
|
||
|
|
"""
|
||
|
|
ready = threading.Event()
|
||
|
|
interrupted = threading.Event()
|
||
|
|
|
||
|
|
mock_agent = MagicMock()
|
||
|
|
|
||
|
|
def _do_interrupt(message=None):
|
||
|
|
interrupted.set()
|
||
|
|
|
||
|
|
mock_agent.interrupt = MagicMock(side_effect=_do_interrupt)
|
||
|
|
|
||
|
|
def _slow_run(user_message=None, conversation_history=None, task_id=None):
|
||
|
|
ready.set()
|
||
|
|
# Block until interrupt() is called
|
||
|
|
interrupted.wait(timeout=10)
|
||
|
|
return {"final_response": "interrupted"}
|
||
|
|
|
||
|
|
mock_agent.run_conversation.side_effect = _slow_run
|
||
|
|
mock_agent.session_prompt_tokens = 0
|
||
|
|
mock_agent.session_completion_tokens = 0
|
||
|
|
mock_agent.session_total_tokens = 0
|
||
|
|
|
||
|
|
return mock_agent, ready, interrupted
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def adapter():
|
||
|
|
return _make_adapter()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def auth_adapter():
|
||
|
|
return _make_adapter(api_key="sk-secret")
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# POST /v1/runs — start a run
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestStartRun:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_start_returns_202(self, adapter):
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
||
|
|
mock_agent = MagicMock()
|
||
|
|
mock_agent.run_conversation.return_value = {"final_response": "done"}
|
||
|
|
mock_agent.session_prompt_tokens = 10
|
||
|
|
mock_agent.session_completion_tokens = 5
|
||
|
|
mock_agent.session_total_tokens = 15
|
||
|
|
mock_create.return_value = mock_agent
|
||
|
|
|
||
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
||
|
|
assert resp.status == 202
|
||
|
|
data = await resp.json()
|
||
|
|
assert data["status"] == "started"
|
||
|
|
assert data["run_id"].startswith("run_")
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_start_invalid_json_returns_400(self, adapter):
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
resp = await cli.post(
|
||
|
|
"/v1/runs",
|
||
|
|
data="not json",
|
||
|
|
headers={"Content-Type": "application/json"},
|
||
|
|
)
|
||
|
|
assert resp.status == 400
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_start_missing_input_returns_400(self, adapter):
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
resp = await cli.post("/v1/runs", json={"model": "test"})
|
||
|
|
assert resp.status == 400
|
||
|
|
data = await resp.json()
|
||
|
|
assert "input" in data["error"]["message"]
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_start_empty_input_returns_400(self, adapter):
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
resp = await cli.post("/v1/runs", json={"input": ""})
|
||
|
|
assert resp.status == 400
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_start_requires_auth(self, auth_adapter):
|
||
|
|
app = _create_runs_app(auth_adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
||
|
|
assert resp.status == 401
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_start_with_valid_auth(self, auth_adapter):
|
||
|
|
app = _create_runs_app(auth_adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
with patch.object(auth_adapter, "_create_agent") as mock_create:
|
||
|
|
mock_agent = MagicMock()
|
||
|
|
mock_agent.run_conversation.return_value = {"final_response": "ok"}
|
||
|
|
mock_agent.session_prompt_tokens = 0
|
||
|
|
mock_agent.session_completion_tokens = 0
|
||
|
|
mock_agent.session_total_tokens = 0
|
||
|
|
mock_create.return_value = mock_agent
|
||
|
|
|
||
|
|
resp = await cli.post(
|
||
|
|
"/v1/runs",
|
||
|
|
json={"input": "hello"},
|
||
|
|
headers={"Authorization": "Bearer sk-secret"},
|
||
|
|
)
|
||
|
|
assert resp.status == 202
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# GET /v1/runs/{run_id}/events — SSE event stream
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestRunEvents:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_events_stream_returns_completed(self, adapter):
|
||
|
|
"""Events stream should receive run.completed when agent finishes."""
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
||
|
|
mock_agent = MagicMock()
|
||
|
|
mock_agent.run_conversation.return_value = {"final_response": "Hello!"}
|
||
|
|
mock_agent.session_prompt_tokens = 10
|
||
|
|
mock_agent.session_completion_tokens = 5
|
||
|
|
mock_agent.session_total_tokens = 15
|
||
|
|
mock_create.return_value = mock_agent
|
||
|
|
|
||
|
|
# Start run
|
||
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
||
|
|
assert resp.status == 202
|
||
|
|
data = await resp.json()
|
||
|
|
run_id = data["run_id"]
|
||
|
|
|
||
|
|
# Subscribe to events
|
||
|
|
events_resp = await cli.get(f"/v1/runs/{run_id}/events")
|
||
|
|
assert events_resp.status == 200
|
||
|
|
body = await events_resp.text()
|
||
|
|
|
||
|
|
# Should contain run.completed
|
||
|
|
assert "run.completed" in body
|
||
|
|
assert "Hello!" in body
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_events_not_found_returns_404(self, adapter):
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
resp = await cli.get("/v1/runs/run_nonexistent/events")
|
||
|
|
assert resp.status == 404
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_events_requires_auth(self, auth_adapter):
|
||
|
|
app = _create_runs_app(auth_adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
resp = await cli.get("/v1/runs/run_any/events")
|
||
|
|
assert resp.status == 401
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# POST /v1/runs/{run_id}/stop — interrupt a running agent
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
|
||
|
|
class TestStopRun:
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stop_running_agent(self, adapter):
|
||
|
|
"""Stop should interrupt the agent and cancel the task."""
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
||
|
|
mock_agent, agent_ready, _ = _make_slow_agent()
|
||
|
|
mock_create.return_value = mock_agent
|
||
|
|
|
||
|
|
# Start run
|
||
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
||
|
|
assert resp.status == 202
|
||
|
|
data = await resp.json()
|
||
|
|
run_id = data["run_id"]
|
||
|
|
|
||
|
|
# Wait for agent to start running in the thread
|
||
|
|
agent_ready.wait(timeout=3.0)
|
||
|
|
await asyncio.sleep(0.1)
|
||
|
|
|
||
|
|
# Verify agent ref is stored
|
||
|
|
assert run_id in adapter._active_run_agents
|
||
|
|
|
||
|
|
# Stop the run
|
||
|
|
stop_resp = await cli.post(f"/v1/runs/{run_id}/stop")
|
||
|
|
assert stop_resp.status == 200
|
||
|
|
stop_data = await stop_resp.json()
|
||
|
|
assert stop_data["run_id"] == run_id
|
||
|
|
assert stop_data["status"] == "stopping"
|
||
|
|
|
||
|
|
# Agent interrupt should have been called
|
||
|
|
mock_agent.interrupt.assert_called_once_with("Stop requested via API")
|
||
|
|
|
||
|
|
# Refs should be cleaned up
|
||
|
|
await asyncio.sleep(0.5)
|
||
|
|
assert run_id not in adapter._active_run_agents
|
||
|
|
assert run_id not in adapter._active_run_tasks
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stop_nonexistent_run_returns_404(self, adapter):
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
resp = await cli.post("/v1/runs/run_nonexistent/stop")
|
||
|
|
assert resp.status == 404
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stop_requires_auth(self, auth_adapter):
|
||
|
|
app = _create_runs_app(auth_adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
resp = await cli.post("/v1/runs/run_any/stop")
|
||
|
|
assert resp.status == 401
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stop_already_completed_run_returns_404(self, adapter):
|
||
|
|
"""Stopping a run that already finished should return 404 (refs cleaned up)."""
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
||
|
|
mock_agent = MagicMock()
|
||
|
|
mock_agent.run_conversation.return_value = {"final_response": "done"}
|
||
|
|
mock_agent.session_prompt_tokens = 0
|
||
|
|
mock_agent.session_completion_tokens = 0
|
||
|
|
mock_agent.session_total_tokens = 0
|
||
|
|
mock_create.return_value = mock_agent
|
||
|
|
|
||
|
|
# Start and wait for completion
|
||
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
||
|
|
assert resp.status == 202
|
||
|
|
data = await resp.json()
|
||
|
|
run_id = data["run_id"]
|
||
|
|
|
||
|
|
await asyncio.sleep(0.3)
|
||
|
|
|
||
|
|
# Run should be done, refs cleaned up
|
||
|
|
assert run_id not in adapter._active_run_agents
|
||
|
|
|
||
|
|
# Stop should return 404
|
||
|
|
stop_resp = await cli.post(f"/v1/runs/{run_id}/stop")
|
||
|
|
assert stop_resp.status == 404
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stop_interrupt_exception_does_not_crash(self, adapter):
|
||
|
|
"""If agent.interrupt() raises, stop should still succeed."""
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
||
|
|
mock_agent, agent_ready, _ = _make_slow_agent()
|
||
|
|
# Override the interrupt side_effect to raise
|
||
|
|
mock_agent.interrupt = MagicMock(side_effect=RuntimeError("interrupt failed"))
|
||
|
|
mock_create.return_value = mock_agent
|
||
|
|
|
||
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
||
|
|
assert resp.status == 202
|
||
|
|
data = await resp.json()
|
||
|
|
run_id = data["run_id"]
|
||
|
|
|
||
|
|
agent_ready.wait(timeout=3.0)
|
||
|
|
await asyncio.sleep(0.1)
|
||
|
|
|
||
|
|
stop_resp = await cli.post(f"/v1/runs/{run_id}/stop")
|
||
|
|
assert stop_resp.status == 200
|
||
|
|
stop_data = await stop_resp.json()
|
||
|
|
assert stop_data["status"] == "stopping"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_stop_sends_sentinel_to_events_stream(self, adapter):
|
||
|
|
"""After stop, the events stream should close."""
|
||
|
|
app = _create_runs_app(adapter)
|
||
|
|
async with TestClient(TestServer(app)) as cli:
|
||
|
|
with patch.object(adapter, "_create_agent") as mock_create:
|
||
|
|
mock_agent, agent_ready, _ = _make_slow_agent()
|
||
|
|
mock_create.return_value = mock_agent
|
||
|
|
|
||
|
|
# Start run
|
||
|
|
resp = await cli.post("/v1/runs", json={"input": "hello"})
|
||
|
|
assert resp.status == 202
|
||
|
|
data = await resp.json()
|
||
|
|
run_id = data["run_id"]
|
||
|
|
|
||
|
|
agent_ready.wait(timeout=3.0)
|
||
|
|
await asyncio.sleep(0.1)
|
||
|
|
|
||
|
|
# Subscribe to events in background
|
||
|
|
events_task = asyncio.ensure_future(
|
||
|
|
cli.get(f"/v1/runs/{run_id}/events")
|
||
|
|
)
|
||
|
|
|
||
|
|
await asyncio.sleep(0.1)
|
||
|
|
|
||
|
|
# Stop the run
|
||
|
|
stop_resp = await cli.post(f"/v1/runs/{run_id}/stop")
|
||
|
|
assert stop_resp.status == 200
|
||
|
|
|
||
|
|
# Events stream should close
|
||
|
|
events_resp = await asyncio.wait_for(events_task, timeout=5.0)
|
||
|
|
assert events_resp.status == 200
|
||
|
|
body = await events_resp.text()
|
||
|
|
# Stream should have received run.failed and closed
|
||
|
|
assert "run.failed" in body or "stream closed" in body
|