mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 23:11:37 +08:00
Compare commits
1 Commits
feat/langf
...
fix/stop-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2579245ae2 |
@@ -1103,7 +1103,7 @@ class GatewayRunner:
|
|||||||
if override_runtime.get("api_key"):
|
if override_runtime.get("api_key"):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Session model override (fast): session=%s config_model=%s -> override_model=%s provider=%s",
|
"Session model override (fast): session=%s config_model=%s -> override_model=%s provider=%s",
|
||||||
(resolved_session_key or "")[:30], model, override_model,
|
resolved_session_key or "", model, override_model,
|
||||||
override_runtime.get("provider"),
|
override_runtime.get("provider"),
|
||||||
)
|
)
|
||||||
return override_model, override_runtime
|
return override_model, override_runtime
|
||||||
@@ -1111,12 +1111,12 @@ class GatewayRunner:
|
|||||||
# resolution and apply model/provider from the override on top.
|
# resolution and apply model/provider from the override on top.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Session model override (no api_key, fallback): session=%s config_model=%s override_model=%s",
|
"Session model override (no api_key, fallback): session=%s config_model=%s override_model=%s",
|
||||||
(resolved_session_key or "")[:30], model, override_model,
|
resolved_session_key or "", model, override_model,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"No session model override: session=%s config_model=%s override_keys=%s",
|
"No session model override: session=%s config_model=%s override_keys=%s",
|
||||||
(resolved_session_key or "")[:30], model,
|
resolved_session_key or "", model,
|
||||||
list(self._session_model_overrides.keys())[:5] if self._session_model_overrides else "[]",
|
list(self._session_model_overrides.keys())[:5] if self._session_model_overrides else "[]",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1687,7 +1687,7 @@ class GatewayRunner:
|
|||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
agent.interrupt(reason)
|
agent.interrupt(reason)
|
||||||
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
logger.debug("Interrupted running agent for session %s during shutdown", session_key)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
||||||
|
|
||||||
@@ -1859,7 +1859,7 @@ class GatewayRunner:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Auto-suspended stuck session %s (active across %d "
|
"Auto-suspended stuck session %s (active across %d "
|
||||||
"consecutive restarts — likely a stuck loop)",
|
"consecutive restarts — likely a stuck loop)",
|
||||||
session_key[:30], counts[session_key],
|
session_key, counts[session_key],
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -2681,7 +2681,7 @@ class GatewayRunner:
|
|||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"mark_resume_pending failed for %s: %s",
|
"mark_resume_pending failed for %s: %s",
|
||||||
_sk[:20], _e,
|
_sk, _e,
|
||||||
)
|
)
|
||||||
self._interrupt_running_agents(
|
self._interrupt_running_agents(
|
||||||
_INTERRUPT_REASON_GATEWAY_RESTART if self._restart_requested else _INTERRUPT_REASON_GATEWAY_SHUTDOWN
|
_INTERRUPT_REASON_GATEWAY_RESTART if self._restart_requested else _INTERRUPT_REASON_GATEWAY_SHUTDOWN
|
||||||
@@ -3347,7 +3347,7 @@ class GatewayRunner:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Evicting stale _running_agents entry for %s "
|
"Evicting stale _running_agents entry for %s "
|
||||||
"(age: %.0fs, idle: %.0fs, timeout: %.0fs)%s",
|
"(age: %.0fs, idle: %.0fs, timeout: %.0fs)%s",
|
||||||
_quick_key[:30], _stale_age, _stale_idle,
|
_quick_key, _stale_age, _stale_idle,
|
||||||
_raw_stale_timeout, _stale_detail,
|
_raw_stale_timeout, _stale_detail,
|
||||||
)
|
)
|
||||||
self._invalidate_session_run_generation(
|
self._invalidate_session_run_generation(
|
||||||
@@ -3383,7 +3383,7 @@ class GatewayRunner:
|
|||||||
interrupt_reason=_INTERRUPT_REASON_STOP,
|
interrupt_reason=_INTERRUPT_REASON_STOP,
|
||||||
invalidation_reason="stop_command",
|
invalidation_reason="stop_command",
|
||||||
)
|
)
|
||||||
logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key[:20])
|
logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key)
|
||||||
return "⚡ Stopped. You can continue this session."
|
return "⚡ Stopped. You can continue this session."
|
||||||
|
|
||||||
# /reset and /new must bypass the running-agent guard so they
|
# /reset and /new must bypass the running-agent guard so they
|
||||||
@@ -3449,7 +3449,7 @@ class GatewayRunner:
|
|||||||
try:
|
try:
|
||||||
accepted = running_agent.steer(steer_text)
|
accepted = running_agent.steer(steer_text)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Steer failed for session %s: %s", _quick_key[:20], exc)
|
logger.warning("Steer failed for session %s: %s", _quick_key, exc)
|
||||||
return f"⚠️ Steer failed: {exc}"
|
return f"⚠️ Steer failed: {exc}"
|
||||||
if accepted:
|
if accepted:
|
||||||
preview = steer_text[:60] + ("..." if len(steer_text) > 60 else "")
|
preview = steer_text[:60] + ("..." if len(steer_text) > 60 else "")
|
||||||
@@ -3532,7 +3532,7 @@ class GatewayRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if event.message_type == MessageType.PHOTO:
|
if event.message_type == MessageType.PHOTO:
|
||||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key)
|
||||||
adapter = self.adapters.get(source.platform)
|
adapter = self.adapters.get(source.platform)
|
||||||
if adapter:
|
if adapter:
|
||||||
merge_pending_message_event(adapter._pending_messages, _quick_key, event)
|
merge_pending_message_event(adapter._pending_messages, _quick_key, event)
|
||||||
@@ -3552,7 +3552,7 @@ class GatewayRunner:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"Telegram follow-up arrived %.2fs after run start for %s — queueing without interrupt",
|
"Telegram follow-up arrived %.2fs after run start for %s — queueing without interrupt",
|
||||||
time.time() - _started_at,
|
time.time() - _started_at,
|
||||||
_quick_key[:20],
|
_quick_key,
|
||||||
)
|
)
|
||||||
adapter = self.adapters.get(source.platform)
|
adapter = self.adapters.get(source.platform)
|
||||||
if adapter:
|
if adapter:
|
||||||
@@ -3570,7 +3570,7 @@ class GatewayRunner:
|
|||||||
if event.get_command() == "stop":
|
if event.get_command() == "stop":
|
||||||
# Force-clean the sentinel so the session is unlocked.
|
# Force-clean the sentinel so the session is unlocked.
|
||||||
self._release_running_agent_state(_quick_key)
|
self._release_running_agent_state(_quick_key)
|
||||||
logger.info("HARD STOP (pending) for session %s — sentinel cleared", _quick_key[:20])
|
logger.info("HARD STOP (pending) for session %s — sentinel cleared", _quick_key)
|
||||||
return "⚡ Force-stopped. The agent was still starting — session unlocked."
|
return "⚡ Force-stopped. The agent was still starting — session unlocked."
|
||||||
# Queue the message so it will be picked up after the
|
# Queue the message so it will be picked up after the
|
||||||
# agent starts.
|
# agent starts.
|
||||||
@@ -3592,10 +3592,10 @@ class GatewayRunner:
|
|||||||
else f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
|
else f"⏳ Gateway is {self._status_action_gerund()} and is not accepting another turn right now."
|
||||||
)
|
)
|
||||||
if self._busy_input_mode == "queue":
|
if self._busy_input_mode == "queue":
|
||||||
logger.debug("PRIORITY queue follow-up for session %s", _quick_key[:20])
|
logger.debug("PRIORITY queue follow-up for session %s", _quick_key)
|
||||||
self._queue_or_replace_pending_event(_quick_key, event)
|
self._queue_or_replace_pending_event(_quick_key, event)
|
||||||
return None
|
return None
|
||||||
logger.debug("PRIORITY interrupt for session %s", _quick_key[:20])
|
logger.debug("PRIORITY interrupt for session %s", _quick_key)
|
||||||
running_agent.interrupt(event.text)
|
running_agent.interrupt(event.text)
|
||||||
if _quick_key in self._pending_messages:
|
if _quick_key in self._pending_messages:
|
||||||
self._pending_messages[_quick_key] += "\n" + event.text
|
self._pending_messages[_quick_key] += "\n" + event.text
|
||||||
@@ -4593,7 +4593,7 @@ class GatewayRunner:
|
|||||||
if not self._is_session_run_current(_quick_key, run_generation):
|
if not self._is_session_run_current(_quick_key, run_generation):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Discarding stale agent result for %s — generation %d is no longer current",
|
"Discarding stale agent result for %s — generation %d is no longer current",
|
||||||
_quick_key[:20] if _quick_key else "?",
|
_quick_key or "?",
|
||||||
run_generation,
|
run_generation,
|
||||||
)
|
)
|
||||||
_stale_adapter = self.adapters.get(source.platform)
|
_stale_adapter = self.adapters.get(source.platform)
|
||||||
@@ -4644,7 +4644,7 @@ class GatewayRunner:
|
|||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"clear_resume_pending failed for %s: %s",
|
"clear_resume_pending failed for %s: %s",
|
||||||
session_key[:20], _e,
|
session_key, _e,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Surface error details when the agent failed silently (final_response=None)
|
# Surface error details when the agent failed silently (final_response=None)
|
||||||
@@ -5291,7 +5291,7 @@ class GatewayRunner:
|
|||||||
interrupt_reason=_INTERRUPT_REASON_STOP,
|
interrupt_reason=_INTERRUPT_REASON_STOP,
|
||||||
invalidation_reason="stop_command_pending",
|
invalidation_reason="stop_command_pending",
|
||||||
)
|
)
|
||||||
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
logger.info("STOP (pending) for session %s — sentinel cleared", session_key)
|
||||||
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
||||||
if agent:
|
if agent:
|
||||||
# Force-clean the session lock so a truly hung agent doesn't
|
# Force-clean the session lock so a truly hung agent doesn't
|
||||||
@@ -8798,7 +8798,7 @@ class GatewayRunner:
|
|||||||
if reason:
|
if reason:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Invalidated run generation for %s → %d (%s)",
|
"Invalidated run generation for %s → %d (%s)",
|
||||||
session_key[:20],
|
session_key,
|
||||||
generation,
|
generation,
|
||||||
reason,
|
reason,
|
||||||
)
|
)
|
||||||
@@ -9205,7 +9205,7 @@ class GatewayRunner:
|
|||||||
if not _run_still_current():
|
if not _run_still_current():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Discarding stale proxy stream for %s — generation %d is no longer current",
|
"Discarding stale proxy stream for %s — generation %d is no longer current",
|
||||||
session_key[:20] if session_key else "?",
|
session_key or "?",
|
||||||
run_generation or 0,
|
run_generation or 0,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
@@ -9269,7 +9269,7 @@ class GatewayRunner:
|
|||||||
if not _run_still_current():
|
if not _run_still_current():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Discarding stale proxy result for %s — generation %d is no longer current",
|
"Discarding stale proxy result for %s — generation %d is no longer current",
|
||||||
session_key[:20] if session_key else "?",
|
session_key or "?",
|
||||||
run_generation or 0,
|
run_generation or 0,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
@@ -9711,7 +9711,7 @@ class GatewayRunner:
|
|||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"run_agent resolved: model=%s provider=%s session=%s",
|
"run_agent resolved: model=%s provider=%s session=%s",
|
||||||
model, runtime_kwargs.get("provider"), (session_key or "")[:30],
|
model, runtime_kwargs.get("provider"), session_key or "",
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return {
|
return {
|
||||||
@@ -10322,7 +10322,7 @@ class GatewayRunner:
|
|||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Skipping stale agent promotion for %s — generation %s is no longer current",
|
"Skipping stale agent promotion for %s — generation %s is no longer current",
|
||||||
(session_key or "")[:20],
|
session_key or "",
|
||||||
run_generation,
|
run_generation,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -10469,7 +10469,7 @@ class GatewayRunner:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Backup interrupt detected for session %s "
|
"Backup interrupt detected for session %s "
|
||||||
"(monitor task state: %s)",
|
"(monitor task state: %s)",
|
||||||
session_key[:20],
|
session_key,
|
||||||
"done" if interrupt_monitor.done() else "running",
|
"done" if interrupt_monitor.done() else "running",
|
||||||
)
|
)
|
||||||
_backup_agent.interrupt(_bp_text)
|
_backup_agent.interrupt(_bp_text)
|
||||||
@@ -10529,7 +10529,7 @@ class GatewayRunner:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Backup interrupt detected for session %s "
|
"Backup interrupt detected for session %s "
|
||||||
"(monitor task state: %s)",
|
"(monitor task state: %s)",
|
||||||
session_key[:20],
|
session_key,
|
||||||
"done" if interrupt_monitor.done() else "running",
|
"done" if interrupt_monitor.done() else "running",
|
||||||
)
|
)
|
||||||
_backup_agent.interrupt(_bp_text)
|
_backup_agent.interrupt(_bp_text)
|
||||||
@@ -10631,7 +10631,7 @@ class GatewayRunner:
|
|||||||
if _is_control_interrupt_message(interrupt_message):
|
if _is_control_interrupt_message(interrupt_message):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Ignoring control interrupt message for session %s: %s",
|
"Ignoring control interrupt message for session %s: %s",
|
||||||
session_key[:20] if session_key else "?",
|
session_key or "?",
|
||||||
interrupt_message,
|
interrupt_message,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -10675,7 +10675,7 @@ class GatewayRunner:
|
|||||||
if self._draining and (pending_event or pending):
|
if self._draining and (pending_event or pending):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Discarding pending follow-up for session %s during gateway %s",
|
"Discarding pending follow-up for session %s during gateway %s",
|
||||||
session_key[:20] if session_key else "?",
|
session_key or "?",
|
||||||
self._status_action_label(),
|
self._status_action_label(),
|
||||||
)
|
)
|
||||||
pending_event = None
|
pending_event = None
|
||||||
@@ -10732,7 +10732,7 @@ class GatewayRunner:
|
|||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Queued follow-up for session %s: final stream delivery not confirmed; sending first response before continuing.",
|
"Queued follow-up for session %s: final stream delivery not confirmed; sending first response before continuing.",
|
||||||
session_key[:20] if session_key else "?",
|
session_key or "?",
|
||||||
)
|
)
|
||||||
await adapter.send(
|
await adapter.send(
|
||||||
source.chat_id,
|
source.chat_id,
|
||||||
@@ -10744,7 +10744,7 @@ class GatewayRunner:
|
|||||||
elif first_response:
|
elif first_response:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Queued follow-up for session %s: skipping resend because final streamed delivery was confirmed.",
|
"Queued follow-up for session %s: skipping resend because final streamed delivery was confirmed.",
|
||||||
session_key[:20] if session_key else "?",
|
session_key or "?",
|
||||||
)
|
)
|
||||||
# Release deferred bg-review notifications now that the
|
# Release deferred bg-review notifications now that the
|
||||||
# first response has been delivered. Pop from the
|
# first response has been delivered. Pop from the
|
||||||
@@ -10879,7 +10879,7 @@ class GatewayRunner:
|
|||||||
if not _is_empty_sentinel and (_streamed or _previewed):
|
if not _is_empty_sentinel and (_streamed or _previewed):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Suppressing normal final send for session %s: final delivery already confirmed (streamed=%s previewed=%s).",
|
"Suppressing normal final send for session %s: final delivery already confirmed (streamed=%s previewed=%s).",
|
||||||
session_key[:20] if session_key else "?",
|
session_key or "?",
|
||||||
_streamed,
|
_streamed,
|
||||||
_previewed,
|
_previewed,
|
||||||
)
|
)
|
||||||
|
|||||||
10
run_agent.py
10
run_agent.py
@@ -5137,6 +5137,8 @@ class AIAgent:
|
|||||||
# response.incomplete instead of response.completed).
|
# response.incomplete instead of response.completed).
|
||||||
self._codex_streamed_text_parts: list = []
|
self._codex_streamed_text_parts: list = []
|
||||||
for attempt in range(max_stream_retries + 1):
|
for attempt in range(max_stream_retries + 1):
|
||||||
|
if self._interrupt_requested:
|
||||||
|
raise InterruptedError("Agent interrupted before Codex stream retry")
|
||||||
collected_output_items: list = []
|
collected_output_items: list = []
|
||||||
try:
|
try:
|
||||||
with active_client.responses.stream(**api_kwargs) as stream:
|
with active_client.responses.stream(**api_kwargs) as stream:
|
||||||
@@ -6306,6 +6308,14 @@ class AIAgent:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for _stream_attempt in range(_max_stream_retries + 1):
|
for _stream_attempt in range(_max_stream_retries + 1):
|
||||||
|
# Check for interrupt before each retry attempt. Without
|
||||||
|
# this, /stop closes the HTTP connection (outer poll loop),
|
||||||
|
# but the retry loop opens a FRESH connection — negating the
|
||||||
|
# interrupt entirely. On slow providers (ollama-cloud) each
|
||||||
|
# retry can block for the full stream-read timeout (120s+),
|
||||||
|
# causing multi-minute delays between /stop and response.
|
||||||
|
if self._interrupt_requested:
|
||||||
|
raise InterruptedError("Agent interrupted before stream retry")
|
||||||
try:
|
try:
|
||||||
if self.api_mode == "anthropic_messages":
|
if self.api_mode == "anthropic_messages":
|
||||||
self._try_refresh_anthropic_client_credentials()
|
self._try_refresh_anthropic_client_credentials()
|
||||||
|
|||||||
162
tests/run_agent/test_stream_interrupt_retry.py
Normal file
162
tests/run_agent/test_stream_interrupt_retry.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""Tests that /stop interrupts streaming retry loops immediately.
|
||||||
|
|
||||||
|
When the agent is interrupted during a streaming API call, the outer poll
|
||||||
|
loop closes the HTTP connection. The inner `_call()` thread sees a
|
||||||
|
connection error and enters its retry loop. Before this fix, the retry
|
||||||
|
loop would open a FRESH connection without checking `_interrupt_requested`,
|
||||||
|
making /stop take multiple retry cycles × read-timeout to actually stop
|
||||||
|
(510+ seconds observed on slow ollama-cloud providers).
|
||||||
|
|
||||||
|
The fix adds an `_interrupt_requested` check at the top of the retry loop
|
||||||
|
so the agent exits immediately instead of retrying.
|
||||||
|
"""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_agent(**kwargs):
|
||||||
|
"""Create a minimal AIAgent for streaming tests."""
|
||||||
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://example.com/v1",
|
||||||
|
model="test/model",
|
||||||
|
quiet_mode=True,
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
)
|
||||||
|
defaults.update(kwargs)
|
||||||
|
agent = AIAgent(**defaults)
|
||||||
|
agent.api_mode = "chat_completions"
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamInterruptBeforeRetry:
|
||||||
|
"""Verify _interrupt_requested is checked before each streaming retry."""
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings(
|
||||||
|
"ignore::pytest.PytestUnhandledThreadExceptionWarning"
|
||||||
|
)
|
||||||
|
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||||
|
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||||
|
def test_interrupt_prevents_stream_retry(self, mock_close, mock_create):
|
||||||
|
"""When _interrupt_requested is set during a transient stream error,
|
||||||
|
the retry loop must NOT retry — it should raise InterruptedError
|
||||||
|
immediately instead of opening a fresh connection."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
attempt_count = [0]
|
||||||
|
|
||||||
|
def fail_once_then_interrupt(*args, **kwargs):
|
||||||
|
attempt_count[0] += 1
|
||||||
|
if attempt_count[0] == 1:
|
||||||
|
# First attempt: simulate normal failure, then set interrupt
|
||||||
|
# (as if /stop arrived while the retry loop processes the error)
|
||||||
|
agent._interrupt_requested = True
|
||||||
|
raise httpx.ConnectError("connection reset by /stop")
|
||||||
|
# Should never reach here — the interrupt check should fire first
|
||||||
|
raise httpx.ConnectError("unexpected retry — interrupt not checked!")
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.chat.completions.create.side_effect = fail_once_then_interrupt
|
||||||
|
mock_create.return_value = mock_client
|
||||||
|
|
||||||
|
agent = _make_agent()
|
||||||
|
agent._interrupt_requested = False
|
||||||
|
|
||||||
|
with pytest.raises(InterruptedError, match="interrupted"):
|
||||||
|
agent._interruptible_streaming_api_call({})
|
||||||
|
|
||||||
|
# Only 1 attempt should have been made — the interrupt should prevent retry
|
||||||
|
assert attempt_count[0] == 1, (
|
||||||
|
f"Expected 1 attempt but got {attempt_count[0]}. "
|
||||||
|
"The retry loop retried despite _interrupt_requested being set."
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings(
|
||||||
|
"ignore::pytest.PytestUnhandledThreadExceptionWarning"
|
||||||
|
)
|
||||||
|
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||||
|
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||||
|
def test_interrupt_before_first_attempt(self, mock_close, mock_create):
|
||||||
|
"""If _interrupt_requested is already set when the streaming call
|
||||||
|
starts, it should exit immediately without making any API call."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_create.return_value = mock_client
|
||||||
|
|
||||||
|
agent = _make_agent()
|
||||||
|
agent._interrupt_requested = True # Pre-set before call
|
||||||
|
|
||||||
|
with pytest.raises(InterruptedError, match="interrupted"):
|
||||||
|
agent._interruptible_streaming_api_call({})
|
||||||
|
|
||||||
|
# No API call should have been made at all
|
||||||
|
assert mock_client.chat.completions.create.call_count == 0
|
||||||
|
|
||||||
|
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||||
|
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||||
|
def test_normal_retry_still_works_without_interrupt(self, mock_close, mock_create):
|
||||||
|
"""Without an interrupt, transient errors should still retry normally."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
attempts = [0]
|
||||||
|
|
||||||
|
def fail_twice_then_succeed(*args, **kwargs):
|
||||||
|
attempts[0] += 1
|
||||||
|
if attempts[0] <= 2:
|
||||||
|
raise httpx.ConnectError("transient failure")
|
||||||
|
# Third attempt succeeds
|
||||||
|
chunks = [
|
||||||
|
SimpleNamespace(
|
||||||
|
choices=[
|
||||||
|
SimpleNamespace(
|
||||||
|
index=0,
|
||||||
|
delta=SimpleNamespace(
|
||||||
|
content="ok",
|
||||||
|
tool_calls=None,
|
||||||
|
reasoning_content=None,
|
||||||
|
reasoning=None,
|
||||||
|
),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="test/model",
|
||||||
|
usage=None,
|
||||||
|
),
|
||||||
|
SimpleNamespace(
|
||||||
|
choices=[
|
||||||
|
SimpleNamespace(
|
||||||
|
index=0,
|
||||||
|
delta=SimpleNamespace(
|
||||||
|
content=None,
|
||||||
|
tool_calls=None,
|
||||||
|
reasoning_content=None,
|
||||||
|
reasoning=None,
|
||||||
|
),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="test/model",
|
||||||
|
usage=None,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
stream = MagicMock()
|
||||||
|
stream.__iter__ = MagicMock(return_value=iter(chunks))
|
||||||
|
stream.response = MagicMock()
|
||||||
|
stream.response.headers = {}
|
||||||
|
return stream
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.chat.completions.create.side_effect = fail_twice_then_succeed
|
||||||
|
mock_create.return_value = mock_client
|
||||||
|
|
||||||
|
agent = _make_agent()
|
||||||
|
agent._interrupt_requested = False
|
||||||
|
|
||||||
|
# Should succeed on the third attempt
|
||||||
|
result = agent._interruptible_streaming_api_call({})
|
||||||
|
assert result is not None
|
||||||
|
assert attempts[0] == 3
|
||||||
Reference in New Issue
Block a user