diff --git a/run_agent.py b/run_agent.py index c8c471e352..459b8d0ef1 100644 --- a/run_agent.py +++ b/run_agent.py @@ -3208,23 +3208,17 @@ class AIAgent: else: result["response"] = _call_chat_completions() except Exception as e: - err_text = str(e).lower() - # Fall back to non-streaming if provider doesn't support it. - # Be specific in matching — "stream" alone is too broad and - # catches unrelated errors like "stream_options" rejections. - stream_unsupported = any( - kw in err_text - for kw in ("streaming is not", "streaming not support", - "does not support stream", "not available") - ) - if stream_unsupported: - logger.info("Streaming not supported by provider, falling back to non-streaming: %s", e) - try: - result["response"] = self._interruptible_api_call(api_kwargs) - except Exception as fallback_err: - result["error"] = fallback_err - else: - result["error"] = e + # Always fall back to non-streaming on ANY streaming error. + # Many third-party/extrinsic providers have partial or broken + # streaming support — rejecting stream=True, crashing on + # stream_options, dropping connections mid-stream, etc. + # A clean fallback to the standard request path ensures the + # agent still works even if streaming doesn't. + logger.info("Streaming failed, falling back to non-streaming: %s", e) + try: + result["response"] = self._interruptible_api_call(api_kwargs) + except Exception as fallback_err: + result["error"] = fallback_err finally: request_client = request_client_holder.get("client") if request_client is not None: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 3615c2a94b..6cc34d972c 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -321,7 +321,7 @@ class TestStreamingCallbacks: class TestStreamingFallback: - """Verify fallback to non-streaming on unsupported providers.""" + """Verify fallback to non-streaming on ANY streaming error.""" @patch("run_agent.AIAgent._interruptible_api_call") @patch("run_agent.AIAgent._create_request_openai_client") @@ -367,16 +367,63 @@ class TestStreamingFallback: assert response.choices[0].message.content == "fallback response" mock_non_stream.assert_called_once() + @patch("run_agent.AIAgent._interruptible_api_call") @patch("run_agent.AIAgent._create_request_openai_client") @patch("run_agent.AIAgent._close_request_openai_client") - def test_non_stream_error_raises(self, mock_close, mock_create): - """Non-streaming errors propagate normally.""" + def test_any_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream): + """ANY streaming error triggers fallback — not just specific messages.""" from run_agent import AIAgent mock_client = MagicMock() - mock_client.chat.completions.create.side_effect = Exception("Rate limit exceeded") + mock_client.chat.completions.create.side_effect = Exception( + "Connection reset by peer" + ) mock_create.return_value = mock_client + fallback_response = SimpleNamespace( + id="fallback", + model="test", + choices=[SimpleNamespace( + index=0, + message=SimpleNamespace( + role="assistant", + content="fallback after connection error", + tool_calls=None, + reasoning_content=None, + ), + finish_reason="stop", + )], + usage=None, + ) + mock_non_stream.return_value = fallback_response + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + response = agent._interruptible_streaming_api_call({}) + + assert response.choices[0].message.content == "fallback after connection error" + mock_non_stream.assert_called_once() + + @patch("run_agent.AIAgent._interruptible_api_call") + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_fallback_error_propagates(self, mock_close, mock_create, mock_non_stream): + """When both streaming AND fallback fail, the fallback error propagates.""" + from run_agent import AIAgent + + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = Exception("stream broke") + mock_create.return_value = mock_client + + mock_non_stream.side_effect = Exception("Rate limit exceeded") + agent = AIAgent( model="test/model", quiet_mode=True,