diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 6b7bf19668..a4c4df8e97 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -1021,6 +1021,23 @@ _AUTO_PROVIDER_LABELS = { _AGGREGATOR_PROVIDERS = frozenset({"openrouter", "nous"}) +_MAIN_RUNTIME_FIELDS = ("provider", "model", "base_url", "api_key", "api_mode") + + +def _normalize_main_runtime(main_runtime: Optional[Dict[str, Any]]) -> Dict[str, str]: + """Return a sanitized copy of a live main-runtime override.""" + if not isinstance(main_runtime, dict): + return {} + normalized: Dict[str, str] = {} + for field in _MAIN_RUNTIME_FIELDS: + value = main_runtime.get(field) + if isinstance(value, str) and value.strip(): + normalized[field] = value.strip() + provider = normalized.get("provider") + if provider: + normalized["provider"] = provider.lower() + return normalized + def _get_provider_chain() -> List[tuple]: """Return the ordered provider detection chain. @@ -1130,7 +1147,7 @@ def _try_payment_fallback( return None, None, "" -def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: +def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Optional[OpenAI], Optional[str]]: """Full auto-detection chain. Priority: @@ -1142,6 +1159,12 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: """ global auxiliary_is_nous, _stale_base_url_warned auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins + runtime = _normalize_main_runtime(main_runtime) + runtime_provider = runtime.get("provider", "") + runtime_model = runtime.get("model", "") + runtime_base_url = runtime.get("base_url", "") + runtime_api_key = runtime.get("api_key", "") + runtime_api_mode = runtime.get("api_mode", "") # ── Warn once if OPENAI_BASE_URL is set but config.yaml uses a named # provider (not 'custom'). This catches the common "env poisoning" @@ -1149,7 +1172,7 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: # old OPENAI_BASE_URL lingers in ~/.hermes/.env. ── if not _stale_base_url_warned: _env_base = os.getenv("OPENAI_BASE_URL", "").strip() - _cfg_provider = _read_main_provider() + _cfg_provider = runtime_provider or _read_main_provider() if (_env_base and _cfg_provider and _cfg_provider != "custom" and not _cfg_provider.startswith("custom:")): @@ -1163,12 +1186,25 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: _stale_base_url_warned = True # ── Step 1: non-aggregator main provider → use main model directly ── - main_provider = _read_main_provider() - main_model = _read_main_model() + main_provider = runtime_provider or _read_main_provider() + main_model = runtime_model or _read_main_model() if (main_provider and main_model and main_provider not in _AGGREGATOR_PROVIDERS and main_provider not in ("auto", "")): - client, resolved = resolve_provider_client(main_provider, main_model) + resolved_provider = main_provider + explicit_base_url = None + explicit_api_key = None + if runtime_base_url and (main_provider == "custom" or main_provider.startswith("custom:")): + resolved_provider = "custom" + explicit_base_url = runtime_base_url + explicit_api_key = runtime_api_key or None + client, resolved = resolve_provider_client( + resolved_provider, + main_model, + explicit_base_url=explicit_base_url, + explicit_api_key=explicit_api_key, + api_mode=runtime_api_mode or None, + ) if client is not None: logger.info("Auxiliary auto-detect: using main provider %s (%s)", main_provider, resolved or main_model) @@ -1249,6 +1285,7 @@ def resolve_provider_client( explicit_base_url: str = None, explicit_api_key: str = None, api_mode: str = None, + main_runtime: Optional[Dict[str, Any]] = None, ) -> Tuple[Optional[Any], Optional[str]]: """Central router: given a provider name and optional model, return a configured client with the correct auth, base URL, and API format. @@ -1319,7 +1356,7 @@ def resolve_provider_client( # ── Auto: try all providers in priority order ──────────────────── if provider == "auto": - client, resolved = _resolve_auto() + client, resolved = _resolve_auto(main_runtime=main_runtime) if client is None: return None, None # When auto-detection lands on a non-OpenRouter provider (e.g. a @@ -1543,7 +1580,11 @@ def resolve_provider_client( # ── Public API ────────────────────────────────────────────────────────────── -def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optional[str]]: +def get_text_auxiliary_client( + task: str = "", + *, + main_runtime: Optional[Dict[str, Any]] = None, +) -> Tuple[Optional[OpenAI], Optional[str]]: """Return (client, default_model_slug) for text-only auxiliary tasks. Args: @@ -1560,10 +1601,11 @@ def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optiona explicit_base_url=base_url, explicit_api_key=api_key, api_mode=api_mode, + main_runtime=main_runtime, ) -def get_async_text_auxiliary_client(task: str = ""): +def get_async_text_auxiliary_client(task: str = "", *, main_runtime: Optional[Dict[str, Any]] = None): """Return (async_client, model_slug) for async consumers. For standard providers returns (AsyncOpenAI, model). For Codex returns @@ -1578,6 +1620,7 @@ def get_async_text_auxiliary_client(task: str = ""): explicit_base_url=base_url, explicit_api_key=api_key, api_mode=api_mode, + main_runtime=main_runtime, ) @@ -1892,6 +1935,7 @@ def _get_cached_client( base_url: str = None, api_key: str = None, api_mode: str = None, + main_runtime: Optional[Dict[str, Any]] = None, ) -> Tuple[Optional[Any], Optional[str]]: """Get or create a cached client for the given provider. @@ -1915,7 +1959,9 @@ def _get_cached_client( loop_id = id(current_loop) except RuntimeError: pass - cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", loop_id) + runtime = _normalize_main_runtime(main_runtime) + runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else () + cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", loop_id, runtime_key) with _client_cache_lock: if cache_key in _client_cache: cached_client, cached_default, cached_loop = _client_cache[cache_key] @@ -1940,6 +1986,7 @@ def _get_cached_client( explicit_base_url=base_url, explicit_api_key=api_key, api_mode=api_mode, + main_runtime=runtime, ) if client is not None: # For async clients, remember which loop they were created on so we @@ -2149,6 +2196,7 @@ def call_llm( model: str = None, base_url: str = None, api_key: str = None, + main_runtime: Optional[Dict[str, Any]] = None, messages: list, temperature: float = None, max_tokens: int = None, @@ -2214,6 +2262,7 @@ def call_llm( base_url=resolved_base_url, api_key=resolved_api_key, api_mode=resolved_api_mode, + main_runtime=main_runtime, ) if client is None: # When the user explicitly chose a non-OpenRouter provider but no @@ -2234,7 +2283,7 @@ def call_llm( if not resolved_base_url: logger.info("Auxiliary %s: provider %s unavailable, trying auto-detection chain", task or "call", resolved_provider) - client, final_model = _get_cached_client("auto") + client, final_model = _get_cached_client("auto", main_runtime=main_runtime) if client is None: raise RuntimeError( f"No LLM provider configured for task={task} provider={resolved_provider}. " diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 2701997fa6..4163966aaa 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -86,12 +86,14 @@ class ContextCompressor(ContextEngine): base_url: str = "", api_key: str = "", provider: str = "", + api_mode: str = "", ) -> None: """Update model info after a model switch or fallback activation.""" self.model = model self.base_url = base_url self.api_key = api_key self.provider = provider + self.api_mode = api_mode self.context_length = context_length self.threshold_tokens = max( int(context_length * self.threshold_percent), @@ -111,11 +113,13 @@ class ContextCompressor(ContextEngine): api_key: str = "", config_context_length: int | None = None, provider: str = "", + api_mode: str = "", ): self.model = model self.base_url = base_url self.api_key = api_key self.provider = provider + self.api_mode = api_mode self.threshold_percent = threshold_percent self.protect_first_n = protect_first_n self.protect_last_n = protect_last_n @@ -438,6 +442,13 @@ The user has requested that this compaction PRIORITISE preserving all informatio try: call_kwargs = { "task": "compression", + "main_runtime": { + "model": self.model, + "provider": self.provider, + "base_url": self.base_url, + "api_key": self.api_key, + "api_mode": self.api_mode, + }, "messages": [{"role": "user", "content": prompt}], "max_tokens": summary_budget * 2, # timeout resolved from auxiliary.compression.timeout config by call_llm diff --git a/run_agent.py b/run_agent.py index b230354542..01863a0a14 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1307,6 +1307,7 @@ class AIAgent: api_key=getattr(self, "api_key", ""), config_context_length=_config_context_length, provider=self.provider, + api_mode=self.api_mode, ) self.compression_enabled = compression_enabled @@ -1563,6 +1564,7 @@ class AIAgent: base_url=self.base_url, api_key=getattr(self, "api_key", ""), provider=self.provider, + api_mode=self.api_mode, ) # ── Invalidate cached system prompt so it rebuilds next turn ── @@ -1696,6 +1698,16 @@ class AIAgent: except Exception: logger.debug("status_callback error in _emit_status", exc_info=True) + def _current_main_runtime(self) -> Dict[str, str]: + """Return the live main runtime for session-scoped auxiliary routing.""" + return { + "model": getattr(self, "model", "") or "", + "provider": getattr(self, "provider", "") or "", + "base_url": getattr(self, "base_url", "") or "", + "api_key": getattr(self, "api_key", "") or "", + "api_mode": getattr(self, "api_mode", "") or "", + } + def _check_compression_model_feasibility(self) -> None: """Warn at session start if the auxiliary compression model's context window is smaller than the main model's compression threshold. @@ -1716,7 +1728,10 @@ class AIAgent: from agent.auxiliary_client import get_text_auxiliary_client from agent.model_metadata import get_model_context_length - client, aux_model = get_text_auxiliary_client("compression") + client, aux_model = get_text_auxiliary_client( + "compression", + main_runtime=self._current_main_runtime(), + ) if client is None or not aux_model: msg = ( "⚠ No auxiliary LLM provider configured — context " diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index a38b62568a..e1164ace8a 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -971,6 +971,74 @@ class TestTaskSpecificOverrides: client, model = get_text_auxiliary_client("compression") assert model == "google/gemini-3-flash-preview" # auto → OpenRouter + def test_resolve_auto_prefers_live_main_runtime_over_persisted_config(self, monkeypatch, tmp_path): + """Session-only live model switches should override persisted config for auto routing.""" + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "config.yaml").write_text( + """model: + default: glm-5.1 + provider: opencode-go +compression: + summary_provider: auto +""" + ) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + calls = [] + + def _fake_resolve(provider, model=None, *args, **kwargs): + calls.append((provider, model, kwargs)) + return MagicMock(), model or "resolved-model" + + with patch("agent.auxiliary_client.resolve_provider_client", side_effect=_fake_resolve): + client, model = _resolve_auto( + main_runtime={ + "provider": "openai-codex", + "model": "gpt-5.4", + "api_mode": "codex_responses", + } + ) + + assert client is not None + assert model == "gpt-5.4" + assert calls[0][0] == "openai-codex" + assert calls[0][1] == "gpt-5.4" + assert calls[0][2]["api_mode"] == "codex_responses" + + def test_explicit_compression_pin_still_wins_over_live_main_runtime(self, monkeypatch, tmp_path): + """Task-level compression config should beat a live session override.""" + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "config.yaml").write_text( + """auxiliary: + compression: + provider: openrouter + model: google/gemini-3-flash-preview +model: + default: glm-5.1 + provider: opencode-go +""" + ) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + with patch("agent.auxiliary_client.resolve_provider_client", return_value=(MagicMock(), "google/gemini-3-flash-preview")) as mock_resolve: + client, model = get_text_auxiliary_client( + "compression", + main_runtime={ + "provider": "openai-codex", + "model": "gpt-5.4", + }, + ) + + assert client is not None + assert model == "google/gemini-3-flash-preview" + assert mock_resolve.call_args.args[0] == "openrouter" + assert mock_resolve.call_args.kwargs["main_runtime"] == { + "provider": "openai-codex", + "model": "gpt-5.4", + } + def test_compression_summary_base_url_from_config(self, monkeypatch, tmp_path): """compression.summary_base_url should produce a custom-endpoint client.""" hermes_home = tmp_path / "hermes" diff --git a/tests/agent/test_context_compressor.py b/tests/agent/test_context_compressor.py index f4cf19666f..6164d812f6 100644 --- a/tests/agent/test_context_compressor.py +++ b/tests/agent/test_context_compressor.py @@ -191,6 +191,37 @@ class TestNonStringContent: kwargs = mock_call.call_args.kwargs assert "temperature" not in kwargs + def test_summary_call_passes_live_main_runtime(self): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "ok" + + with patch("agent.context_compressor.get_model_context_length", return_value=100000): + c = ContextCompressor( + model="gpt-5.4", + provider="openai-codex", + base_url="https://chatgpt.com/backend-api/codex", + api_key="codex-token", + api_mode="codex_responses", + quiet_mode=True, + ) + + messages = [ + {"role": "user", "content": "do something"}, + {"role": "assistant", "content": "ok"}, + ] + + with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_call: + c._generate_summary(messages) + + assert mock_call.call_args.kwargs["main_runtime"] == { + "model": "gpt-5.4", + "provider": "openai-codex", + "base_url": "https://chatgpt.com/backend-api/codex", + "api_key": "codex-token", + "api_mode": "codex_responses", + } + class TestSummaryFailureCooldown: def test_summary_failure_enters_cooldown_and_skips_retry(self): diff --git a/tests/run_agent/test_compression_feasibility.py b/tests/run_agent/test_compression_feasibility.py index 1b4423414e..0738b1d438 100644 --- a/tests/run_agent/test_compression_feasibility.py +++ b/tests/run_agent/test_compression_feasibility.py @@ -26,6 +26,7 @@ def _make_agent( agent.provider = "openrouter" agent.base_url = "https://openrouter.ai/api/v1" agent.api_key = "sk-test" + agent.api_mode = "chat_completions" agent.quiet_mode = True agent.log_prefix = "" agent.compression_enabled = compression_enabled @@ -99,6 +100,36 @@ def test_no_warning_when_aux_context_sufficient(mock_get_client, mock_ctx_len): assert agent._compression_warning is None +def test_feasibility_check_passes_live_main_runtime(): + """Compression feasibility should probe using the live session runtime.""" + agent = _make_agent(main_context=200_000, threshold_percent=0.50) + agent.model = "gpt-5.4" + agent.provider = "openai-codex" + agent.base_url = "https://chatgpt.com/backend-api/codex" + agent.api_key = "codex-token" + agent.api_mode = "codex_responses" + + mock_client = MagicMock() + mock_client.base_url = "https://chatgpt.com/backend-api/codex" + mock_client.api_key = "codex-token" + + with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(mock_client, "gpt-5.4")) as mock_get_client, \ + patch("agent.model_metadata.get_model_context_length", return_value=200_000): + agent._emit_status = lambda msg: None + agent._check_compression_model_feasibility() + + mock_get_client.assert_called_once_with( + "compression", + main_runtime={ + "model": "gpt-5.4", + "provider": "openai-codex", + "base_url": "https://chatgpt.com/backend-api/codex", + "api_key": "codex-token", + "api_mode": "codex_responses", + }, + ) + + @patch("agent.auxiliary_client.get_text_auxiliary_client") def test_warns_when_no_auxiliary_provider(mock_get_client): """Warning emitted when no auxiliary provider is configured."""