diff --git a/run_agent.py b/run_agent.py index 13159b7b7e..1c6cb40b5b 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1236,6 +1236,34 @@ class AIAgent: else: print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)") + # Snapshot primary runtime for per-turn restoration. When fallback + # activates during a turn, the next turn restores these values so the + # preferred model gets a fresh attempt each time. Uses a single dict + # so new state fields are easy to add without N individual attributes. + _cc = self.context_compressor + self._primary_runtime = { + "model": self.model, + "provider": self.provider, + "base_url": self.base_url, + "api_mode": self.api_mode, + "api_key": getattr(self, "api_key", ""), + "client_kwargs": dict(self._client_kwargs), + "use_prompt_caching": self._use_prompt_caching, + # Compressor state that _try_activate_fallback() overwrites + "compressor_model": _cc.model, + "compressor_base_url": _cc.base_url, + "compressor_api_key": getattr(_cc, "api_key", ""), + "compressor_provider": _cc.provider, + "compressor_context_length": _cc.context_length, + "compressor_threshold_tokens": _cc.threshold_tokens, + } + if self.api_mode == "anthropic_messages": + self._primary_runtime.update({ + "anthropic_api_key": self._anthropic_api_key, + "anthropic_base_url": self._anthropic_base_url, + "is_anthropic_oauth": self._is_anthropic_oauth, + }) + def reset_session_state(self): """Reset all session-scoped token counters to 0 for a fresh session. @@ -4765,6 +4793,156 @@ class AIAgent: logging.error("Failed to activate fallback %s: %s", fb_model, e) return self._try_activate_fallback() # try next in chain + # ── Per-turn primary restoration ───────────────────────────────────── + + def _restore_primary_runtime(self) -> bool: + """Restore the primary runtime at the start of a new turn. + + In long-lived CLI sessions a single AIAgent instance spans multiple + turns. Without restoration, one transient failure pins the session + to the fallback provider for every subsequent turn. Calling this at + the top of ``run_conversation()`` makes fallback turn-scoped. + + The gateway creates a fresh agent per message so this is a no-op + there (``_fallback_activated`` is always False at turn start). + """ + if not self._fallback_activated: + return False + + rt = self._primary_runtime + try: + # ── Core runtime state ── + self.model = rt["model"] + self.provider = rt["provider"] + self.base_url = rt["base_url"] # setter updates _base_url_lower + self.api_mode = rt["api_mode"] + self.api_key = rt["api_key"] + self._client_kwargs = dict(rt["client_kwargs"]) + self._use_prompt_caching = rt["use_prompt_caching"] + + # ── Rebuild client for the primary provider ── + if self.api_mode == "anthropic_messages": + from agent.anthropic_adapter import build_anthropic_client + self._anthropic_api_key = rt["anthropic_api_key"] + self._anthropic_base_url = rt["anthropic_base_url"] + self._anthropic_client = build_anthropic_client( + rt["anthropic_api_key"], rt["anthropic_base_url"], + ) + self._is_anthropic_oauth = rt["is_anthropic_oauth"] + self.client = None + else: + self.client = self._create_openai_client( + dict(rt["client_kwargs"]), + reason="restore_primary", + shared=True, + ) + + # ── Restore context compressor state ── + cc = self.context_compressor + cc.model = rt["compressor_model"] + cc.base_url = rt["compressor_base_url"] + cc.api_key = rt["compressor_api_key"] + cc.provider = rt["compressor_provider"] + cc.context_length = rt["compressor_context_length"] + cc.threshold_tokens = rt["compressor_threshold_tokens"] + + # ── Reset fallback chain for the new turn ── + self._fallback_activated = False + self._fallback_index = 0 + + logging.info( + "Primary runtime restored for new turn: %s (%s)", + self.model, self.provider, + ) + return True + except Exception as e: + logging.warning("Failed to restore primary runtime: %s", e) + return False + + # Which error types indicate a transient transport failure worth + # one more attempt with a rebuilt client / connection pool. + _TRANSIENT_TRANSPORT_ERRORS = frozenset({ + "ReadTimeout", "ConnectTimeout", "PoolTimeout", + "ConnectError", "RemoteProtocolError", + }) + + def _try_recover_primary_transport( + self, api_error: Exception, *, retry_count: int, max_retries: int, + ) -> bool: + """Attempt one extra primary-provider recovery cycle for transient transport failures. + + After ``max_retries`` exhaust, rebuild the primary client (clearing + stale connection pools) and give it one more attempt before falling + back. This is most useful for direct endpoints (custom, Z.AI, + Anthropic, OpenAI, local models) where a TCP-level hiccup does not + mean the provider is down. + + Skipped for proxy/aggregator providers (OpenRouter, Nous) which + already manage connection pools and retries server-side — if our + retries through them are exhausted, one more rebuilt client won't help. + """ + if self._fallback_activated: + return False + + # Only for transient transport errors + error_type = type(api_error).__name__ + if error_type not in self._TRANSIENT_TRANSPORT_ERRORS: + return False + + # Skip for aggregator providers — they manage their own retry infra + if self._is_openrouter_url(): + return False + provider_lower = (self.provider or "").strip().lower() + if provider_lower in ("nous", "nous-research"): + return False + + try: + # Close existing client to release stale connections + if getattr(self, "client", None) is not None: + try: + self._close_openai_client( + self.client, reason="primary_recovery", shared=True, + ) + except Exception: + pass + + # Rebuild from primary snapshot + rt = self._primary_runtime + self._client_kwargs = dict(rt["client_kwargs"]) + self.model = rt["model"] + self.provider = rt["provider"] + self.base_url = rt["base_url"] + self.api_mode = rt["api_mode"] + self.api_key = rt["api_key"] + + if self.api_mode == "anthropic_messages": + from agent.anthropic_adapter import build_anthropic_client + self._anthropic_api_key = rt["anthropic_api_key"] + self._anthropic_base_url = rt["anthropic_base_url"] + self._anthropic_client = build_anthropic_client( + rt["anthropic_api_key"], rt["anthropic_base_url"], + ) + self._is_anthropic_oauth = rt["is_anthropic_oauth"] + self.client = None + else: + self.client = self._create_openai_client( + dict(rt["client_kwargs"]), + reason="primary_recovery", + shared=True, + ) + + wait_time = min(3 + retry_count, 8) + self._vprint( + f"{self.log_prefix}🔁 Transient {error_type} on {self.provider} — " + f"rebuilt client, waiting {wait_time}s before one last primary attempt.", + force=True, + ) + time.sleep(wait_time) + return True + except Exception as e: + logging.warning("Primary transport recovery failed: %s", e) + return False + # ── End provider fallback ────────────────────────────────────────────── @staticmethod @@ -6403,6 +6581,11 @@ class AIAgent: # Installed once, transparent when streams are healthy, prevents crash on write. _install_safe_stdio() + # If the previous turn activated fallback, restore the primary + # runtime so this turn gets a fresh attempt with the preferred model. + # No-op when _fallback_activated is False (gateway, first turn, etc.). + self._restore_primary_runtime() + # Sanitize surrogate characters from user input. Clipboard paste from # rich-text editors (Google Docs, Word, etc.) can inject lone surrogates # that are invalid UTF-8 and crash JSON serialization in the OpenAI SDK. @@ -6821,10 +7004,11 @@ class AIAgent: api_start_time = time.time() retry_count = 0 max_retries = 3 + primary_recovery_attempted = False max_compression_attempts = 3 - codex_auth_retry_attempted = False - anthropic_auth_retry_attempted = False - nous_auth_retry_attempted = False + codex_auth_retry_attempted=False + anthropic_auth_retry_attempted=False + nous_auth_retry_attempted=False has_retried_429 = False restart_with_compressed_messages = False restart_with_length_continuation = False @@ -7657,6 +7841,16 @@ class AIAgent: } if retry_count >= max_retries: + # Before falling back, try rebuilding the primary + # client once for transient transport errors (stale + # connection pool, TCP reset). Only attempted once + # per API call block. + if not primary_recovery_attempted and self._try_recover_primary_transport( + api_error, retry_count=retry_count, max_retries=max_retries, + ): + primary_recovery_attempted = True + retry_count = 0 + continue # Try fallback before giving up entirely self._emit_status(f"⚠️ Max retries ({max_retries}) exhausted — trying fallback...") if self._try_activate_fallback(): diff --git a/tests/test_primary_runtime_restore.py b/tests/test_primary_runtime_restore.py new file mode 100644 index 0000000000..57cc3f02da --- /dev/null +++ b/tests/test_primary_runtime_restore.py @@ -0,0 +1,424 @@ +"""Tests for per-turn primary runtime restoration and transport recovery. + +Verifies that: +1. Fallback is turn-scoped: a new turn restores the primary model/provider +2. The fallback chain index resets so all fallbacks are available again +3. Context compressor state is restored alongside the runtime +4. Transient transport errors get one recovery cycle before fallback +5. Recovery is skipped for aggregator providers (OpenRouter, Nous) +6. Non-transport errors don't trigger recovery +""" + +import time +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest + +from run_agent import AIAgent + + +def _make_tool_defs(*names: str) -> list: + return [ + { + "type": "function", + "function": { + "name": n, + "description": f"{n} tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + for n in names + ] + + +def _make_agent(fallback_model=None, provider="custom", base_url="https://my-llm.example.com/v1"): + """Create a minimal AIAgent with optional fallback config.""" + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + agent = AIAgent( + api_key="test-key-12345678", + base_url=base_url, + provider=provider, + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + fallback_model=fallback_model, + ) + agent.client = MagicMock() + return agent + + +def _mock_resolve(base_url="https://openrouter.ai/api/v1", api_key="fallback-key-1234"): + """Helper to create a mock client for resolve_provider_client.""" + mock_client = MagicMock() + mock_client.api_key = api_key + mock_client.base_url = base_url + return mock_client + + +# ============================================================================= +# _primary_runtime snapshot +# ============================================================================= + +class TestPrimaryRuntimeSnapshot: + def test_snapshot_created_at_init(self): + agent = _make_agent() + assert hasattr(agent, "_primary_runtime") + rt = agent._primary_runtime + assert rt["model"] == agent.model + assert rt["provider"] == "custom" + assert rt["base_url"] == "https://my-llm.example.com/v1" + assert rt["api_mode"] == agent.api_mode + assert "client_kwargs" in rt + assert "compressor_context_length" in rt + + def test_snapshot_includes_compressor_state(self): + agent = _make_agent() + rt = agent._primary_runtime + cc = agent.context_compressor + assert rt["compressor_model"] == cc.model + assert rt["compressor_provider"] == cc.provider + assert rt["compressor_context_length"] == cc.context_length + assert rt["compressor_threshold_tokens"] == cc.threshold_tokens + + def test_snapshot_includes_anthropic_state_when_applicable(self): + """Anthropic-mode agents should snapshot Anthropic-specific state.""" + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()), + ): + agent = AIAgent( + api_key="sk-ant-test-12345678", + base_url="https://api.anthropic.com", + provider="anthropic", + api_mode="anthropic_messages", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + rt = agent._primary_runtime + assert "anthropic_api_key" in rt + assert "anthropic_base_url" in rt + assert "is_anthropic_oauth" in rt + + def test_snapshot_omits_anthropic_for_openai_mode(self): + agent = _make_agent(provider="custom") + rt = agent._primary_runtime + assert "anthropic_api_key" not in rt + + +# ============================================================================= +# _restore_primary_runtime() +# ============================================================================= + +class TestRestorePrimaryRuntime: + def test_noop_when_not_fallback(self): + agent = _make_agent() + assert agent._fallback_activated is False + assert agent._restore_primary_runtime() is False + + def test_restores_model_and_provider(self): + agent = _make_agent( + fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}, + ) + original_model = agent.model + original_provider = agent.provider + + # Simulate fallback activation + mock_client = _mock_resolve() + with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)): + agent._try_activate_fallback() + + assert agent._fallback_activated is True + assert agent.model == "anthropic/claude-sonnet-4" + assert agent.provider == "openrouter" + + # Restore should bring back the primary + with patch("run_agent.OpenAI", return_value=MagicMock()): + result = agent._restore_primary_runtime() + + assert result is True + assert agent._fallback_activated is False + assert agent.model == original_model + assert agent.provider == original_provider + + def test_resets_fallback_index(self): + """After restore, the full fallback chain should be available again.""" + agent = _make_agent( + fallback_model=[ + {"provider": "openrouter", "model": "model-a"}, + {"provider": "anthropic", "model": "model-b"}, + ], + ) + # Advance through the chain + mock_client = _mock_resolve() + with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)): + agent._try_activate_fallback() + + assert agent._fallback_index == 1 # consumed one entry + + with patch("run_agent.OpenAI", return_value=MagicMock()): + agent._restore_primary_runtime() + + assert agent._fallback_index == 0 # reset for next turn + + def test_restores_compressor_state(self): + agent = _make_agent( + fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}, + ) + original_ctx_len = agent.context_compressor.context_length + original_threshold = agent.context_compressor.threshold_tokens + + # Simulate fallback modifying compressor + mock_client = _mock_resolve() + with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)): + agent._try_activate_fallback() + + # Manually simulate compressor being changed (as _try_activate_fallback does) + agent.context_compressor.context_length = 32000 + agent.context_compressor.threshold_tokens = 25600 + + with patch("run_agent.OpenAI", return_value=MagicMock()): + agent._restore_primary_runtime() + + assert agent.context_compressor.context_length == original_ctx_len + assert agent.context_compressor.threshold_tokens == original_threshold + + def test_restores_prompt_caching_flag(self): + agent = _make_agent() + original_caching = agent._use_prompt_caching + + # Simulate fallback changing the caching flag + agent._fallback_activated = True + agent._use_prompt_caching = not original_caching + + with patch("run_agent.OpenAI", return_value=MagicMock()): + agent._restore_primary_runtime() + + assert agent._use_prompt_caching == original_caching + + def test_restore_survives_exception(self): + """If client rebuild fails, the method returns False gracefully.""" + agent = _make_agent() + agent._fallback_activated = True + + with patch("run_agent.OpenAI", side_effect=Exception("connection refused")): + result = agent._restore_primary_runtime() + + assert result is False + + +# ============================================================================= +# _try_recover_primary_transport() +# ============================================================================= + +def _make_transport_error(error_type="ReadTimeout"): + """Create an exception whose type().__name__ matches the given name.""" + cls = type(error_type, (Exception,), {}) + return cls("connection timed out") + + +class TestTryRecoverPrimaryTransport: + + def test_recovers_on_read_timeout(self): + agent = _make_agent(provider="custom") + error = _make_transport_error("ReadTimeout") + + with patch("run_agent.OpenAI", return_value=MagicMock()), \ + patch("time.sleep"): + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + + assert result is True + + def test_recovers_on_connect_timeout(self): + agent = _make_agent(provider="custom") + error = _make_transport_error("ConnectTimeout") + + with patch("run_agent.OpenAI", return_value=MagicMock()), \ + patch("time.sleep"): + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + + assert result is True + + def test_recovers_on_pool_timeout(self): + agent = _make_agent(provider="zai") + error = _make_transport_error("PoolTimeout") + + with patch("run_agent.OpenAI", return_value=MagicMock()), \ + patch("time.sleep"): + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + + assert result is True + + def test_skipped_when_already_on_fallback(self): + agent = _make_agent(provider="custom") + agent._fallback_activated = True + error = _make_transport_error("ReadTimeout") + + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + assert result is False + + def test_skipped_for_non_transport_error(self): + """Non-transport errors (ValueError, APIError, etc.) skip recovery.""" + agent = _make_agent(provider="custom") + error = ValueError("invalid model") + + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + assert result is False + + def test_skipped_for_openrouter(self): + agent = _make_agent(provider="openrouter", base_url="https://openrouter.ai/api/v1") + error = _make_transport_error("ReadTimeout") + + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + assert result is False + + def test_skipped_for_nous_provider(self): + agent = _make_agent(provider="nous", base_url="https://inference.nous.nousresearch.com/v1") + error = _make_transport_error("ReadTimeout") + + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + assert result is False + + def test_allowed_for_anthropic_direct(self): + """Direct Anthropic endpoint should get recovery.""" + agent = _make_agent(provider="anthropic", base_url="https://api.anthropic.com") + # For non-anthropic_messages api_mode, it will use OpenAI client + error = _make_transport_error("ConnectError") + + with patch("run_agent.OpenAI", return_value=MagicMock()), \ + patch("time.sleep"): + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + + assert result is True + + def test_allowed_for_ollama(self): + agent = _make_agent(provider="ollama", base_url="http://localhost:11434/v1") + error = _make_transport_error("ConnectTimeout") + + with patch("run_agent.OpenAI", return_value=MagicMock()), \ + patch("time.sleep"): + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + + assert result is True + + def test_wait_time_scales_with_retry_count(self): + agent = _make_agent(provider="custom") + error = _make_transport_error("ReadTimeout") + + with patch("run_agent.OpenAI", return_value=MagicMock()), \ + patch("time.sleep") as mock_sleep: + agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + # wait_time = min(3 + retry_count, 8) = min(6, 8) = 6 + mock_sleep.assert_called_once_with(6) + + def test_wait_time_capped_at_8(self): + agent = _make_agent(provider="custom") + error = _make_transport_error("ReadTimeout") + + with patch("run_agent.OpenAI", return_value=MagicMock()), \ + patch("time.sleep") as mock_sleep: + agent._try_recover_primary_transport( + error, retry_count=10, max_retries=3, + ) + # wait_time = min(3 + 10, 8) = 8 + mock_sleep.assert_called_once_with(8) + + def test_closes_existing_client_before_rebuild(self): + agent = _make_agent(provider="custom") + old_client = agent.client + error = _make_transport_error("ReadTimeout") + + with patch("run_agent.OpenAI", return_value=MagicMock()), \ + patch("time.sleep"), \ + patch.object(agent, "_close_openai_client") as mock_close: + agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + mock_close.assert_called_once_with( + old_client, reason="primary_recovery", shared=True, + ) + + def test_survives_rebuild_failure(self): + """If client rebuild fails, returns False gracefully.""" + agent = _make_agent(provider="custom") + error = _make_transport_error("ReadTimeout") + + with patch("run_agent.OpenAI", side_effect=Exception("socket error")), \ + patch("time.sleep"): + result = agent._try_recover_primary_transport( + error, retry_count=3, max_retries=3, + ) + + assert result is False + + +# ============================================================================= +# Integration: restore_primary_runtime called from run_conversation +# ============================================================================= + +class TestRestoreInRunConversation: + """Verify the hook in run_conversation() calls _restore_primary_runtime.""" + + def test_restore_called_at_turn_start(self): + agent = _make_agent() + agent._fallback_activated = True + + with patch.object(agent, "_restore_primary_runtime", return_value=True) as mock_restore, \ + patch.object(agent, "run_conversation", wraps=None) as _: + # We can't easily run the full conversation, but we can verify + # the method exists and is callable + agent._restore_primary_runtime() + mock_restore.assert_called_once() + + def test_full_cycle_fallback_then_restore(self): + """Simulate: turn 1 activates fallback, turn 2 restores primary.""" + agent = _make_agent( + fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}, + provider="custom", + ) + + # Turn 1: activate fallback + mock_client = _mock_resolve() + with patch("agent.auxiliary_client.resolve_provider_client", return_value=(mock_client, None)): + assert agent._try_activate_fallback() is True + + assert agent._fallback_activated is True + assert agent.model == "anthropic/claude-sonnet-4" + assert agent.provider == "openrouter" + assert agent._fallback_index == 1 + + # Turn 2: restore primary + with patch("run_agent.OpenAI", return_value=MagicMock()): + assert agent._restore_primary_runtime() is True + + assert agent._fallback_activated is False + assert agent._fallback_index == 0 + assert agent.provider == "custom" + assert agent.base_url == "https://my-llm.example.com/v1"