mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 23:11:37 +08:00
Compare commits
2 Commits
fix/docker
...
feat/modal
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc5ca0fe42 | ||
|
|
f035796381 |
64
modal_app.py
Normal file
64
modal_app.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Modal deployment configuration for hermes-agent.
|
||||
|
||||
Deploys the FastAPI streaming wrapper as a serverless ASGI app on Modal.
|
||||
|
||||
Usage:
|
||||
modal deploy modal_app.py # Deploy to Modal
|
||||
modal serve modal_app.py # Local dev with hot-reload
|
||||
"""
|
||||
|
||||
import modal
|
||||
|
||||
image = (
|
||||
modal.Image.debian_slim(python_version="3.11")
|
||||
.apt_install("git")
|
||||
.pip_install(
|
||||
"fastapi[standard]",
|
||||
"uvicorn",
|
||||
"openai",
|
||||
"python-dotenv",
|
||||
"fire",
|
||||
"httpx",
|
||||
"rich",
|
||||
"tenacity",
|
||||
"pyyaml",
|
||||
"requests",
|
||||
"jinja2",
|
||||
"pydantic>=2.0",
|
||||
"prompt_toolkit",
|
||||
"firecrawl-py",
|
||||
"fal-client",
|
||||
"edge-tts",
|
||||
"litellm>=1.75.5",
|
||||
"typer",
|
||||
"platformdirs",
|
||||
"PyJWT[crypto]",
|
||||
)
|
||||
.add_local_dir(".", "/app", copy=True, ignore=[".git", "__pycache__", "venv", ".venv", "*.pyc"])
|
||||
)
|
||||
|
||||
app = modal.App("hermes-agent", image=image)
|
||||
|
||||
|
||||
@app.function(
|
||||
min_containers=0,
|
||||
scaledown_window=300,
|
||||
timeout=600,
|
||||
secrets=[modal.Secret.from_name("hermes-secrets")],
|
||||
)
|
||||
@modal.concurrent(max_inputs=10)
|
||||
@modal.asgi_app()
|
||||
def web():
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Force HERMES_HOME to a known writable path inside the container
|
||||
hermes_home = "/tmp/hermes"
|
||||
os.environ["HERMES_HOME"] = hermes_home
|
||||
Path(hermes_home).mkdir(parents=True, exist_ok=True)
|
||||
(Path(hermes_home) / "logs").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sys.path.insert(0, "/app")
|
||||
from serve import app as fastapi_app
|
||||
return fastapi_app
|
||||
@@ -39,6 +39,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
modal = ["swe-rex[modal]>=1.4.0"]
|
||||
serve = ["fastapi[standard]", "uvicorn"]
|
||||
dev = ["pytest", "pytest-asyncio"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
cron = ["croniter"]
|
||||
@@ -51,6 +52,7 @@ mcp = ["mcp>=1.2.0"]
|
||||
homeassistant = ["aiohttp>=3.9.0"]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[serve]",
|
||||
"hermes-agent[messaging]",
|
||||
"hermes-agent[cron]",
|
||||
"hermes-agent[cli]",
|
||||
|
||||
46
run_agent.py
46
run_agent.py
@@ -26,6 +26,7 @@ import json
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
@@ -140,6 +141,8 @@ class AIAgent:
|
||||
skip_memory: bool = False,
|
||||
session_db=None,
|
||||
honcho_session_key: str = None,
|
||||
event_queue: "queue.Queue | None" = None,
|
||||
extra_tags: List[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -217,6 +220,8 @@ class AIAgent:
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.event_queue: queue.Queue | None = event_queue
|
||||
self._extra_tags: List[str] = extra_tags or []
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
@@ -255,7 +260,7 @@ class AIAgent:
|
||||
# Persistent error log -- always writes WARNING+ to ~/.hermes/logs/errors.log
|
||||
# so tool failures, API errors, etc. are inspectable after the fact.
|
||||
from agent.redact import RedactingFormatter
|
||||
_error_log_dir = Path.home() / ".hermes" / "logs"
|
||||
_error_log_dir = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "logs"
|
||||
_error_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
_error_log_path = _error_log_dir / "errors.log"
|
||||
from logging.handlers import RotatingFileHandler
|
||||
@@ -1305,6 +1310,19 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
logger.debug("Honcho sync failed (non-fatal): %s", e)
|
||||
|
||||
def _emit_event(self, event: Dict[str, Any]) -> None:
|
||||
"""Push a structured event onto the event queue (if one is attached).
|
||||
|
||||
Used by the serve layer to stream intermediate agent progress
|
||||
(text tokens, tool calls, tool results) back to callers over SSE.
|
||||
No-op when ``event_queue`` is ``None`` (CLI / gateway usage).
|
||||
"""
|
||||
if self.event_queue is not None:
|
||||
try:
|
||||
self.event_queue.put_nowait(event)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _build_system_prompt(self, system_message: str = None) -> str:
|
||||
"""
|
||||
Assemble the full system prompt from all layers.
|
||||
@@ -2136,9 +2154,11 @@ class AIAgent:
|
||||
"effort": "xhigh"
|
||||
}
|
||||
|
||||
# Nous Portal product attribution
|
||||
# Nous Portal product attribution + caller-supplied tags
|
||||
if _is_nous:
|
||||
extra_body["tags"] = ["product=hermes-agent"]
|
||||
tags = list(self._extra_tags)
|
||||
tags.append("product=hermes-agent")
|
||||
extra_body["tags"] = tags
|
||||
|
||||
if extra_body:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
@@ -2454,6 +2474,13 @@ class AIAgent:
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
self._emit_event({
|
||||
"type": "tool-call",
|
||||
"name": function_name,
|
||||
"args": function_args,
|
||||
"status": "calling",
|
||||
})
|
||||
|
||||
tool_start_time = time.time()
|
||||
|
||||
if function_name == "todo":
|
||||
@@ -2617,6 +2644,14 @@ class AIAgent:
|
||||
messages.append(tool_msg)
|
||||
self._log_msg_to_db(tool_msg)
|
||||
|
||||
self._emit_event({
|
||||
"type": "tool-result",
|
||||
"name": function_name,
|
||||
"output": function_result[:4000],
|
||||
"status": "complete",
|
||||
"duration": round(tool_duration, 2),
|
||||
})
|
||||
|
||||
if not self.quiet_mode:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
@@ -3779,6 +3814,9 @@ class AIAgent:
|
||||
|
||||
# Strip <think> blocks from user-facing response (keep raw in messages for trajectory)
|
||||
final_response = self._strip_think_blocks(final_response).strip()
|
||||
|
||||
if final_response:
|
||||
self._emit_event({"type": "text", "text": final_response})
|
||||
|
||||
final_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
|
||||
@@ -3875,6 +3913,8 @@ class AIAgent:
|
||||
|
||||
# Clear interrupt state after handling
|
||||
self.clear_interrupt()
|
||||
|
||||
self._emit_event({"type": "done"})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
124
serve.py
Normal file
124
serve.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""FastAPI streaming wrapper for AIAgent.
|
||||
|
||||
Exposes hermes-agent as an HTTP service with SSE streaming.
|
||||
Run locally with: uvicorn serve:app --host 0.0.0.0 --port 8000
|
||||
Deploy on Modal via modal_app.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Force HERMES_HOME to a writable path. Modal secrets may set HERMES_HOME to
|
||||
# a non-existent path (e.g. /app/tinker-atropos) — override unconditionally.
|
||||
_hermes_home = Path("/tmp/hermes")
|
||||
_hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(_hermes_home / "logs").mkdir(parents=True, exist_ok=True)
|
||||
os.environ["HERMES_HOME"] = str(_hermes_home)
|
||||
|
||||
# Pre-import modules that register signal handlers so they run in the
|
||||
# main thread (signal.signal() fails if called from a worker thread).
|
||||
try:
|
||||
import tools.browser_tool # noqa: F401
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from run_agent import AIAgent # noqa: F401
|
||||
except Exception as e:
|
||||
logger.warning("Failed to pre-import AIAgent: %s", e)
|
||||
|
||||
app = FastAPI(title="hermes-agent", version="0.1.0")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/v1/agent/stream")
|
||||
async def agent_stream(request: Request):
|
||||
body = await request.json()
|
||||
|
||||
messages = body.get("messages", [])
|
||||
model = body.get("model", "anthropic/claude-opus-4.6")
|
||||
system_prompt = body.get("system_prompt")
|
||||
toolsets = body.get("toolsets")
|
||||
max_iterations = body.get("max_iterations", 30)
|
||||
base_url = body.get("base_url") or os.getenv("AGENT_LLM_BASE_URL")
|
||||
api_key = body.get("api_key") or os.getenv("AGENT_LLM_API_KEY")
|
||||
tags = body.get("tags")
|
||||
|
||||
user_message = ""
|
||||
conversation_history = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
user_message = msg.get("content", "")
|
||||
conversation_history.append(msg)
|
||||
|
||||
if conversation_history and conversation_history[-1].get("role") == "user":
|
||||
user_message = conversation_history.pop().get("content", "")
|
||||
|
||||
eq: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=512)
|
||||
|
||||
def run_agent():
|
||||
try:
|
||||
agent = AIAgent(
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=toolsets,
|
||||
event_queue=eq,
|
||||
ephemeral_system_prompt=system_prompt,
|
||||
extra_tags=tags,
|
||||
)
|
||||
result = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history or None,
|
||||
)
|
||||
if result and result.get("failed"):
|
||||
eq.put({"type": "error", "error": result.get("error", "Agent failed")})
|
||||
eq.put({"type": "done"})
|
||||
except Exception as e:
|
||||
logger.exception("Agent error")
|
||||
eq.put({"type": "error", "error": str(e)})
|
||||
eq.put({"type": "done"})
|
||||
|
||||
thread = threading.Thread(target=run_agent, daemon=True)
|
||||
thread.start()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def event_generator():
|
||||
while True:
|
||||
try:
|
||||
event = await loop.run_in_executor(None, lambda: eq.get(timeout=120))
|
||||
except queue.Empty:
|
||||
yield "data: {\"type\": \"done\"}\n\n"
|
||||
break
|
||||
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
|
||||
if event.get("type") == "done":
|
||||
break
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
365
tests/test_serve.py
Normal file
365
tests/test_serve.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""Tests for the serve layer (serve.py) and event_queue integration.
|
||||
|
||||
Covers:
|
||||
- _emit_event: queue attached, no queue, queue full
|
||||
- extra_tags merging in _build_api_kwargs for Nous API
|
||||
- FastAPI /health endpoint
|
||||
- FastAPI /v1/agent/stream SSE endpoint (mocked AIAgent)
|
||||
|
||||
Run with: python -m pytest tests/test_serve.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import queue
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent_no_queue():
|
||||
"""AIAgent without an event_queue (CLI/gateway mode)."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent_with_queue():
|
||||
"""AIAgent with an event_queue attached (serve mode)."""
|
||||
eq = queue.Queue(maxsize=128)
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
event_queue=eq,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a, eq
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def nous_agent():
|
||||
"""AIAgent pointing at a Nous inference URL with extra_tags."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
base_url="https://stg-inference-api.nousresearch.com/v1",
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
extra_tags=["user=test-user", "tier=paid"],
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Group 1: _emit_event
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestEmitEvent:
|
||||
def test_no_queue_is_noop(self, agent_no_queue):
|
||||
"""_emit_event should silently do nothing when no queue is attached."""
|
||||
agent_no_queue._emit_event({"type": "text", "text": "hello"})
|
||||
|
||||
def test_event_pushed_to_queue(self, agent_with_queue):
|
||||
agent, eq = agent_with_queue
|
||||
event = {"type": "text", "text": "hello"}
|
||||
agent._emit_event(event)
|
||||
assert not eq.empty()
|
||||
assert eq.get_nowait() == event
|
||||
|
||||
def test_multiple_events_ordered(self, agent_with_queue):
|
||||
agent, eq = agent_with_queue
|
||||
events = [
|
||||
{"type": "tool-call", "name": "terminal", "status": "calling"},
|
||||
{"type": "tool-result", "name": "terminal", "status": "complete"},
|
||||
{"type": "text", "text": "done"},
|
||||
{"type": "done"},
|
||||
]
|
||||
for e in events:
|
||||
agent._emit_event(e)
|
||||
received = []
|
||||
while not eq.empty():
|
||||
received.append(eq.get_nowait())
|
||||
assert received == events
|
||||
|
||||
def test_full_queue_does_not_raise(self):
|
||||
"""When the queue is full, _emit_event should silently drop the event."""
|
||||
eq = queue.Queue(maxsize=1)
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
event_queue=eq,
|
||||
)
|
||||
eq.put({"type": "filler"})
|
||||
assert eq.full()
|
||||
a._emit_event({"type": "text", "text": "overflow"})
|
||||
assert eq.qsize() == 1
|
||||
assert eq.get_nowait()["type"] == "filler"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Group 2: extra_tags in _build_api_kwargs
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestExtraTags:
|
||||
def test_no_tags_on_openrouter(self, agent_no_queue):
|
||||
"""OpenRouter requests should NOT include Nous product tags."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent_no_queue._build_api_kwargs(messages)
|
||||
extra = kwargs.get("extra_body", {})
|
||||
assert "tags" not in extra
|
||||
|
||||
def test_default_product_tag_on_nous(self, nous_agent):
|
||||
"""Nous API requests should always include product=hermes-agent."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = nous_agent._build_api_kwargs(messages)
|
||||
tags = kwargs["extra_body"]["tags"]
|
||||
assert "product=hermes-agent" in tags
|
||||
|
||||
def test_extra_tags_merged(self, nous_agent):
|
||||
"""Caller-supplied tags should appear alongside the product tag."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = nous_agent._build_api_kwargs(messages)
|
||||
tags = kwargs["extra_body"]["tags"]
|
||||
assert "user=test-user" in tags
|
||||
assert "tier=paid" in tags
|
||||
assert "product=hermes-agent" in tags
|
||||
|
||||
def test_extra_tags_empty_by_default(self, agent_no_queue):
|
||||
"""Agent without extra_tags should have an empty list."""
|
||||
assert agent_no_queue._extra_tags == []
|
||||
|
||||
def test_extra_tags_does_not_mutate_original(self, nous_agent):
|
||||
"""Calling _build_api_kwargs should not grow _extra_tags each time."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
nous_agent._build_api_kwargs(messages)
|
||||
nous_agent._build_api_kwargs(messages)
|
||||
assert nous_agent._extra_tags.count("product=hermes-agent") == 0
|
||||
assert len(nous_agent._extra_tags) == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Group 3: FastAPI endpoints (serve.py)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fastapi_app():
|
||||
"""Import the FastAPI app from serve.py."""
|
||||
from serve import app
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHealthEndpoint:
|
||||
async def test_health_returns_ok(self, fastapi_app):
|
||||
transport = ASGITransport(app=fastapi_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentStreamEndpoint:
|
||||
async def test_stream_returns_sse_events(self, fastapi_app):
|
||||
"""Mock AIAgent to emit known events and verify SSE output."""
|
||||
mock_result = {
|
||||
"final_response": "Hello!",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
"completed": True,
|
||||
}
|
||||
|
||||
def fake_run_conversation(user_message, conversation_history=None):
|
||||
agent_instance = fake_init.agent_ref
|
||||
if agent_instance and agent_instance.event_queue:
|
||||
eq = agent_instance.event_queue
|
||||
eq.put({"type": "tool-call", "name": "terminal", "args": "echo hi", "status": "calling"})
|
||||
eq.put({"type": "tool-result", "name": "terminal", "output": "hi", "status": "complete", "duration": 0.1})
|
||||
eq.put({"type": "text", "text": "Hello!"})
|
||||
eq.put({"type": "done"})
|
||||
return mock_result
|
||||
|
||||
class fake_init:
|
||||
agent_ref = None
|
||||
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.run_conversation = fake_run_conversation
|
||||
fake_init.agent_ref = self
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch.object(AIAgent, "__init__", patched_init),
|
||||
):
|
||||
transport = ASGITransport(app=fastapi_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.post(
|
||||
"/v1/agent/stream",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
"model": "test/model",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert "text/event-stream" in resp.headers["content-type"]
|
||||
|
||||
lines = resp.text.strip().split("\n")
|
||||
events = []
|
||||
for line in lines:
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
|
||||
types = [e["type"] for e in events]
|
||||
assert "tool-call" in types
|
||||
assert "tool-result" in types
|
||||
assert "text" in types
|
||||
assert types[-1] == "done"
|
||||
|
||||
text_event = next(e for e in events if e["type"] == "text")
|
||||
assert text_event["text"] == "Hello!"
|
||||
|
||||
tool_call = next(e for e in events if e["type"] == "tool-call")
|
||||
assert tool_call["name"] == "terminal"
|
||||
|
||||
async def test_stream_error_propagated(self, fastapi_app):
|
||||
"""When AIAgent raises, an error event should be streamed."""
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
def exploding_run(user_message, conversation_history=None):
|
||||
raise RuntimeError("kaboom")
|
||||
|
||||
self.run_conversation = exploding_run
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch.object(AIAgent, "__init__", patched_init),
|
||||
):
|
||||
transport = ASGITransport(app=fastapi_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.post(
|
||||
"/v1/agent/stream",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "fail"}],
|
||||
"model": "test/model",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
events = []
|
||||
for line in resp.text.strip().split("\n"):
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
|
||||
error_events = [e for e in events if e["type"] == "error"]
|
||||
assert len(error_events) >= 1
|
||||
assert "kaboom" in error_events[0]["error"]
|
||||
assert events[-1]["type"] == "done"
|
||||
|
||||
async def test_stream_passes_base_url_and_tags(self, fastapi_app):
|
||||
"""Verify base_url, api_key, and tags from the request body reach AIAgent."""
|
||||
captured = {}
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
captured["base_url"] = kwargs.get("base_url")
|
||||
captured["api_key"] = kwargs.get("api_key")
|
||||
captured["extra_tags"] = kwargs.get("extra_tags")
|
||||
original_init(self, *args, **kwargs)
|
||||
self.run_conversation = lambda **kw: (
|
||||
self.event_queue.put({"type": "text", "text": "ok"}) if self.event_queue else None,
|
||||
self.event_queue.put({"type": "done"}) if self.event_queue else None,
|
||||
{"final_response": "ok", "messages": [], "api_calls": 1, "completed": True},
|
||||
)[-1]
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch.object(AIAgent, "__init__", patched_init),
|
||||
):
|
||||
transport = ASGITransport(app=fastapi_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
await client.post(
|
||||
"/v1/agent/stream",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"model": "test/model",
|
||||
"base_url": "https://my-api.example.com/v1",
|
||||
"api_key": "sk-test-key",
|
||||
"tags": ["user=alice", "tier=free"],
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
assert captured["base_url"] == "https://my-api.example.com/v1"
|
||||
assert captured["api_key"] == "sk-test-key"
|
||||
assert captured["extra_tags"] == ["user=alice", "tier=free"]
|
||||
@@ -54,8 +54,9 @@ ENVIRONMENTS_DIR = TINKER_ATROPOS_ROOT / "tinker_atropos" / "environments"
|
||||
CONFIGS_DIR = TINKER_ATROPOS_ROOT / "configs"
|
||||
LOGS_DIR = TINKER_ATROPOS_ROOT / "logs"
|
||||
|
||||
# Ensure logs directory exists
|
||||
LOGS_DIR.mkdir(exist_ok=True)
|
||||
# Ensure logs directory exists (parents=True for container environments
|
||||
# where tinker-atropos may not be checked out)
|
||||
LOGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
Reference in New Issue
Block a user