mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
|
|
"""WebSocket transport for the tui_gateway JSON-RPC server.
|
||
|
|
|
||
|
|
Reuses :func:`tui_gateway.server.dispatch` verbatim so every RPC method, every
|
||
|
|
slash command, every approval/clarify/sudo flow, and every agent event flows
|
||
|
|
through the same handlers whether the client is Ink over stdio or an iOS /
|
||
|
|
web client over WebSocket.
|
||
|
|
|
||
|
|
Wire protocol
|
||
|
|
-------------
|
||
|
|
Identical to stdio: newline-delimited JSON-RPC in both directions. The server
|
||
|
|
emits a ``gateway.ready`` event immediately after connection accept, then
|
||
|
|
echoes responses/events for inbound requests. No framing differences.
|
||
|
|
|
||
|
|
Mounting
|
||
|
|
--------
|
||
|
|
from fastapi import WebSocket
|
||
|
|
from tui_gateway.ws import handle_ws
|
||
|
|
|
||
|
|
@app.websocket("/api/ws")
|
||
|
|
async def ws(ws: WebSocket):
|
||
|
|
await handle_ws(ws)
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from tui_gateway import server
|
||
|
|
|
||
|
|
_log = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Max seconds a pool-dispatched handler will block waiting for the event loop
|
||
|
|
# to flush a WS frame before we mark the transport dead. Protects handler
|
||
|
|
# threads from a wedged socket.
|
||
|
|
_WS_WRITE_TIMEOUT_S = 10.0
|
||
|
|
|
||
|
|
# Keep starlette optional at import time; handle_ws uses the real class when
|
||
|
|
# it's available and falls back to a generic Exception sentinel otherwise.
|
||
|
|
try:
|
||
|
|
from starlette.websockets import WebSocketDisconnect as _WebSocketDisconnect
|
||
|
|
except ImportError: # pragma: no cover - starlette is a required install path
|
||
|
|
_WebSocketDisconnect = Exception # type: ignore[assignment]
|
||
|
|
|
||
|
|
|
||
|
|
class WSTransport:
|
||
|
|
"""Per-connection WS transport.
|
||
|
|
|
||
|
|
``write`` is safe to call from any thread *other than* the event loop
|
||
|
|
thread that owns the socket. Pool workers (the only real caller) run in
|
||
|
|
their own threads, so marshalling onto the loop via
|
||
|
|
:func:`asyncio.run_coroutine_threadsafe` + ``future.result()`` is correct
|
||
|
|
and deadlock-free there.
|
||
|
|
|
||
|
|
When called from the loop thread itself (e.g. by ``handle_ws`` for an
|
||
|
|
inline response) the same call would deadlock: we'd schedule work onto
|
||
|
|
the loop we're currently blocking. We detect that case and fire-and-
|
||
|
|
forget instead. Callers that need to know when the bytes are on the wire
|
||
|
|
should use :meth:`write_async` from the loop thread.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, ws: Any, loop: asyncio.AbstractEventLoop) -> None:
|
||
|
|
self._ws = ws
|
||
|
|
self._loop = loop
|
||
|
|
self._closed = False
|
||
|
|
|
||
|
|
def write(self, obj: dict) -> bool:
|
||
|
|
if self._closed:
|
||
|
|
return False
|
||
|
|
|
||
|
|
line = json.dumps(obj, ensure_ascii=False)
|
||
|
|
|
||
|
|
try:
|
||
|
|
on_loop = asyncio.get_running_loop() is self._loop
|
||
|
|
except RuntimeError:
|
||
|
|
on_loop = False
|
||
|
|
|
||
|
|
if on_loop:
|
||
|
|
# Fire-and-forget — don't block the loop waiting on itself.
|
||
|
|
self._loop.create_task(self._safe_send(line))
|
||
|
|
return True
|
||
|
|
|
||
|
|
try:
|
||
|
|
fut = asyncio.run_coroutine_threadsafe(self._safe_send(line), self._loop)
|
||
|
|
fut.result(timeout=_WS_WRITE_TIMEOUT_S)
|
||
|
|
return not self._closed
|
||
|
|
except Exception as exc:
|
||
|
|
self._closed = True
|
||
|
|
_log.debug("ws write failed: %s", exc)
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def write_async(self, obj: dict) -> bool:
|
||
|
|
"""Send from the owning event loop. Awaits until the frame is on the wire."""
|
||
|
|
if self._closed:
|
||
|
|
return False
|
||
|
|
await self._safe_send(json.dumps(obj, ensure_ascii=False))
|
||
|
|
return not self._closed
|
||
|
|
|
||
|
|
async def _safe_send(self, line: str) -> None:
|
||
|
|
try:
|
||
|
|
await self._ws.send_text(line)
|
||
|
|
except Exception as exc:
|
||
|
|
self._closed = True
|
||
|
|
_log.debug("ws send failed: %s", exc)
|
||
|
|
|
||
|
|
def close(self) -> None:
|
||
|
|
self._closed = True
|
||
|
|
|
||
|
|
|
||
|
|
async def handle_ws(ws: Any) -> None:
|
||
|
|
"""Run one WebSocket session. Wire-compatible with ``tui_gateway.entry``."""
|
||
|
|
await ws.accept()
|
||
|
|
|
||
|
|
transport = WSTransport(ws, asyncio.get_running_loop())
|
||
|
|
|
||
|
|
await transport.write_async(
|
||
|
|
{
|
||
|
|
"jsonrpc": "2.0",
|
||
|
|
"method": "event",
|
||
|
|
"params": {
|
||
|
|
"type": "gateway.ready",
|
||
|
|
"payload": {"skin": server.resolve_skin()},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
raw = await ws.receive_text()
|
||
|
|
except _WebSocketDisconnect:
|
||
|
|
break
|
||
|
|
|
||
|
|
line = raw.strip()
|
||
|
|
if not line:
|
||
|
|
continue
|
||
|
|
|
||
|
|
try:
|
||
|
|
req = json.loads(line)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
ok = await transport.write_async(
|
||
|
|
{
|
||
|
|
"jsonrpc": "2.0",
|
||
|
|
"error": {"code": -32700, "message": "parse error"},
|
||
|
|
"id": None,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
if not ok:
|
||
|
|
break
|
||
|
|
continue
|
||
|
|
|
||
|
|
# dispatch() may schedule long handlers on the pool; it returns
|
||
|
|
# None in that case and the worker writes the response itself via
|
||
|
|
# the transport we pass in (a separate thread, so transport.write
|
||
|
|
# is the safe path there). For inline handlers it returns the
|
||
|
|
# response dict, which we write here from the loop.
|
||
|
|
resp = await asyncio.to_thread(server.dispatch, req, transport)
|
||
|
|
if resp is not None and not await transport.write_async(resp):
|
||
|
|
break
|
||
|
|
finally:
|
||
|
|
transport.close()
|
||
|
|
|
||
|
|
# Detach the transport from any sessions it owned so later emits
|
||
|
|
# fall back to stdio instead of crashing into a closed socket.
|
||
|
|
for _, sess in list(server._sessions.items()):
|
||
|
|
if sess.get("transport") is transport:
|
||
|
|
sess["transport"] = server._stdio_transport
|
||
|
|
|
||
|
|
try:
|
||
|
|
await ws.close()
|
||
|
|
except Exception:
|
||
|
|
pass
|