mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 07:21:37 +08:00
Compare commits
5 Commits
opencode-p
...
feat/strea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d41115aa31 | ||
|
|
9f086173be | ||
|
|
57faddd808 | ||
|
|
af6a92a4c2 | ||
|
|
4d6c90c6d0 |
63
cli.py
63
cli.py
@@ -1187,6 +1187,7 @@ class HermesCLI:
|
|||||||
# History file for persistent input recall across sessions
|
# History file for persistent input recall across sessions
|
||||||
self._history_file = Path.home() / ".hermes_history"
|
self._history_file = Path.home() / ".hermes_history"
|
||||||
self._last_invalidate: float = 0.0 # throttle UI repaints
|
self._last_invalidate: float = 0.0 # throttle UI repaints
|
||||||
|
self._stream_buf = ""
|
||||||
|
|
||||||
def _invalidate(self, min_interval: float = 0.25) -> None:
|
def _invalidate(self, min_interval: float = 0.25) -> None:
|
||||||
"""Throttled UI repaint — prevents terminal blinking on slow/SSH connections."""
|
"""Throttled UI repaint — prevents terminal blinking on slow/SSH connections."""
|
||||||
@@ -1386,6 +1387,7 @@ class HermesCLI:
|
|||||||
platform="cli",
|
platform="cli",
|
||||||
session_db=self._session_db,
|
session_db=self._session_db,
|
||||||
clarify_callback=self._clarify_callback,
|
clarify_callback=self._clarify_callback,
|
||||||
|
stream_delta_callback=self._stream_delta,
|
||||||
honcho_session_key=self.session_id,
|
honcho_session_key=self.session_id,
|
||||||
fallback_model=self._fallback_model,
|
fallback_model=self._fallback_model,
|
||||||
)
|
)
|
||||||
@@ -2905,6 +2907,28 @@ class HermesCLI:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ❌ MCP reload failed: {e}")
|
print(f" ❌ MCP reload failed: {e}")
|
||||||
|
|
||||||
|
_stream_started = False
|
||||||
|
|
||||||
|
def _stream_delta(self, text: str):
|
||||||
|
"""Buffer streaming tokens; emit complete lines via _cprint."""
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
if not self._stream_started:
|
||||||
|
text = text.lstrip("\n")
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
self._stream_started = True
|
||||||
|
self._stream_buf += text
|
||||||
|
while "\n" in self._stream_buf:
|
||||||
|
line, self._stream_buf = self._stream_buf.split("\n", 1)
|
||||||
|
_cprint(line)
|
||||||
|
|
||||||
|
def _flush_stream(self):
|
||||||
|
"""Emit any remaining partial line from the stream buffer."""
|
||||||
|
if self._stream_buf:
|
||||||
|
_cprint(self._stream_buf)
|
||||||
|
self._stream_buf = ""
|
||||||
|
|
||||||
def _clarify_callback(self, question, choices):
|
def _clarify_callback(self, question, choices):
|
||||||
"""
|
"""
|
||||||
Platform callback for the clarify tool. Called from the agent thread.
|
Platform callback for the clarify tool. Called from the agent thread.
|
||||||
@@ -3076,12 +3100,12 @@ class HermesCLI:
|
|||||||
message if isinstance(message, str) else "", images
|
message if isinstance(message, str) else "", images
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add user message to history
|
|
||||||
self.conversation_history.append({"role": "user", "content": message})
|
self.conversation_history.append({"role": "user", "content": message})
|
||||||
|
self._stream_buf = ""
|
||||||
|
self._stream_started = False
|
||||||
|
|
||||||
w = shutil.get_terminal_size().columns
|
w = shutil.get_terminal_size().columns
|
||||||
_cprint(f"{_GOLD}{'─' * w}{_RST}")
|
_cprint(f"\n{_GOLD}╭─ ⚕ Hermes {'─' * max(w - 15, 0)}╮{_RST}")
|
||||||
print(flush=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run the conversation with interrupt monitoring
|
# Run the conversation with interrupt monitoring
|
||||||
@@ -3127,43 +3151,28 @@ class HermesCLI:
|
|||||||
|
|
||||||
agent_thread.join() # Ensure agent thread completes
|
agent_thread.join() # Ensure agent thread completes
|
||||||
|
|
||||||
# Drain any remaining agent output still in the StdoutProxy
|
self._flush_stream()
|
||||||
# buffer so tool/status lines render ABOVE our response box.
|
|
||||||
# The flush pushes data into the renderer queue; the short
|
|
||||||
# sleep lets the renderer actually paint it before we draw.
|
|
||||||
import time as _time
|
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
import time as _time
|
||||||
_time.sleep(0.15)
|
_time.sleep(0.15)
|
||||||
|
|
||||||
# Update history with full conversation
|
|
||||||
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
||||||
|
|
||||||
# Get the final response
|
|
||||||
response = result.get("final_response", "") if result else ""
|
response = result.get("final_response", "") if result else ""
|
||||||
|
|
||||||
# Handle failed results (e.g., non-retryable errors like invalid model)
|
|
||||||
if result and result.get("failed") and not response:
|
if result and result.get("failed") and not response:
|
||||||
error_detail = result.get("error", "Unknown error")
|
response = f"Error: {result.get('error', 'Unknown error')}"
|
||||||
response = f"Error: {error_detail}"
|
|
||||||
|
|
||||||
# Handle interrupt - check if we were interrupted
|
|
||||||
pending_message = None
|
pending_message = None
|
||||||
if result and result.get("interrupted"):
|
if result and result.get("interrupted"):
|
||||||
pending_message = result.get("interrupt_message") or interrupt_msg
|
pending_message = result.get("interrupt_message") or interrupt_msg
|
||||||
# Add indicator that we were interrupted
|
|
||||||
if response and pending_message:
|
if response and pending_message:
|
||||||
response = response + "\n\n---\n_[Interrupted - processing new message]_"
|
response += "\n\n---\n_[Interrupted - processing new message]_"
|
||||||
|
|
||||||
|
if response and not (self.agent and self.agent.stream_delta_callback):
|
||||||
|
_cprint(f"\n{response}")
|
||||||
|
|
||||||
if response:
|
|
||||||
w = shutil.get_terminal_size().columns
|
w = shutil.get_terminal_size().columns
|
||||||
label = " ⚕ Hermes "
|
_cprint(f"{_GOLD}╰{'─' * (w - 2)}╯{_RST}")
|
||||||
fill = w - 2 - len(label) # 2 for ╭ and ╮
|
|
||||||
top = f"{_GOLD}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}"
|
|
||||||
bot = f"{_GOLD}╰{'─' * (w - 2)}╯{_RST}"
|
|
||||||
|
|
||||||
# Render box + response as a single _cprint call so
|
|
||||||
# nothing can interleave between the box borders.
|
|
||||||
_cprint(f"\n{top}\n{response}\n\n{bot}")
|
|
||||||
|
|
||||||
# Play terminal bell when agent finishes (if enabled).
|
# Play terminal bell when agent finishes (if enabled).
|
||||||
# Works over SSH — the bell propagates to the user's terminal.
|
# Works over SSH — the bell propagates to the user's terminal.
|
||||||
@@ -3620,7 +3629,7 @@ class HermesCLI:
|
|||||||
return ""
|
return ""
|
||||||
if cli_ref._agent_running:
|
if cli_ref._agent_running:
|
||||||
return "type a message + Enter to interrupt, Ctrl+C to cancel"
|
return "type a message + Enter to interrupt, Ctrl+C to cancel"
|
||||||
return ""
|
return "Ask Hermes anything... (Alt+Enter for newline)"
|
||||||
|
|
||||||
input_area.control.input_processors.append(_PlaceholderProcessor(_get_placeholder))
|
input_area.control.input_processors.append(_PlaceholderProcessor(_get_placeholder))
|
||||||
|
|
||||||
|
|||||||
@@ -319,7 +319,7 @@ class SendResult:
|
|||||||
raw_response: Any = None
|
raw_response: Any = None
|
||||||
|
|
||||||
|
|
||||||
# Type for message handlers
|
# Handler may return str (sent by base) or dict(content=..., already_sent=True).
|
||||||
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
||||||
|
|
||||||
|
|
||||||
@@ -691,10 +691,19 @@ class BasePlatformAdapter(ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Call the handler (this can take a while with tool calls)
|
# Call the handler (this can take a while with tool calls)
|
||||||
response = await self._message_handler(event)
|
handler_result = await self._message_handler(event)
|
||||||
|
|
||||||
|
# Normalise: handler may return str or dict(content, already_sent)
|
||||||
|
already_sent = False
|
||||||
|
if isinstance(handler_result, dict):
|
||||||
|
response = handler_result.get("content") or ""
|
||||||
|
already_sent = handler_result.get("already_sent", False)
|
||||||
|
else:
|
||||||
|
response = handler_result
|
||||||
|
|
||||||
# Send response if any
|
# Send response if any
|
||||||
if not response:
|
if not response:
|
||||||
|
if not already_sent:
|
||||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||||
if response:
|
if response:
|
||||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||||
@@ -706,7 +715,7 @@ class BasePlatformAdapter(ABC):
|
|||||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||||
|
|
||||||
# Send the text portion first (if any remains after extractions)
|
# Send the text portion first (if any remains after extractions)
|
||||||
if text_content:
|
if text_content and not already_sent:
|
||||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||||
result = await self.send(
|
result = await self.send(
|
||||||
chat_id=event.source.chat_id,
|
chat_id=event.source.chat_id,
|
||||||
|
|||||||
@@ -1297,6 +1297,8 @@ class GatewayRunner:
|
|||||||
# Update session
|
# Update session
|
||||||
self.session_store.update_session(session_entry.session_key)
|
self.session_store.update_session(session_entry.session_key)
|
||||||
|
|
||||||
|
if agent_result.get("already_sent"):
|
||||||
|
return {"content": response, "already_sent": True}
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -2456,6 +2458,83 @@ class GatewayRunner:
|
|||||||
progress_queue = queue.Queue() if tool_progress_enabled else None
|
progress_queue = queue.Queue() if tool_progress_enabled else None
|
||||||
last_tool = [None] # Mutable container for tracking in closure
|
last_tool = [None] # Mutable container for tracking in closure
|
||||||
|
|
||||||
|
# Streaming token queue — same pattern as progress_queue but for
|
||||||
|
# assistant text deltas. An async drain task sends/edits a single
|
||||||
|
# platform message with the accumulated text.
|
||||||
|
stream_queue = queue.Queue()
|
||||||
|
stream_sent = [False] # set True once any delta was delivered
|
||||||
|
|
||||||
|
def _stream_delta(text: str):
|
||||||
|
stream_queue.put(text)
|
||||||
|
|
||||||
|
async def send_stream_messages():
|
||||||
|
"""Drain stream_queue, deliver via send/edit_message."""
|
||||||
|
_adapter = self.adapters.get(source.platform)
|
||||||
|
if not _adapter:
|
||||||
|
return
|
||||||
|
|
||||||
|
accumulated = []
|
||||||
|
msg_id = None
|
||||||
|
can_edit = True
|
||||||
|
last_edit_ts = 0.0
|
||||||
|
EDIT_INTERVAL = 0.6 # seconds between edits (rate-limit safe)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
delta = stream_queue.get_nowait()
|
||||||
|
accumulated.append(delta)
|
||||||
|
stream_sent[0] = True
|
||||||
|
|
||||||
|
now = asyncio.get_event_loop().time()
|
||||||
|
if now - last_edit_ts < EDIT_INTERVAL:
|
||||||
|
# Coalesce — will flush on next poll cycle
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_text = "".join(accumulated)
|
||||||
|
if msg_id is None:
|
||||||
|
res = await _adapter.send(
|
||||||
|
chat_id=source.chat_id, content=full_text)
|
||||||
|
if res.success and res.message_id:
|
||||||
|
msg_id = res.message_id
|
||||||
|
elif can_edit:
|
||||||
|
res = await _adapter.edit_message(
|
||||||
|
chat_id=source.chat_id,
|
||||||
|
message_id=msg_id,
|
||||||
|
content=full_text,
|
||||||
|
)
|
||||||
|
if not res.success:
|
||||||
|
can_edit = False
|
||||||
|
last_edit_ts = now
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Final flush
|
||||||
|
while not stream_queue.empty():
|
||||||
|
try:
|
||||||
|
accumulated.append(stream_queue.get_nowait())
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
if accumulated:
|
||||||
|
full_text = "".join(accumulated)
|
||||||
|
if msg_id is None:
|
||||||
|
await _adapter.send(
|
||||||
|
chat_id=source.chat_id, content=full_text)
|
||||||
|
elif can_edit:
|
||||||
|
try:
|
||||||
|
await _adapter.edit_message(
|
||||||
|
chat_id=source.chat_id,
|
||||||
|
message_id=msg_id,
|
||||||
|
content=full_text,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Stream message error: %s", e)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
def progress_callback(tool_name: str, preview: str = None, args: dict = None):
|
def progress_callback(tool_name: str, preview: str = None, args: dict = None):
|
||||||
"""Callback invoked by agent when a tool is called."""
|
"""Callback invoked by agent when a tool is called."""
|
||||||
if not progress_queue:
|
if not progress_queue:
|
||||||
@@ -2698,6 +2777,7 @@ class GatewayRunner:
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||||
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
|
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
|
||||||
|
stream_delta_callback=_stream_delta,
|
||||||
platform=platform_key,
|
platform=platform_key,
|
||||||
honcho_session_key=session_key,
|
honcho_session_key=session_key,
|
||||||
session_db=self._session_db,
|
session_db=self._session_db,
|
||||||
@@ -2820,6 +2900,7 @@ class GatewayRunner:
|
|||||||
"api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0,
|
"api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0,
|
||||||
"tools": tools_holder[0] or [],
|
"tools": tools_holder[0] or [],
|
||||||
"history_offset": len(agent_history),
|
"history_offset": len(agent_history),
|
||||||
|
"already_sent": stream_sent[0],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Start progress message sender if enabled
|
# Start progress message sender if enabled
|
||||||
@@ -2827,6 +2908,9 @@ class GatewayRunner:
|
|||||||
if tool_progress_enabled:
|
if tool_progress_enabled:
|
||||||
progress_task = asyncio.create_task(send_progress_messages())
|
progress_task = asyncio.create_task(send_progress_messages())
|
||||||
|
|
||||||
|
# Start stream message sender
|
||||||
|
stream_task = asyncio.create_task(send_stream_messages())
|
||||||
|
|
||||||
# Track this agent as running for this session (for interrupt support)
|
# Track this agent as running for this session (for interrupt support)
|
||||||
# We do this in a callback after the agent is created
|
# We do this in a callback after the agent is created
|
||||||
async def track_agent():
|
async def track_agent():
|
||||||
@@ -2901,9 +2985,10 @@ class GatewayRunner:
|
|||||||
session_key=session_key
|
session_key=session_key
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Stop progress sender and interrupt monitor
|
# Stop progress sender, stream sender, and interrupt monitor
|
||||||
if progress_task:
|
if progress_task:
|
||||||
progress_task.cancel()
|
progress_task.cancel()
|
||||||
|
stream_task.cancel()
|
||||||
interrupt_monitor.cancel()
|
interrupt_monitor.cancel()
|
||||||
|
|
||||||
# Clean up tracking
|
# Clean up tracking
|
||||||
@@ -2912,7 +2997,7 @@ class GatewayRunner:
|
|||||||
del self._running_agents[session_key]
|
del self._running_agents[session_key]
|
||||||
|
|
||||||
# Wait for cancelled tasks
|
# Wait for cancelled tasks
|
||||||
for task in [progress_task, interrupt_monitor, tracking_task]:
|
for task in [progress_task, stream_task, interrupt_monitor, tracking_task]:
|
||||||
if task:
|
if task:
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
|
|||||||
156
run_agent.py
156
run_agent.py
@@ -174,6 +174,7 @@ class AIAgent:
|
|||||||
tool_progress_callback: callable = None,
|
tool_progress_callback: callable = None,
|
||||||
clarify_callback: callable = None,
|
clarify_callback: callable = None,
|
||||||
step_callback: callable = None,
|
step_callback: callable = None,
|
||||||
|
stream_delta_callback: callable = None,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
reasoning_config: Dict[str, Any] = None,
|
reasoning_config: Dict[str, Any] = None,
|
||||||
prefill_messages: List[Dict[str, Any]] = None,
|
prefill_messages: List[Dict[str, Any]] = None,
|
||||||
@@ -258,6 +259,7 @@ class AIAgent:
|
|||||||
self.tool_progress_callback = tool_progress_callback
|
self.tool_progress_callback = tool_progress_callback
|
||||||
self.clarify_callback = clarify_callback
|
self.clarify_callback = clarify_callback
|
||||||
self.step_callback = step_callback
|
self.step_callback = step_callback
|
||||||
|
self.stream_delta_callback = stream_delta_callback
|
||||||
self._last_reported_tool = None # Track for "new tool" mode
|
self._last_reported_tool = None # Track for "new tool" mode
|
||||||
|
|
||||||
# Interrupt mechanism for breaking out of tool loops
|
# Interrupt mechanism for breaking out of tool loops
|
||||||
@@ -2158,6 +2160,137 @@ class AIAgent:
|
|||||||
raise result["error"]
|
raise result["error"]
|
||||||
return result["response"]
|
return result["response"]
|
||||||
|
|
||||||
|
def _interruptible_streaming_api_call(self, api_kwargs: dict, on_first_delta=None):
|
||||||
|
"""Streaming variant of _interruptible_api_call for chat_completions.
|
||||||
|
|
||||||
|
Fires self.stream_delta_callback(text) as content tokens arrive and
|
||||||
|
accumulates the full response into a SimpleNamespace matching the shape
|
||||||
|
downstream code expects. Falls back to the non-streaming path when the
|
||||||
|
provider rejects the stream request.
|
||||||
|
"""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
result = {"response": None, "error": None}
|
||||||
|
first_delta_fired = [False]
|
||||||
|
|
||||||
|
def _stream():
|
||||||
|
try:
|
||||||
|
stream_kwargs = {**api_kwargs, "stream": True,
|
||||||
|
"stream_options": {"include_usage": True}}
|
||||||
|
stream = self.client.chat.completions.create(**stream_kwargs)
|
||||||
|
|
||||||
|
content_parts = []
|
||||||
|
tool_calls_acc = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
usage = None
|
||||||
|
reasoning_content = None
|
||||||
|
model = None
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if not chunk.choices:
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if choice.finish_reason:
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
if model is None and hasattr(chunk, "model"):
|
||||||
|
model = chunk.model
|
||||||
|
|
||||||
|
delta = choice.delta
|
||||||
|
if delta is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if delta.content:
|
||||||
|
content_parts.append(delta.content)
|
||||||
|
if not first_delta_fired[0]:
|
||||||
|
first_delta_fired[0] = True
|
||||||
|
if on_first_delta:
|
||||||
|
on_first_delta()
|
||||||
|
if self.stream_delta_callback:
|
||||||
|
try:
|
||||||
|
self.stream_delta_callback(delta.content)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tc_delta in delta.tool_calls:
|
||||||
|
idx = tc_delta.index
|
||||||
|
if idx not in tool_calls_acc:
|
||||||
|
tool_calls_acc[idx] = {
|
||||||
|
"id": tc_delta.id or "",
|
||||||
|
"type": tc_delta.type or "function",
|
||||||
|
"function": {
|
||||||
|
"name": getattr(tc_delta.function, "name", None) or "",
|
||||||
|
"arguments": getattr(tc_delta.function, "arguments", None) or "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
entry = tool_calls_acc[idx]
|
||||||
|
if tc_delta.id:
|
||||||
|
entry["id"] = tc_delta.id
|
||||||
|
fn = tc_delta.function
|
||||||
|
if fn:
|
||||||
|
if fn.name:
|
||||||
|
entry["function"]["name"] = fn.name
|
||||||
|
if fn.arguments:
|
||||||
|
entry["function"]["arguments"] += fn.arguments
|
||||||
|
|
||||||
|
rc = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
||||||
|
if rc:
|
||||||
|
reasoning_content = (reasoning_content or "") + rc
|
||||||
|
|
||||||
|
tool_calls_list = None
|
||||||
|
if tool_calls_acc:
|
||||||
|
tool_calls_list = [
|
||||||
|
SimpleNamespace(
|
||||||
|
id=tc["id"], call_id=tc["id"], type=tc["type"],
|
||||||
|
function=SimpleNamespace(name=tc["function"]["name"],
|
||||||
|
arguments=tc["function"]["arguments"]),
|
||||||
|
)
|
||||||
|
for idx, tc in sorted(tool_calls_acc.items())
|
||||||
|
]
|
||||||
|
|
||||||
|
message = SimpleNamespace(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=tool_calls_list,
|
||||||
|
reasoning=reasoning_content,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
reasoning_details=None,
|
||||||
|
)
|
||||||
|
result["response"] = SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(message=message, finish_reason=finish_reason)],
|
||||||
|
usage=usage,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
result["error"] = e
|
||||||
|
|
||||||
|
t = threading.Thread(target=_stream, daemon=True)
|
||||||
|
t.start()
|
||||||
|
while t.is_alive():
|
||||||
|
t.join(timeout=0.3)
|
||||||
|
if self._interrupt_requested:
|
||||||
|
try:
|
||||||
|
self.client.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.client = OpenAI(**self._client_kwargs)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise InterruptedError("Agent interrupted during streaming API call")
|
||||||
|
|
||||||
|
if result["error"] is not None:
|
||||||
|
err = result["error"]
|
||||||
|
err_str = str(err).lower()
|
||||||
|
if any(kw in err_str for kw in ("stream", "not support", "unsupported")):
|
||||||
|
logger.debug("Streaming failed (%s), falling back to non-streaming.", err)
|
||||||
|
return self._interruptible_api_call(api_kwargs)
|
||||||
|
raise err
|
||||||
|
return result["response"]
|
||||||
|
|
||||||
# ── Provider fallback ──────────────────────────────────────────────────
|
# ── Provider fallback ──────────────────────────────────────────────────
|
||||||
|
|
||||||
# API-key providers: provider → (base_url, [env_var_names])
|
# API-key providers: provider → (base_url, [env_var_names])
|
||||||
@@ -3353,12 +3486,27 @@ class AIAgent:
|
|||||||
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||||
|
|
||||||
|
if self.stream_delta_callback and self.api_mode != "codex_responses":
|
||||||
|
def _stop_spinner():
|
||||||
|
nonlocal thinking_spinner
|
||||||
|
if thinking_spinner:
|
||||||
|
thinking_spinner.stop("")
|
||||||
|
thinking_spinner = None
|
||||||
|
|
||||||
|
response = self._interruptible_streaming_api_call(
|
||||||
|
api_kwargs, on_first_delta=_stop_spinner)
|
||||||
|
|
||||||
|
# Separate streamed content from tool status lines
|
||||||
|
msg = getattr(response, "choices", [None])[0]
|
||||||
|
if msg and getattr(msg, "message", None):
|
||||||
|
m = msg.message
|
||||||
|
if m.content and m.tool_calls:
|
||||||
|
print(flush=True)
|
||||||
|
else:
|
||||||
response = self._interruptible_api_call(api_kwargs)
|
response = self._interruptible_api_call(api_kwargs)
|
||||||
|
|
||||||
api_duration = time.time() - api_start_time
|
api_duration = time.time() - api_start_time
|
||||||
|
|
||||||
# Stop thinking spinner silently -- the response box or tool
|
|
||||||
# execution messages that follow are more informative.
|
|
||||||
if thinking_spinner:
|
if thinking_spinner:
|
||||||
thinking_spinner.stop("")
|
thinking_spinner.stop("")
|
||||||
thinking_spinner = None
|
thinking_spinner = None
|
||||||
@@ -4055,8 +4203,8 @@ class AIAgent:
|
|||||||
turn_content = assistant_message.content or ""
|
turn_content = assistant_message.content or ""
|
||||||
if turn_content and self._has_content_after_think_block(turn_content):
|
if turn_content and self._has_content_after_think_block(turn_content):
|
||||||
self._last_content_with_tools = turn_content
|
self._last_content_with_tools = turn_content
|
||||||
# Show intermediate commentary so the user can follow along
|
# Show intermediate commentary — skip when streaming (already in buffer)
|
||||||
if self.quiet_mode:
|
if self.quiet_mode and not self.stream_delta_callback:
|
||||||
clean = self._strip_think_blocks(turn_content).strip()
|
clean = self._strip_think_blocks(turn_content).strip()
|
||||||
if clean:
|
if clean:
|
||||||
print(f" ┊ 💬 {clean}")
|
print(f" ┊ 💬 {clean}")
|
||||||
|
|||||||
256
tests/test_streaming.py
Normal file
256
tests/test_streaming.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
"""Tests for streaming token output — accumulator shape, callback order, fallback."""
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch, call
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_tool_defs(*names):
|
||||||
|
return [
|
||||||
|
{"type": "function", "function": {"name": n, "description": f"{n}", "parameters": {"type": "object", "properties": {}}}}
|
||||||
|
for n in names
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def agent():
|
||||||
|
with (
|
||||||
|
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||||
|
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||||
|
patch("run_agent.OpenAI"),
|
||||||
|
):
|
||||||
|
cb = MagicMock()
|
||||||
|
a = AIAgent(
|
||||||
|
api_key="test-key-1234567890",
|
||||||
|
quiet_mode=True,
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
stream_delta_callback=cb,
|
||||||
|
)
|
||||||
|
a.client = MagicMock()
|
||||||
|
a._stream_cb = cb
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers — fake streaming chunks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _chunk(content=None, tool_call_delta=None, finish_reason=None, usage=None, model=None):
|
||||||
|
delta = SimpleNamespace(content=content, tool_calls=tool_call_delta)
|
||||||
|
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason)
|
||||||
|
c = SimpleNamespace(choices=[choice])
|
||||||
|
if usage is not None:
|
||||||
|
c.usage = SimpleNamespace(**usage)
|
||||||
|
if model:
|
||||||
|
c.model = model
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def _usage_chunk(**kw):
|
||||||
|
c = SimpleNamespace(choices=[], usage=SimpleNamespace(**kw))
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def _tc_delta(index, id=None, name=None, arguments=None, type=None):
|
||||||
|
fn = SimpleNamespace(name=name, arguments=arguments)
|
||||||
|
return SimpleNamespace(index=index, id=id, type=type, function=fn)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: accumulator shape
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingAccumulator:
|
||||||
|
def test_text_only_response(self, agent):
|
||||||
|
"""Streaming text-only response produces correct synthetic shape."""
|
||||||
|
chunks = [
|
||||||
|
_chunk(content="Hello", model="test/m"),
|
||||||
|
_chunk(content=" world"),
|
||||||
|
_chunk(finish_reason="stop"),
|
||||||
|
_usage_chunk(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||||
|
]
|
||||||
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
|
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||||
|
|
||||||
|
assert resp.choices[0].message.content == "Hello world"
|
||||||
|
assert resp.choices[0].message.tool_calls is None
|
||||||
|
assert resp.choices[0].finish_reason == "stop"
|
||||||
|
assert resp.usage.prompt_tokens == 10
|
||||||
|
assert resp.model == "test/m"
|
||||||
|
|
||||||
|
def test_tool_call_response(self, agent):
|
||||||
|
"""Streaming tool-call response accumulates function name + arguments."""
|
||||||
|
chunks = [
|
||||||
|
_chunk(tool_call_delta=[_tc_delta(0, id="call_1", name="web_search", arguments='{"q', type="function")]),
|
||||||
|
_chunk(tool_call_delta=[_tc_delta(0, arguments='uery": "hi"}')]),
|
||||||
|
_chunk(finish_reason="tool_calls"),
|
||||||
|
]
|
||||||
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
|
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||||
|
|
||||||
|
tc = resp.choices[0].message.tool_calls
|
||||||
|
assert tc is not None
|
||||||
|
assert len(tc) == 1
|
||||||
|
assert tc[0].id == "call_1"
|
||||||
|
assert tc[0].function.name == "web_search"
|
||||||
|
assert tc[0].function.arguments == '{"query": "hi"}'
|
||||||
|
assert resp.choices[0].finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
def test_mixed_content_and_tool_calls(self, agent):
|
||||||
|
"""Content + tool calls in same stream are both accumulated."""
|
||||||
|
chunks = [
|
||||||
|
_chunk(content="Let me check."),
|
||||||
|
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="web_search", arguments="{}", type="function")]),
|
||||||
|
_chunk(finish_reason="tool_calls"),
|
||||||
|
]
|
||||||
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
|
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||||
|
|
||||||
|
assert resp.choices[0].message.content == "Let me check."
|
||||||
|
assert len(resp.choices[0].message.tool_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingCallbacks:
|
||||||
|
def test_deltas_fire_in_order(self, agent):
|
||||||
|
"""stream_delta_callback receives content deltas in order."""
|
||||||
|
received = []
|
||||||
|
agent.stream_delta_callback = lambda t: received.append(t)
|
||||||
|
chunks = [_chunk(content="a"), _chunk(content="b"), _chunk(content="c"), _chunk(finish_reason="stop")]
|
||||||
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
|
agent._interruptible_streaming_api_call({"model": "test"})
|
||||||
|
|
||||||
|
assert received == ["a", "b", "c"]
|
||||||
|
|
||||||
|
def test_on_first_delta_fires_once(self, agent):
|
||||||
|
first = MagicMock()
|
||||||
|
chunks = [_chunk(content="x"), _chunk(content="y"), _chunk(finish_reason="stop")]
|
||||||
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
|
agent._interruptible_streaming_api_call({"model": "test"}, on_first_delta=first)
|
||||||
|
|
||||||
|
first.assert_called_once()
|
||||||
|
|
||||||
|
def test_tool_only_does_not_fire_callback(self, agent):
|
||||||
|
"""Tool-call-only stream does not invoke stream_delta_callback."""
|
||||||
|
received = []
|
||||||
|
agent.stream_delta_callback = lambda t: received.append(t)
|
||||||
|
chunks = [
|
||||||
|
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="t", arguments="{}", type="function")]),
|
||||||
|
_chunk(finish_reason="tool_calls"),
|
||||||
|
]
|
||||||
|
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||||
|
|
||||||
|
agent._interruptible_streaming_api_call({"model": "test"})
|
||||||
|
|
||||||
|
assert received == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamingFallback:
|
||||||
|
def test_stream_error_falls_back(self, agent):
|
||||||
|
"""When streaming fails with 'not support', falls back to non-streaming."""
|
||||||
|
agent.client.chat.completions.create.side_effect = [
|
||||||
|
Exception("streaming not supported by this provider"),
|
||||||
|
SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(
|
||||||
|
message=SimpleNamespace(content="ok", tool_calls=None, reasoning=None, reasoning_content=None, reasoning_details=None),
|
||||||
|
finish_reason="stop",
|
||||||
|
)],
|
||||||
|
usage=None,
|
||||||
|
model="test/m",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||||
|
|
||||||
|
assert resp.choices[0].message.content == "ok"
|
||||||
|
assert agent.client.chat.completions.create.call_count == 2
|
||||||
|
|
||||||
|
def test_non_stream_error_raises(self, agent):
|
||||||
|
"""Non-stream-related errors propagate normally."""
|
||||||
|
agent.client.chat.completions.create.side_effect = ValueError("bad request")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="bad request"):
|
||||||
|
agent._interruptible_streaming_api_call({"model": "test"})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: base.py already_sent contract
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAlreadySentContract:
|
||||||
|
def _make_adapter(self, send_side_effect=None):
|
||||||
|
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||||
|
from gateway.config import Platform, PlatformConfig
|
||||||
|
|
||||||
|
class FakeAdapter(BasePlatformAdapter):
|
||||||
|
async def connect(self): return True
|
||||||
|
async def disconnect(self): pass
|
||||||
|
async def get_chat_info(self, chat_id): return {"name": "test"}
|
||||||
|
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||||
|
if send_side_effect is not None:
|
||||||
|
send_side_effect(content)
|
||||||
|
return SendResult(success=True, message_id="1")
|
||||||
|
|
||||||
|
cfg = PlatformConfig(enabled=True)
|
||||||
|
adapter = FakeAdapter(cfg, Platform.TELEGRAM)
|
||||||
|
adapter._running = True
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_already_sent_skips_send(self):
|
||||||
|
"""Handler returning already_sent=True prevents base from calling send()."""
|
||||||
|
from gateway.platforms.base import MessageEvent
|
||||||
|
from gateway.config import Platform
|
||||||
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
sent = []
|
||||||
|
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
|
||||||
|
|
||||||
|
async def handler(event):
|
||||||
|
return {"content": "hello", "already_sent": True}
|
||||||
|
adapter.set_message_handler(handler)
|
||||||
|
|
||||||
|
event = MessageEvent(
|
||||||
|
text="hi",
|
||||||
|
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
|
||||||
|
)
|
||||||
|
await adapter._process_message_background(event, "s1")
|
||||||
|
|
||||||
|
assert sent == [], "send() should not be called when already_sent=True"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_string_response_sends_normally(self):
|
||||||
|
"""Handler returning a plain string triggers send() as before."""
|
||||||
|
from gateway.platforms.base import MessageEvent
|
||||||
|
from gateway.config import Platform
|
||||||
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
sent = []
|
||||||
|
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
|
||||||
|
|
||||||
|
async def handler(event):
|
||||||
|
return "hello"
|
||||||
|
adapter.set_message_handler(handler)
|
||||||
|
|
||||||
|
event = MessageEvent(
|
||||||
|
text="hi",
|
||||||
|
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
|
||||||
|
)
|
||||||
|
await adapter._process_message_background(event, "s1")
|
||||||
|
|
||||||
|
assert "hello" in sent
|
||||||
Reference in New Issue
Block a user