mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 15:31:38 +08:00
feat: per-turn primary runtime restoration and transport recovery
Make provider fallback turn-scoped in long-lived CLI sessions. Previously, a single transient failure pinned the session to the fallback provider for every subsequent turn. Now the primary model/provider is restored at the start of each run_conversation() call. Three pieces: 1. _primary_runtime dict snapshot at __init__ — captures model, provider, base_url, api_mode, api_key, client_kwargs, prompt caching flag, and context compressor state. Single dict is easy to extend; avoids the brittleness of N individual _primary_* attributes. 2. _restore_primary_runtime() — called at the top of run_conversation(). Restores all primary state, rebuilds the client, resets the fallback chain index (so all fallbacks are available again), and restores the context compressor's model/context_length/threshold that _try_activate_fallback() overwrites. No-op in the gateway (fresh agent per message). 3. _try_recover_primary_transport() — after max_retries exhaust, rebuilds the primary client (clearing stale connection pools) and gives one more attempt before falling through to fallback. Only for transient transport errors (ReadTimeout, ConnectTimeout, PoolTimeout, ConnectError, RemoteProtocolError). Skipped for aggregator providers (OpenRouter, Nous) which manage their own retry infrastructure. Inspired by PR #4612 (betamod) which identified this gap. Includes 25 tests covering snapshot creation, restore correctness, fallback index reset, compressor state restoration, transport recovery scoping, wait time behavior, and error resilience.
This commit is contained in:
200
run_agent.py
200
run_agent.py
@@ -1236,6 +1236,34 @@ class AIAgent:
|
||||
else:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)")
|
||||
|
||||
# Snapshot primary runtime for per-turn restoration. When fallback
|
||||
# activates during a turn, the next turn restores these values so the
|
||||
# preferred model gets a fresh attempt each time. Uses a single dict
|
||||
# so new state fields are easy to add without N individual attributes.
|
||||
_cc = self.context_compressor
|
||||
self._primary_runtime = {
|
||||
"model": self.model,
|
||||
"provider": self.provider,
|
||||
"base_url": self.base_url,
|
||||
"api_mode": self.api_mode,
|
||||
"api_key": getattr(self, "api_key", ""),
|
||||
"client_kwargs": dict(self._client_kwargs),
|
||||
"use_prompt_caching": self._use_prompt_caching,
|
||||
# Compressor state that _try_activate_fallback() overwrites
|
||||
"compressor_model": _cc.model,
|
||||
"compressor_base_url": _cc.base_url,
|
||||
"compressor_api_key": getattr(_cc, "api_key", ""),
|
||||
"compressor_provider": _cc.provider,
|
||||
"compressor_context_length": _cc.context_length,
|
||||
"compressor_threshold_tokens": _cc.threshold_tokens,
|
||||
}
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._primary_runtime.update({
|
||||
"anthropic_api_key": self._anthropic_api_key,
|
||||
"anthropic_base_url": self._anthropic_base_url,
|
||||
"is_anthropic_oauth": self._is_anthropic_oauth,
|
||||
})
|
||||
|
||||
def reset_session_state(self):
|
||||
"""Reset all session-scoped token counters to 0 for a fresh session.
|
||||
|
||||
@@ -4765,6 +4793,156 @@ class AIAgent:
|
||||
logging.error("Failed to activate fallback %s: %s", fb_model, e)
|
||||
return self._try_activate_fallback() # try next in chain
|
||||
|
||||
# ── Per-turn primary restoration ─────────────────────────────────────
|
||||
|
||||
def _restore_primary_runtime(self) -> bool:
|
||||
"""Restore the primary runtime at the start of a new turn.
|
||||
|
||||
In long-lived CLI sessions a single AIAgent instance spans multiple
|
||||
turns. Without restoration, one transient failure pins the session
|
||||
to the fallback provider for every subsequent turn. Calling this at
|
||||
the top of ``run_conversation()`` makes fallback turn-scoped.
|
||||
|
||||
The gateway creates a fresh agent per message so this is a no-op
|
||||
there (``_fallback_activated`` is always False at turn start).
|
||||
"""
|
||||
if not self._fallback_activated:
|
||||
return False
|
||||
|
||||
rt = self._primary_runtime
|
||||
try:
|
||||
# ── Core runtime state ──
|
||||
self.model = rt["model"]
|
||||
self.provider = rt["provider"]
|
||||
self.base_url = rt["base_url"] # setter updates _base_url_lower
|
||||
self.api_mode = rt["api_mode"]
|
||||
self.api_key = rt["api_key"]
|
||||
self._client_kwargs = dict(rt["client_kwargs"])
|
||||
self._use_prompt_caching = rt["use_prompt_caching"]
|
||||
|
||||
# ── Rebuild client for the primary provider ──
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
self._anthropic_api_key = rt["anthropic_api_key"]
|
||||
self._anthropic_base_url = rt["anthropic_base_url"]
|
||||
self._anthropic_client = build_anthropic_client(
|
||||
rt["anthropic_api_key"], rt["anthropic_base_url"],
|
||||
)
|
||||
self._is_anthropic_oauth = rt["is_anthropic_oauth"]
|
||||
self.client = None
|
||||
else:
|
||||
self.client = self._create_openai_client(
|
||||
dict(rt["client_kwargs"]),
|
||||
reason="restore_primary",
|
||||
shared=True,
|
||||
)
|
||||
|
||||
# ── Restore context compressor state ──
|
||||
cc = self.context_compressor
|
||||
cc.model = rt["compressor_model"]
|
||||
cc.base_url = rt["compressor_base_url"]
|
||||
cc.api_key = rt["compressor_api_key"]
|
||||
cc.provider = rt["compressor_provider"]
|
||||
cc.context_length = rt["compressor_context_length"]
|
||||
cc.threshold_tokens = rt["compressor_threshold_tokens"]
|
||||
|
||||
# ── Reset fallback chain for the new turn ──
|
||||
self._fallback_activated = False
|
||||
self._fallback_index = 0
|
||||
|
||||
logging.info(
|
||||
"Primary runtime restored for new turn: %s (%s)",
|
||||
self.model, self.provider,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning("Failed to restore primary runtime: %s", e)
|
||||
return False
|
||||
|
||||
# Which error types indicate a transient transport failure worth
|
||||
# one more attempt with a rebuilt client / connection pool.
|
||||
_TRANSIENT_TRANSPORT_ERRORS = frozenset({
|
||||
"ReadTimeout", "ConnectTimeout", "PoolTimeout",
|
||||
"ConnectError", "RemoteProtocolError",
|
||||
})
|
||||
|
||||
def _try_recover_primary_transport(
|
||||
self, api_error: Exception, *, retry_count: int, max_retries: int,
|
||||
) -> bool:
|
||||
"""Attempt one extra primary-provider recovery cycle for transient transport failures.
|
||||
|
||||
After ``max_retries`` exhaust, rebuild the primary client (clearing
|
||||
stale connection pools) and give it one more attempt before falling
|
||||
back. This is most useful for direct endpoints (custom, Z.AI,
|
||||
Anthropic, OpenAI, local models) where a TCP-level hiccup does not
|
||||
mean the provider is down.
|
||||
|
||||
Skipped for proxy/aggregator providers (OpenRouter, Nous) which
|
||||
already manage connection pools and retries server-side — if our
|
||||
retries through them are exhausted, one more rebuilt client won't help.
|
||||
"""
|
||||
if self._fallback_activated:
|
||||
return False
|
||||
|
||||
# Only for transient transport errors
|
||||
error_type = type(api_error).__name__
|
||||
if error_type not in self._TRANSIENT_TRANSPORT_ERRORS:
|
||||
return False
|
||||
|
||||
# Skip for aggregator providers — they manage their own retry infra
|
||||
if self._is_openrouter_url():
|
||||
return False
|
||||
provider_lower = (self.provider or "").strip().lower()
|
||||
if provider_lower in ("nous", "nous-research"):
|
||||
return False
|
||||
|
||||
try:
|
||||
# Close existing client to release stale connections
|
||||
if getattr(self, "client", None) is not None:
|
||||
try:
|
||||
self._close_openai_client(
|
||||
self.client, reason="primary_recovery", shared=True,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Rebuild from primary snapshot
|
||||
rt = self._primary_runtime
|
||||
self._client_kwargs = dict(rt["client_kwargs"])
|
||||
self.model = rt["model"]
|
||||
self.provider = rt["provider"]
|
||||
self.base_url = rt["base_url"]
|
||||
self.api_mode = rt["api_mode"]
|
||||
self.api_key = rt["api_key"]
|
||||
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
self._anthropic_api_key = rt["anthropic_api_key"]
|
||||
self._anthropic_base_url = rt["anthropic_base_url"]
|
||||
self._anthropic_client = build_anthropic_client(
|
||||
rt["anthropic_api_key"], rt["anthropic_base_url"],
|
||||
)
|
||||
self._is_anthropic_oauth = rt["is_anthropic_oauth"]
|
||||
self.client = None
|
||||
else:
|
||||
self.client = self._create_openai_client(
|
||||
dict(rt["client_kwargs"]),
|
||||
reason="primary_recovery",
|
||||
shared=True,
|
||||
)
|
||||
|
||||
wait_time = min(3 + retry_count, 8)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}🔁 Transient {error_type} on {self.provider} — "
|
||||
f"rebuilt client, waiting {wait_time}s before one last primary attempt.",
|
||||
force=True,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.warning("Primary transport recovery failed: %s", e)
|
||||
return False
|
||||
|
||||
# ── End provider fallback ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
@@ -6403,6 +6581,11 @@ class AIAgent:
|
||||
# Installed once, transparent when streams are healthy, prevents crash on write.
|
||||
_install_safe_stdio()
|
||||
|
||||
# If the previous turn activated fallback, restore the primary
|
||||
# runtime so this turn gets a fresh attempt with the preferred model.
|
||||
# No-op when _fallback_activated is False (gateway, first turn, etc.).
|
||||
self._restore_primary_runtime()
|
||||
|
||||
# Sanitize surrogate characters from user input. Clipboard paste from
|
||||
# rich-text editors (Google Docs, Word, etc.) can inject lone surrogates
|
||||
# that are invalid UTF-8 and crash JSON serialization in the OpenAI SDK.
|
||||
@@ -6821,10 +7004,11 @@ class AIAgent:
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
primary_recovery_attempted = False
|
||||
max_compression_attempts = 3
|
||||
codex_auth_retry_attempted = False
|
||||
anthropic_auth_retry_attempted = False
|
||||
nous_auth_retry_attempted = False
|
||||
codex_auth_retry_attempted=False
|
||||
anthropic_auth_retry_attempted=False
|
||||
nous_auth_retry_attempted=False
|
||||
has_retried_429 = False
|
||||
restart_with_compressed_messages = False
|
||||
restart_with_length_continuation = False
|
||||
@@ -7657,6 +7841,16 @@ class AIAgent:
|
||||
}
|
||||
|
||||
if retry_count >= max_retries:
|
||||
# Before falling back, try rebuilding the primary
|
||||
# client once for transient transport errors (stale
|
||||
# connection pool, TCP reset). Only attempted once
|
||||
# per API call block.
|
||||
if not primary_recovery_attempted and self._try_recover_primary_transport(
|
||||
api_error, retry_count=retry_count, max_retries=max_retries,
|
||||
):
|
||||
primary_recovery_attempted = True
|
||||
retry_count = 0
|
||||
continue
|
||||
# Try fallback before giving up entirely
|
||||
self._emit_status(f"⚠️ Max retries ({max_retries}) exhausted — trying fallback...")
|
||||
if self._try_activate_fallback():
|
||||
|
||||
424
tests/test_primary_runtime_restore.py
Normal file
424
tests/test_primary_runtime_restore.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""Tests for per-turn primary runtime restoration and transport recovery.
|
||||
|
||||
Verifies that:
|
||||
1. Fallback is turn-scoped: a new turn restores the primary model/provider
|
||||
2. The fallback chain index resets so all fallbacks are available again
|
||||
3. Context compressor state is restored alongside the runtime
|
||||
4. Transient transport errors get one recovery cycle before fallback
|
||||
5. Recovery is skipped for aggregator providers (OpenRouter, Nous)
|
||||
6. Non-transport errors don't trigger recovery
|
||||
"""
|
||||
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": n,
|
||||
"description": f"{n} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
def _make_agent(fallback_model=None, provider="custom", base_url="https://my-llm.example.com/v1"):
|
||||
"""Create a minimal AIAgent with optional fallback config."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key-12345678",
|
||||
base_url=base_url,
|
||||
provider=provider,
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
fallback_model=fallback_model,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
def _mock_resolve(base_url="https://openrouter.ai/api/v1", api_key="fallback-key-1234"):
|
||||
"""Helper to create a mock client for resolve_provider_client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.api_key = api_key
|
||||
mock_client.base_url = base_url
|
||||
return mock_client
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _primary_runtime snapshot
|
||||
# =============================================================================
|
||||
|
||||
class TestPrimaryRuntimeSnapshot:
|
||||
def test_snapshot_created_at_init(self):
|
||||
agent = _make_agent()
|
||||
assert hasattr(agent, "_primary_runtime")
|
||||
rt = agent._primary_runtime
|
||||
assert rt["model"] == agent.model
|
||||
assert rt["provider"] == "custom"
|
||||
assert rt["base_url"] == "https://my-llm.example.com/v1"
|
||||
assert rt["api_mode"] == agent.api_mode
|
||||
assert "client_kwargs" in rt
|
||||
assert "compressor_context_length" in rt
|
||||
|
||||
def test_snapshot_includes_compressor_state(self):
|
||||
agent = _make_agent()
|
||||
rt = agent._primary_runtime
|
||||
cc = agent.context_compressor
|
||||
assert rt["compressor_model"] == cc.model
|
||||
assert rt["compressor_provider"] == cc.provider
|
||||
assert rt["compressor_context_length"] == cc.context_length
|
||||
assert rt["compressor_threshold_tokens"] == cc.threshold_tokens
|
||||
|
||||
def test_snapshot_includes_anthropic_state_when_applicable(self):
|
||||
"""Anthropic-mode agents should snapshot Anthropic-specific state."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="sk-ant-test-12345678",
|
||||
base_url="https://api.anthropic.com",
|
||||
provider="anthropic",
|
||||
api_mode="anthropic_messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
rt = agent._primary_runtime
|
||||
assert "anthropic_api_key" in rt
|
||||
assert "anthropic_base_url" in rt
|
||||
assert "is_anthropic_oauth" in rt
|
||||
|
||||
def test_snapshot_omits_anthropic_for_openai_mode(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
rt = agent._primary_runtime
|
||||
assert "anthropic_api_key" not in rt
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _restore_primary_runtime()
|
||||
# =============================================================================
|
||||
|
||||
class TestRestorePrimaryRuntime:
|
||||
def test_noop_when_not_fallback(self):
|
||||
agent = _make_agent()
|
||||
assert agent._fallback_activated is False
|
||||
assert agent._restore_primary_runtime() is False
|
||||
|
||||
def test_restores_model_and_provider(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
original_model = agent.model
|
||||
original_provider = agent.provider
|
||||
|
||||
# Simulate fallback activation
|
||||
mock_client = _mock_resolve()
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
|
||||
agent._try_activate_fallback()
|
||||
|
||||
assert agent._fallback_activated is True
|
||||
assert agent.model == "anthropic/claude-sonnet-4"
|
||||
assert agent.provider == "openrouter"
|
||||
|
||||
# Restore should bring back the primary
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
result = agent._restore_primary_runtime()
|
||||
|
||||
assert result is True
|
||||
assert agent._fallback_activated is False
|
||||
assert agent.model == original_model
|
||||
assert agent.provider == original_provider
|
||||
|
||||
def test_resets_fallback_index(self):
|
||||
"""After restore, the full fallback chain should be available again."""
|
||||
agent = _make_agent(
|
||||
fallback_model=[
|
||||
{"provider": "openrouter", "model": "model-a"},
|
||||
{"provider": "anthropic", "model": "model-b"},
|
||||
],
|
||||
)
|
||||
# Advance through the chain
|
||||
mock_client = _mock_resolve()
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
|
||||
agent._try_activate_fallback()
|
||||
|
||||
assert agent._fallback_index == 1 # consumed one entry
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
agent._restore_primary_runtime()
|
||||
|
||||
assert agent._fallback_index == 0 # reset for next turn
|
||||
|
||||
def test_restores_compressor_state(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
original_ctx_len = agent.context_compressor.context_length
|
||||
original_threshold = agent.context_compressor.threshold_tokens
|
||||
|
||||
# Simulate fallback modifying compressor
|
||||
mock_client = _mock_resolve()
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
|
||||
agent._try_activate_fallback()
|
||||
|
||||
# Manually simulate compressor being changed (as _try_activate_fallback does)
|
||||
agent.context_compressor.context_length = 32000
|
||||
agent.context_compressor.threshold_tokens = 25600
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
agent._restore_primary_runtime()
|
||||
|
||||
assert agent.context_compressor.context_length == original_ctx_len
|
||||
assert agent.context_compressor.threshold_tokens == original_threshold
|
||||
|
||||
def test_restores_prompt_caching_flag(self):
|
||||
agent = _make_agent()
|
||||
original_caching = agent._use_prompt_caching
|
||||
|
||||
# Simulate fallback changing the caching flag
|
||||
agent._fallback_activated = True
|
||||
agent._use_prompt_caching = not original_caching
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
agent._restore_primary_runtime()
|
||||
|
||||
assert agent._use_prompt_caching == original_caching
|
||||
|
||||
def test_restore_survives_exception(self):
|
||||
"""If client rebuild fails, the method returns False gracefully."""
|
||||
agent = _make_agent()
|
||||
agent._fallback_activated = True
|
||||
|
||||
with patch("run_agent.OpenAI", side_effect=Exception("connection refused")):
|
||||
result = agent._restore_primary_runtime()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _try_recover_primary_transport()
|
||||
# =============================================================================
|
||||
|
||||
def _make_transport_error(error_type="ReadTimeout"):
|
||||
"""Create an exception whose type().__name__ matches the given name."""
|
||||
cls = type(error_type, (Exception,), {})
|
||||
return cls("connection timed out")
|
||||
|
||||
|
||||
class TestTryRecoverPrimaryTransport:
|
||||
|
||||
def test_recovers_on_read_timeout(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_recovers_on_connect_timeout(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ConnectTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_recovers_on_pool_timeout(self):
|
||||
agent = _make_agent(provider="zai")
|
||||
error = _make_transport_error("PoolTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_skipped_when_already_on_fallback(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
agent._fallback_activated = True
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_skipped_for_non_transport_error(self):
|
||||
"""Non-transport errors (ValueError, APIError, etc.) skip recovery."""
|
||||
agent = _make_agent(provider="custom")
|
||||
error = ValueError("invalid model")
|
||||
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_skipped_for_openrouter(self):
|
||||
agent = _make_agent(provider="openrouter", base_url="https://openrouter.ai/api/v1")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_skipped_for_nous_provider(self):
|
||||
agent = _make_agent(provider="nous", base_url="https://inference.nous.nousresearch.com/v1")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_allowed_for_anthropic_direct(self):
|
||||
"""Direct Anthropic endpoint should get recovery."""
|
||||
agent = _make_agent(provider="anthropic", base_url="https://api.anthropic.com")
|
||||
# For non-anthropic_messages api_mode, it will use OpenAI client
|
||||
error = _make_transport_error("ConnectError")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_allowed_for_ollama(self):
|
||||
agent = _make_agent(provider="ollama", base_url="http://localhost:11434/v1")
|
||||
error = _make_transport_error("ConnectTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_wait_time_scales_with_retry_count(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep") as mock_sleep:
|
||||
agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
# wait_time = min(3 + retry_count, 8) = min(6, 8) = 6
|
||||
mock_sleep.assert_called_once_with(6)
|
||||
|
||||
def test_wait_time_capped_at_8(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep") as mock_sleep:
|
||||
agent._try_recover_primary_transport(
|
||||
error, retry_count=10, max_retries=3,
|
||||
)
|
||||
# wait_time = min(3 + 10, 8) = 8
|
||||
mock_sleep.assert_called_once_with(8)
|
||||
|
||||
def test_closes_existing_client_before_rebuild(self):
|
||||
agent = _make_agent(provider="custom")
|
||||
old_client = agent.client
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()), \
|
||||
patch("time.sleep"), \
|
||||
patch.object(agent, "_close_openai_client") as mock_close:
|
||||
agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
mock_close.assert_called_once_with(
|
||||
old_client, reason="primary_recovery", shared=True,
|
||||
)
|
||||
|
||||
def test_survives_rebuild_failure(self):
|
||||
"""If client rebuild fails, returns False gracefully."""
|
||||
agent = _make_agent(provider="custom")
|
||||
error = _make_transport_error("ReadTimeout")
|
||||
|
||||
with patch("run_agent.OpenAI", side_effect=Exception("socket error")), \
|
||||
patch("time.sleep"):
|
||||
result = agent._try_recover_primary_transport(
|
||||
error, retry_count=3, max_retries=3,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration: restore_primary_runtime called from run_conversation
|
||||
# =============================================================================
|
||||
|
||||
class TestRestoreInRunConversation:
|
||||
"""Verify the hook in run_conversation() calls _restore_primary_runtime."""
|
||||
|
||||
def test_restore_called_at_turn_start(self):
|
||||
agent = _make_agent()
|
||||
agent._fallback_activated = True
|
||||
|
||||
with patch.object(agent, "_restore_primary_runtime", return_value=True) as mock_restore, \
|
||||
patch.object(agent, "run_conversation", wraps=None) as _:
|
||||
# We can't easily run the full conversation, but we can verify
|
||||
# the method exists and is callable
|
||||
agent._restore_primary_runtime()
|
||||
mock_restore.assert_called_once()
|
||||
|
||||
def test_full_cycle_fallback_then_restore(self):
|
||||
"""Simulate: turn 1 activates fallback, turn 2 restores primary."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
provider="custom",
|
||||
)
|
||||
|
||||
# Turn 1: activate fallback
|
||||
mock_client = _mock_resolve()
|
||||
with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)):
|
||||
assert agent._try_activate_fallback() is True
|
||||
|
||||
assert agent._fallback_activated is True
|
||||
assert agent.model == "anthropic/claude-sonnet-4"
|
||||
assert agent.provider == "openrouter"
|
||||
assert agent._fallback_index == 1
|
||||
|
||||
# Turn 2: restore primary
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
assert agent._restore_primary_runtime() is True
|
||||
|
||||
assert agent._fallback_activated is False
|
||||
assert agent._fallback_index == 0
|
||||
assert agent.provider == "custom"
|
||||
assert agent.base_url == "https://my-llm.example.com/v1"
|
||||
Reference in New Issue
Block a user