mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 15:01:34 +08:00
Compare commits
1 Commits
codex-port
...
feat/strea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72decda522 |
29
cli.py
29
cli.py
@@ -1253,6 +1253,7 @@ class HermesCLI:
|
||||
# Background task tracking: {task_id: threading.Thread}
|
||||
self._background_tasks: Dict[str, threading.Thread] = {}
|
||||
self._background_task_counter = 0
|
||||
self._stream_buf = ""
|
||||
|
||||
def _invalidate(self, min_interval: float = 0.25) -> None:
|
||||
"""Throttled UI repaint — prevents terminal blinking on slow/SSH connections."""
|
||||
@@ -1495,6 +1496,7 @@ class HermesCLI:
|
||||
platform="cli",
|
||||
session_db=self._session_db,
|
||||
clarify_callback=self._clarify_callback,
|
||||
stream_delta_callback=self._stream_delta,
|
||||
honcho_session_key=self.session_id,
|
||||
fallback_model=self._fallback_model,
|
||||
thinking_callback=self._on_thinking,
|
||||
@@ -3339,6 +3341,28 @@ class HermesCLI:
|
||||
"Use your best judgement to make the choice and proceed."
|
||||
)
|
||||
|
||||
_stream_started = False
|
||||
|
||||
def _stream_delta(self, text: str):
|
||||
"""Buffer streaming tokens; emit complete lines via _cprint."""
|
||||
if not text:
|
||||
return
|
||||
if not self._stream_started:
|
||||
text = text.lstrip("\n")
|
||||
if not text:
|
||||
return
|
||||
self._stream_started = True
|
||||
self._stream_buf += text
|
||||
while "\n" in self._stream_buf:
|
||||
line, self._stream_buf = self._stream_buf.split("\n", 1)
|
||||
_cprint(line)
|
||||
|
||||
def _flush_stream(self):
|
||||
"""Emit any remaining partial line from the stream buffer."""
|
||||
if self._stream_buf:
|
||||
_cprint(self._stream_buf)
|
||||
self._stream_buf = ""
|
||||
|
||||
def _sudo_password_callback(self) -> str:
|
||||
"""
|
||||
Prompt for sudo password through the prompt_toolkit UI.
|
||||
@@ -3467,6 +3491,8 @@ class HermesCLI:
|
||||
|
||||
# Add user message to history
|
||||
self.conversation_history.append({"role": "user", "content": message})
|
||||
self._stream_buf = ""
|
||||
self._stream_started = False
|
||||
|
||||
_cprint(f"{_GOLD}{'─' * 40}{_RST}")
|
||||
print(flush=True)
|
||||
@@ -3514,6 +3540,7 @@ class HermesCLI:
|
||||
agent_thread.join(0.1)
|
||||
|
||||
agent_thread.join() # Ensure agent thread completes
|
||||
self._flush_stream()
|
||||
|
||||
# Drain any remaining agent output still in the StdoutProxy
|
||||
# buffer so tool/status lines render ABOVE our response box.
|
||||
@@ -3542,7 +3569,7 @@ class HermesCLI:
|
||||
if response and pending_message:
|
||||
response = response + "\n\n---\n_[Interrupted - processing new message]_"
|
||||
|
||||
if response:
|
||||
if response and not (self.agent and self.agent.stream_delta_callback):
|
||||
# Use a Rich Panel for the response box — adapts to terminal
|
||||
# width at render time instead of hard-coding border length.
|
||||
try:
|
||||
|
||||
@@ -413,6 +413,36 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
return SendResult(success=False, error="Not supported")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
"""Whether this platform supports response streaming via message edits."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_draft_streaming(self) -> bool:
|
||||
"""Whether this platform supports native draft streaming (Bot API 9.3+)."""
|
||||
return False
|
||||
|
||||
async def send_draft(self, chat_id: str, draft_id: int, text: str, metadata: dict = None) -> bool:
|
||||
"""Push a draft text update. Override in subclasses."""
|
||||
return False
|
||||
|
||||
async def finalize_draft(self, chat_id: str, content: str, metadata: dict = None) -> "SendResult":
|
||||
"""Finalize a draft stream with the completed message."""
|
||||
return SendResult(success=False, error="Not supported")
|
||||
|
||||
async def delete_message(self, chat_id: str, message_id: str) -> SendResult:
|
||||
"""Delete a previously sent message."""
|
||||
return SendResult(success=False, error="Not supported")
|
||||
|
||||
async def send_raw(self, chat_id: str, content: str, metadata: dict = None) -> "SendResult":
|
||||
"""Send without formatting (default: delegates to send)."""
|
||||
return await self.send(chat_id=chat_id, content=content, metadata=metadata)
|
||||
|
||||
async def edit_message_raw(self, chat_id: str, message_id: str, content: str) -> "SendResult":
|
||||
"""Edit without formatting (default: delegates to edit_message)."""
|
||||
return await self.edit_message(chat_id=chat_id, message_id=message_id, content=content)
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""
|
||||
Send a typing indicator.
|
||||
@@ -697,11 +727,20 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
try:
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
handler_result = await self._message_handler(event)
|
||||
|
||||
# Normalise: handler may return str or dict(content, already_sent)
|
||||
already_sent = False
|
||||
if isinstance(handler_result, dict):
|
||||
response = handler_result.get("content") or ""
|
||||
already_sent = handler_result.get("already_sent", False)
|
||||
else:
|
||||
response = handler_result
|
||||
|
||||
# Send response if any
|
||||
if not response:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if not already_sent:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
@@ -712,7 +751,7 @@ class BasePlatformAdapter(ABC):
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
if text_content and not already_sent:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
|
||||
@@ -299,6 +299,99 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_raw(
|
||||
self, chat_id: str, content: str, metadata: dict = None,
|
||||
) -> SendResult:
|
||||
"""Send a plain-text message without MarkdownV2 formatting."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id), text=content, parse_mode=None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def edit_message_raw(
|
||||
self, chat_id: str, message_id: str, content: str,
|
||||
) -> SendResult:
|
||||
"""Edit a message with plain text (no MarkdownV2 formatting)."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
await self._bot.edit_message_text(
|
||||
chat_id=int(chat_id), message_id=int(message_id),
|
||||
text=content, parse_mode=None,
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_draft_streaming(self) -> bool:
|
||||
"""Whether this adapter supports Telegram Bot API sendMessageDraft (9.3+)."""
|
||||
return True
|
||||
|
||||
async def send_draft(
|
||||
self, chat_id: str, draft_id: int, text: str, metadata: dict = None,
|
||||
) -> bool:
|
||||
"""Push a draft update via sendMessageDraft (Bot API 9.3+)."""
|
||||
if not self._bot:
|
||||
return False
|
||||
try:
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
return await self._bot.send_message_draft(
|
||||
chat_id=int(chat_id), draft_id=draft_id, text=text,
|
||||
parse_mode=None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] send_message_draft failed: %s", self.name, e)
|
||||
return False
|
||||
|
||||
async def finalize_draft(
|
||||
self, chat_id: str, content: str, metadata: dict = None,
|
||||
) -> SendResult:
|
||||
"""Finalize a draft stream by sending the completed message with formatting."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
formatted = self.format_message(content)
|
||||
try:
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id), text=formatted,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception:
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id), text=content, parse_mode=None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def delete_message(self, chat_id: str, message_id: str) -> SendResult:
|
||||
"""Delete a Telegram message."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
await self._bot.delete_message(
|
||||
chat_id=int(chat_id), message_id=int(message_id),
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
159
run_agent.py
159
run_agent.py
@@ -175,6 +175,7 @@ class AIAgent:
|
||||
thinking_callback: callable = None,
|
||||
clarify_callback: callable = None,
|
||||
step_callback: callable = None,
|
||||
stream_delta_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
@@ -262,6 +263,7 @@ class AIAgent:
|
||||
self.thinking_callback = thinking_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
@@ -2060,6 +2062,147 @@ class AIAgent:
|
||||
return terminal_response
|
||||
raise RuntimeError("Responses create(stream=True) fallback did not emit a terminal response.")
|
||||
|
||||
def _interruptible_streaming_api_call(self, api_kwargs: dict, on_first_delta=None):
|
||||
"""Streaming variant of _interruptible_api_call for chat_completions.
|
||||
|
||||
Fires self.stream_delta_callback(text) as content tokens arrive and
|
||||
accumulates the full response into a SimpleNamespace matching the shape
|
||||
downstream code expects. Falls back to the non-streaming path when the
|
||||
provider rejects the stream request.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
result = {"response": None, "error": None}
|
||||
first_delta_fired = [False]
|
||||
|
||||
def _stream():
|
||||
try:
|
||||
stream_kwargs = {**api_kwargs, "stream": True,
|
||||
"stream_options": {"include_usage": True}}
|
||||
stream_resp = self.client.chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts = []
|
||||
tool_calls_acc = {}
|
||||
finish_reason = "stop"
|
||||
usage = None
|
||||
reasoning_content = None
|
||||
model = None
|
||||
has_tool_calls = False
|
||||
|
||||
try:
|
||||
for chunk in stream_resp:
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = chunk.usage
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
if model is None and hasattr(chunk, "model"):
|
||||
model = chunk.model
|
||||
|
||||
delta = choice.delta
|
||||
if delta is None:
|
||||
continue
|
||||
|
||||
if delta.content:
|
||||
content_parts.append(delta.content)
|
||||
if not first_delta_fired[0]:
|
||||
first_delta_fired[0] = True
|
||||
if on_first_delta:
|
||||
on_first_delta()
|
||||
if self.stream_delta_callback and not has_tool_calls:
|
||||
try:
|
||||
self.stream_delta_callback(delta.content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if delta.tool_calls:
|
||||
has_tool_calls = True
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc_delta.id or "",
|
||||
"type": tc_delta.type or "function",
|
||||
"function": {
|
||||
"name": getattr(tc_delta.function, "name", None) or "",
|
||||
"arguments": getattr(tc_delta.function, "arguments", None) or "",
|
||||
},
|
||||
}
|
||||
else:
|
||||
entry = tool_calls_acc[idx]
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
fn = tc_delta.function
|
||||
if fn:
|
||||
if fn.name:
|
||||
entry["function"]["name"] = fn.name
|
||||
if fn.arguments:
|
||||
entry["function"]["arguments"] += fn.arguments
|
||||
|
||||
rc = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
||||
if rc:
|
||||
reasoning_content = (reasoning_content or "") + rc
|
||||
finally:
|
||||
close_fn = getattr(stream_resp, "close", None)
|
||||
if callable(close_fn):
|
||||
try:
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tool_calls_list = None
|
||||
if tool_calls_acc:
|
||||
tool_calls_list = [
|
||||
SimpleNamespace(
|
||||
id=tc["id"], call_id=tc["id"], type=tc["type"],
|
||||
function=SimpleNamespace(name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"]),
|
||||
)
|
||||
for idx, tc in sorted(tool_calls_acc.items())
|
||||
]
|
||||
|
||||
message = SimpleNamespace(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls_list,
|
||||
reasoning=reasoning_content,
|
||||
reasoning_content=reasoning_content,
|
||||
reasoning_details=None,
|
||||
)
|
||||
result["response"] = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=message, finish_reason=finish_reason)],
|
||||
usage=usage,
|
||||
model=model,
|
||||
)
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
|
||||
t = threading.Thread(target=_stream, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
if self._interrupt_requested:
|
||||
try:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during streaming API call")
|
||||
|
||||
if result["error"] is not None:
|
||||
err = result["error"]
|
||||
err_str = str(err).lower()
|
||||
if any(kw in err_str for kw in ("stream", "not support", "unsupported")):
|
||||
logger.debug("Streaming failed (%s), falling back to non-streaming.", err)
|
||||
return self._interruptible_api_call(api_kwargs)
|
||||
raise err
|
||||
return result["response"]
|
||||
|
||||
def _try_refresh_codex_client_credentials(self, *, force: bool = True) -> bool:
|
||||
if self.api_mode != "codex_responses" or self.provider != "openai-codex":
|
||||
return False
|
||||
@@ -3474,7 +3617,17 @@ class AIAgent:
|
||||
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
if self.stream_delta_callback and self.api_mode != "codex_responses":
|
||||
def _stop_spinner():
|
||||
nonlocal thinking_spinner
|
||||
if thinking_spinner:
|
||||
thinking_spinner.stop("")
|
||||
thinking_spinner = None
|
||||
|
||||
response = self._interruptible_streaming_api_call(
|
||||
api_kwargs, on_first_delta=_stop_spinner)
|
||||
else:
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
@@ -4230,8 +4383,8 @@ class AIAgent:
|
||||
turn_content = assistant_message.content or ""
|
||||
if turn_content and self._has_content_after_think_block(turn_content):
|
||||
self._last_content_with_tools = turn_content
|
||||
# Show intermediate commentary so the user can follow along
|
||||
if self.quiet_mode:
|
||||
# Show intermediate commentary — skip when streaming (already in buffer)
|
||||
if self.quiet_mode and not self.stream_delta_callback:
|
||||
clean = self._strip_think_blocks(turn_content).strip()
|
||||
if clean:
|
||||
print(f" ┊ 💬 {clean}")
|
||||
|
||||
257
tests/test_streaming.py
Normal file
257
tests/test_streaming.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Tests for streaming token output — accumulator shape, callback order, fallback."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_defs(*names):
|
||||
return [
|
||||
{"type": "function", "function": {"name": n, "description": f"{n}", "parameters": {"type": "object", "properties": {}}}}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent():
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
cb = MagicMock()
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=cb,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
a._stream_cb = cb
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — fake streaming chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _chunk(content=None, tool_call_delta=None, finish_reason=None, usage=None, model=None):
|
||||
delta = SimpleNamespace(content=content, tool_calls=tool_call_delta,
|
||||
reasoning_content=None, reasoning=None)
|
||||
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason)
|
||||
c = SimpleNamespace(choices=[choice])
|
||||
if usage is not None:
|
||||
c.usage = SimpleNamespace(**usage)
|
||||
if model:
|
||||
c.model = model
|
||||
return c
|
||||
|
||||
|
||||
def _usage_chunk(**kw):
|
||||
c = SimpleNamespace(choices=[], usage=SimpleNamespace(**kw))
|
||||
return c
|
||||
|
||||
|
||||
def _tc_delta(index, id=None, name=None, arguments=None, type=None):
|
||||
fn = SimpleNamespace(name=name, arguments=arguments)
|
||||
return SimpleNamespace(index=index, id=id, type=type, function=fn)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: accumulator shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamingAccumulator:
|
||||
def test_text_only_response(self, agent):
|
||||
"""Streaming text-only response produces correct synthetic shape."""
|
||||
chunks = [
|
||||
_chunk(content="Hello", model="test/m"),
|
||||
_chunk(content=" world"),
|
||||
_chunk(finish_reason="stop"),
|
||||
_usage_chunk(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "Hello world"
|
||||
assert resp.choices[0].message.tool_calls is None
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert resp.usage.prompt_tokens == 10
|
||||
assert resp.model == "test/m"
|
||||
|
||||
def test_tool_call_response(self, agent):
|
||||
"""Streaming tool-call response accumulates function name + arguments."""
|
||||
chunks = [
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="call_1", name="web_search", arguments='{"q', type="function")]),
|
||||
_chunk(tool_call_delta=[_tc_delta(0, arguments='uery": "hi"}')]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert tc is not None
|
||||
assert len(tc) == 1
|
||||
assert tc[0].id == "call_1"
|
||||
assert tc[0].function.name == "web_search"
|
||||
assert tc[0].function.arguments == '{"query": "hi"}'
|
||||
assert resp.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
def test_mixed_content_and_tool_calls(self, agent):
|
||||
"""Content + tool calls in same stream are both accumulated."""
|
||||
chunks = [
|
||||
_chunk(content="Let me check."),
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="web_search", arguments="{}", type="function")]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "Let me check."
|
||||
assert len(resp.choices[0].message.tool_calls) == 1
|
||||
|
||||
|
||||
class TestStreamingCallbacks:
|
||||
def test_deltas_fire_in_order(self, agent):
|
||||
"""stream_delta_callback receives content deltas in order."""
|
||||
received = []
|
||||
agent.stream_delta_callback = lambda t: received.append(t)
|
||||
chunks = [_chunk(content="a"), _chunk(content="b"), _chunk(content="c"), _chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert received == ["a", "b", "c"]
|
||||
|
||||
def test_on_first_delta_fires_once(self, agent):
|
||||
first = MagicMock()
|
||||
chunks = [_chunk(content="x"), _chunk(content="y"), _chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"}, on_first_delta=first)
|
||||
|
||||
first.assert_called_once()
|
||||
|
||||
def test_tool_only_does_not_fire_callback(self, agent):
|
||||
"""Tool-call-only stream does not invoke stream_delta_callback."""
|
||||
received = []
|
||||
agent.stream_delta_callback = lambda t: received.append(t)
|
||||
chunks = [
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="t", arguments="{}", type="function")]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert received == []
|
||||
|
||||
|
||||
class TestStreamingFallback:
|
||||
def test_stream_error_falls_back(self, agent):
|
||||
"""When streaming fails with 'not support', falls back to non-streaming."""
|
||||
agent.client.chat.completions.create.side_effect = [
|
||||
Exception("streaming not supported by this provider"),
|
||||
SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content="ok", tool_calls=None, reasoning=None, reasoning_content=None, reasoning_details=None),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
model="test/m",
|
||||
),
|
||||
]
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "ok"
|
||||
assert agent.client.chat.completions.create.call_count == 2
|
||||
|
||||
def test_non_stream_error_raises(self, agent):
|
||||
"""Non-stream-related errors propagate normally."""
|
||||
agent.client.chat.completions.create.side_effect = ValueError("bad request")
|
||||
|
||||
with pytest.raises(ValueError, match="bad request"):
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: base.py already_sent contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAlreadySentContract:
|
||||
def _make_adapter(self, send_side_effect=None):
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
class FakeAdapter(BasePlatformAdapter):
|
||||
async def connect(self): return True
|
||||
async def disconnect(self): pass
|
||||
async def get_chat_info(self, chat_id): return {"name": "test"}
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
if send_side_effect is not None:
|
||||
send_side_effect(content)
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
cfg = PlatformConfig(enabled=True)
|
||||
adapter = FakeAdapter(cfg, Platform.TELEGRAM)
|
||||
adapter._running = True
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_sent_skips_send(self):
|
||||
"""Handler returning already_sent=True prevents base from calling send()."""
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
sent = []
|
||||
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
|
||||
|
||||
async def handler(event):
|
||||
return {"content": "hello", "already_sent": True}
|
||||
adapter.set_message_handler(handler)
|
||||
|
||||
event = MessageEvent(
|
||||
text="hi",
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
|
||||
)
|
||||
await adapter._process_message_background(event, "s1")
|
||||
|
||||
assert sent == [], "send() should not be called when already_sent=True"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_response_sends_normally(self):
|
||||
"""Handler returning a plain string triggers send() as before."""
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
sent = []
|
||||
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
|
||||
|
||||
async def handler(event):
|
||||
return "hello"
|
||||
adapter.set_message_handler(handler)
|
||||
|
||||
event = MessageEvent(
|
||||
text="hi",
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
|
||||
)
|
||||
await adapter._process_message_background(event, "s1")
|
||||
|
||||
assert "hello" in sent
|
||||
Reference in New Issue
Block a user