diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 84f023f83b..c1c2e2f9a3 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -1438,10 +1438,14 @@ def resolve_provider_client( custom_entry = _get_named_custom_provider(provider) if custom_entry: custom_base = custom_entry.get("base_url", "").strip() - custom_key = custom_entry.get("api_key", "").strip() or "no-key-required" + custom_key = custom_entry.get("api_key", "").strip() + custom_key_env = custom_entry.get("key_env", "").strip() + if not custom_key and custom_key_env: + custom_key = os.getenv(custom_key_env, "").strip() + custom_key = custom_key or "no-key-required" if custom_base: final_model = _normalize_resolved_model( - model or _read_main_model() or "gpt-4o-mini", + model or custom_entry.get("model") or _read_main_model() or "gpt-4o-mini", provider, ) client = OpenAI(api_key=custom_key, base_url=custom_base) diff --git a/agent/credential_pool.py b/agent/credential_pool.py index e067fb9014..ea9ad92329 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -289,6 +289,14 @@ def _iter_custom_providers(config: Optional[dict] = None): return custom_providers = config.get("custom_providers") if not isinstance(custom_providers, list): + # Fall back to the v12+ providers dict via the compatibility layer + try: + from hermes_cli.config import get_compatible_custom_providers + + custom_providers = get_compatible_custom_providers(config) + except Exception: + return + if not custom_providers: return for entry in custom_providers: if not isinstance(entry, dict): diff --git a/cli.py b/cli.py index a61bcd9d33..dcb5bfcc5f 100644 --- a/cli.py +++ b/cli.py @@ -4710,10 +4710,10 @@ class HermesCLI: user_provs = None custom_provs = None try: - from hermes_cli.config import load_config + from hermes_cli.config import get_compatible_custom_providers, load_config cfg = load_config() user_provs = cfg.get("providers") - custom_provs = cfg.get("custom_providers") + custom_provs = get_compatible_custom_providers(cfg) except Exception: pass diff --git a/gateway/run.py b/gateway/run.py index 4c30db7db8..afc5aa035e 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3330,21 +3330,26 @@ class GatewayRunner: # Must run after runtime resolution so _hyg_base_url is set. if _hyg_config_context_length is None and _hyg_base_url: try: - _hyg_custom_providers = _hyg_data.get("custom_providers") - if isinstance(_hyg_custom_providers, list): - for _cp in _hyg_custom_providers: - if not isinstance(_cp, dict): - continue - _cp_url = (_cp.get("base_url") or "").rstrip("/") - if _cp_url and _cp_url == _hyg_base_url.rstrip("/"): - _cp_models = _cp.get("models", {}) - if isinstance(_cp_models, dict): - _cp_model_cfg = _cp_models.get(_hyg_model, {}) - if isinstance(_cp_model_cfg, dict): - _cp_ctx = _cp_model_cfg.get("context_length") - if _cp_ctx is not None: - _hyg_config_context_length = int(_cp_ctx) - break + try: + from hermes_cli.config import get_compatible_custom_providers as _gw_gcp + _hyg_custom_providers = _gw_gcp(_hyg_data) + except Exception: + _hyg_custom_providers = _hyg_data.get("custom_providers") + if not isinstance(_hyg_custom_providers, list): + _hyg_custom_providers = [] + for _cp in _hyg_custom_providers: + if not isinstance(_cp, dict): + continue + _cp_url = (_cp.get("base_url") or "").rstrip("/") + if _cp_url and _cp_url == _hyg_base_url.rstrip("/"): + _cp_models = _cp.get("models", {}) + if isinstance(_cp_models, dict): + _cp_model_cfg = _cp_models.get(_hyg_model, {}) + if isinstance(_cp_model_cfg, dict): + _cp_ctx = _cp_model_cfg.get("context_length") + if _cp_ctx is not None: + _hyg_config_context_length = int(_cp_ctx) + break except (TypeError, ValueError): pass except Exception: @@ -4296,7 +4301,11 @@ class GatewayRunner: current_provider = model_cfg.get("provider", current_provider) current_base_url = model_cfg.get("base_url", "") user_provs = cfg.get("providers") - custom_provs = cfg.get("custom_providers") + try: + from hermes_cli.config import get_compatible_custom_providers + custom_provs = get_compatible_custom_providers(cfg) + except Exception: + custom_provs = cfg.get("custom_providers") except Exception: pass diff --git a/hermes_cli/auth_commands.py b/hermes_cli/auth_commands.py index 0532faa770..c1cf0ff618 100644 --- a/hermes_cli/auth_commands.py +++ b/hermes_cli/auth_commands.py @@ -36,25 +36,23 @@ _OAUTH_CAPABLE_PROVIDERS = {"anthropic", "nous", "openai-codex", "qwen-oauth"} def _get_custom_provider_names() -> list: - """Return list of (display_name, pool_key) tuples for custom_providers in config.""" + """Return list of (display_name, pool_key, provider_key) tuples.""" try: - from hermes_cli.config import load_config + from hermes_cli.config import get_compatible_custom_providers, load_config config = load_config() except Exception: return [] - custom_providers = config.get("custom_providers") - if not isinstance(custom_providers, list): - return [] result = [] - for entry in custom_providers: + for entry in get_compatible_custom_providers(config): if not isinstance(entry, dict): continue name = entry.get("name") if not isinstance(name, str) or not name.strip(): continue pool_key = f"{CUSTOM_POOL_PREFIX}{_normalize_custom_pool_name(name)}" - result.append((name.strip(), pool_key)) + provider_key = str(entry.get("provider_key", "") or "").strip() + result.append((name.strip(), pool_key, provider_key)) return result @@ -66,9 +64,11 @@ def _resolve_custom_provider_input(raw: str) -> str | None: # Direct match on 'custom:name' format if normalized.startswith(CUSTOM_POOL_PREFIX): return normalized - for display_name, pool_key in _get_custom_provider_names(): + for display_name, pool_key, provider_key in _get_custom_provider_names(): if _normalize_custom_pool_name(display_name) == normalized: return pool_key + if provider_key and provider_key.strip().lower() == normalized: + return pool_key return None @@ -405,7 +405,7 @@ def _pick_provider(prompt: str = "Provider") -> str: known = sorted(set(list(PROVIDER_REGISTRY.keys()) + ["openrouter"])) custom_names = _get_custom_provider_names() if custom_names: - custom_display = [name for name, _key in custom_names] + custom_display = [name for name, _key, _provider_key in custom_names] print(f"\nKnown providers: {', '.join(known)}") print(f"Custom endpoints: {', '.join(custom_display)}") else: diff --git a/hermes_cli/config.py b/hermes_cli/config.py index ef4e04b716..f524e792a5 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -1544,6 +1544,136 @@ def get_missing_skill_config_vars() -> List[Dict[str, Any]]: return missing +def _normalize_custom_provider_entry( + entry: Any, + *, + provider_key: str = "", +) -> Optional[Dict[str, Any]]: + """Return a runtime-compatible custom provider entry or ``None``.""" + if not isinstance(entry, dict): + return None + + base_url = "" + for url_key in ("api", "url", "base_url"): + raw_url = entry.get(url_key) + if isinstance(raw_url, str) and raw_url.strip(): + base_url = raw_url.strip() + break + if not base_url: + return None + + name = "" + raw_name = entry.get("name") + if isinstance(raw_name, str) and raw_name.strip(): + name = raw_name.strip() + elif provider_key.strip(): + name = provider_key.strip() + if not name: + return None + + normalized: Dict[str, Any] = { + "name": name, + "base_url": base_url, + } + + provider_key = provider_key.strip() + if provider_key: + normalized["provider_key"] = provider_key + + api_key = entry.get("api_key") + if isinstance(api_key, str) and api_key.strip(): + normalized["api_key"] = api_key.strip() + + key_env = entry.get("key_env") + if isinstance(key_env, str) and key_env.strip(): + normalized["key_env"] = key_env.strip() + + api_mode = entry.get("api_mode") or entry.get("transport") + if isinstance(api_mode, str) and api_mode.strip(): + normalized["api_mode"] = api_mode.strip() + + model_name = entry.get("model") or entry.get("default_model") + if isinstance(model_name, str) and model_name.strip(): + normalized["model"] = model_name.strip() + + models = entry.get("models") + if isinstance(models, dict) and models: + normalized["models"] = models + + context_length = entry.get("context_length") + if isinstance(context_length, int) and context_length > 0: + normalized["context_length"] = context_length + + rate_limit_delay = entry.get("rate_limit_delay") + if isinstance(rate_limit_delay, (int, float)) and rate_limit_delay >= 0: + normalized["rate_limit_delay"] = rate_limit_delay + + return normalized + + +def providers_dict_to_custom_providers(providers_dict: Any) -> List[Dict[str, Any]]: + """Normalize ``providers`` config entries into the legacy custom-provider shape.""" + if not isinstance(providers_dict, dict): + return [] + + custom_providers: List[Dict[str, Any]] = [] + for key, entry in providers_dict.items(): + normalized = _normalize_custom_provider_entry(entry, provider_key=str(key)) + if normalized is not None: + custom_providers.append(normalized) + + return custom_providers + + +def get_compatible_custom_providers( + config: Optional[Dict[str, Any]] = None, +) -> List[Dict[str, Any]]: + """Return a deduplicated custom-provider view across legacy and v12+ config. + + ``custom_providers`` remains the on-disk legacy format, while ``providers`` + is the newer keyed schema. Runtime and picker flows still need a single + list-shaped view, but we should not materialise that compatibility layer + back into config.yaml because it duplicates entries in UIs. + """ + if config is None: + config = load_config() + + compatible: List[Dict[str, Any]] = [] + seen_provider_keys: set = set() + seen_name_url_pairs: set = set() + + def _append_if_new(entry: Optional[Dict[str, Any]]) -> None: + if entry is None: + return + provider_key = str(entry.get("provider_key", "") or "").strip().lower() + name = str(entry.get("name", "") or "").strip().lower() + base_url = str(entry.get("base_url", "") or "").strip().rstrip("/").lower() + pair = (name, base_url) + + if provider_key and provider_key in seen_provider_keys: + return + if name and base_url and pair in seen_name_url_pairs: + return + + compatible.append(entry) + if provider_key: + seen_provider_keys.add(provider_key) + if name and base_url: + seen_name_url_pairs.add(pair) + + custom_providers = config.get("custom_providers") + if custom_providers is not None: + if not isinstance(custom_providers, list): + return [] + for entry in custom_providers: + _append_if_new(_normalize_custom_provider_entry(entry)) + + for entry in providers_dict_to_custom_providers(config.get("providers")): + _append_if_new(entry) + + return compatible + + def check_config_version() -> Tuple[int, int]: """ Check config version. @@ -1861,8 +1991,8 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A if migrated_count > 0: config["providers"] = providers_dict - # Remove the old list - del config["custom_providers"] + # Remove the old list — runtime reads via get_compatible_custom_providers() + config.pop("custom_providers", None) save_config(config) if not quiet: print(f" ✓ Migrated {migrated_count} custom provider(s) to providers: section") diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 24ba11f20f..f653b4cd07 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -999,7 +999,7 @@ def select_provider_and_model(args=None): from hermes_cli.auth import ( resolve_provider, AuthError, format_auth_error, ) - from hermes_cli.config import load_config, get_env_value + from hermes_cli.config import get_compatible_custom_providers, load_config, get_env_value config = load_config() current_model = config.get("model") @@ -1090,11 +1090,8 @@ def select_provider_and_model(args=None): ] def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]: - custom_providers_cfg = cfg.get("custom_providers") or [] custom_provider_map = {} - if not isinstance(custom_providers_cfg, list): - return custom_provider_map - for entry in custom_providers_cfg: + for entry in get_compatible_custom_providers(cfg): if not isinstance(entry, dict): continue name = (entry.get("name") or "").strip() @@ -1102,12 +1099,20 @@ def select_provider_and_model(args=None): if not name or not base_url: continue key = "custom:" + name.lower().replace(" ", "-") + provider_key = (entry.get("provider_key") or "").strip() + if provider_key: + try: + resolve_provider(provider_key) + except AuthError: + key = provider_key custom_provider_map[key] = { "name": name, "base_url": base_url, "api_key": entry.get("api_key", ""), + "key_env": entry.get("key_env", ""), "model": entry.get("model", ""), "api_mode": entry.get("api_mode", ""), + "provider_key": provider_key, } return custom_provider_map @@ -1157,7 +1162,8 @@ def select_provider_and_model(args=None): if selected_provider == "more": ext_ordered = list(extended_providers) ext_ordered.append(("custom", "Custom endpoint (enter URL manually)")) - if _custom_provider_map: + _has_saved_custom_list = isinstance(config.get("custom_providers"), list) and bool(config.get("custom_providers")) + if _has_saved_custom_list: ext_ordered.append(("remove-custom", "Remove a saved custom provider")) ext_ordered.append(("cancel", "Cancel")) @@ -1184,7 +1190,7 @@ def select_provider_and_model(args=None): _model_flow_copilot(config, current_model) elif selected_provider == "custom": _model_flow_custom(config) - elif selected_provider.startswith("custom:"): + elif selected_provider.startswith("custom:") or selected_provider in _custom_provider_map: provider_info = _named_custom_provider_map(load_config()).get(selected_provider) if provider_info is None: print( @@ -1869,7 +1875,9 @@ def _model_flow_named_custom(config, provider_info): name = provider_info["name"] base_url = provider_info["base_url"] api_key = provider_info.get("api_key", "") + key_env = provider_info.get("key_env", "") saved_model = provider_info.get("model", "") + provider_key = (provider_info.get("provider_key") or "").strip() print(f" Provider: {name}") print(f" URL: {base_url}") @@ -1952,10 +1960,15 @@ def _model_flow_named_custom(config, provider_info): if not isinstance(model, dict): model = {"default": model} if model else {} cfg["model"] = model - model["provider"] = "custom" - model["base_url"] = base_url - if api_key: - model["api_key"] = api_key + if provider_key: + model["provider"] = provider_key + model.pop("base_url", None) + model.pop("api_key", None) + else: + model["provider"] = "custom" + model["base_url"] = base_url + if api_key: + model["api_key"] = api_key # Apply api_mode from custom_providers entry, or clear stale value custom_api_mode = provider_info.get("api_mode", "") if custom_api_mode: @@ -1965,8 +1978,23 @@ def _model_flow_named_custom(config, provider_info): save_config(cfg) deactivate_provider() - # Save model name to the custom_providers entry for next time - _save_custom_provider(base_url, api_key, model_name) + # Persist the selected model back to whichever schema owns this endpoint. + if provider_key: + cfg = load_config() + providers_cfg = cfg.get("providers") + if isinstance(providers_cfg, dict): + provider_entry = providers_cfg.get(provider_key) + if isinstance(provider_entry, dict): + provider_entry["default_model"] = model_name + if api_key and not str(provider_entry.get("api_key", "") or "").strip(): + provider_entry["api_key"] = api_key + if key_env and not str(provider_entry.get("key_env", "") or "").strip(): + provider_entry["key_env"] = key_env + cfg["providers"] = providers_cfg + save_config(cfg) + else: + # Save model name to the custom_providers entry for next time + _save_custom_provider(base_url, api_key, model_name) print(f"\n✅ Model set to: {model_name}") print(f" Provider: {name} ({base_url})") diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index d8854b893d..6957c80b6e 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -26,7 +26,7 @@ from hermes_cli.auth import ( resolve_external_process_provider_credentials, has_usable_secret, ) -from hermes_cli.config import load_config +from hermes_cli.config import get_compatible_custom_providers, load_config from hermes_constants import OPENROUTER_BASE_URL @@ -315,13 +315,16 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An # Fall back to custom_providers: list (legacy format) custom_providers = config.get("custom_providers") - if not isinstance(custom_providers, list): - if isinstance(custom_providers, dict): - logger.warning( - "custom_providers in config.yaml is a dict, not a list. " - "Each entry must be prefixed with '-' in YAML. " - "Run 'hermes doctor' for details." - ) + if isinstance(custom_providers, dict): + logger.warning( + "custom_providers in config.yaml is a dict, not a list. " + "Each entry must be prefixed with '-' in YAML. " + "Run 'hermes doctor' for details." + ) + return None + + custom_providers = get_compatible_custom_providers(config) + if not custom_providers: return None for entry in custom_providers: @@ -333,13 +336,21 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An continue name_norm = _normalize_custom_provider_name(name) menu_key = f"custom:{name_norm}" - if requested_norm not in {name_norm, menu_key}: + provider_key = str(entry.get("provider_key", "") or "").strip() + provider_key_norm = _normalize_custom_provider_name(provider_key) if provider_key else "" + provider_menu_key = f"custom:{provider_key_norm}" if provider_key_norm else "" + if requested_norm not in {name_norm, menu_key, provider_key_norm, provider_menu_key}: continue result = { "name": name.strip(), "base_url": base_url.strip(), "api_key": str(entry.get("api_key", "") or "").strip(), } + key_env = str(entry.get("key_env", "") or "").strip() + if key_env: + result["key_env"] = key_env + if provider_key: + result["provider_key"] = provider_key api_mode = _parse_api_mode(entry.get("api_mode")) if api_mode: result["api_mode"] = api_mode @@ -381,6 +392,7 @@ def _resolve_named_custom_runtime( api_key_candidates = [ (explicit_api_key or "").strip(), str(custom_provider.get("api_key", "") or "").strip(), + os.getenv(str(custom_provider.get("key_env", "") or "").strip(), "").strip(), os.getenv("OPENAI_API_KEY", "").strip(), os.getenv("OPENROUTER_API_KEY", "").strip(), ] diff --git a/run_agent.py b/run_agent.py index 89526320ec..64daad4c8b 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1267,24 +1267,29 @@ class AIAgent: # Check custom_providers per-model context_length if _config_context_length is None: - _custom_providers = _agent_cfg.get("custom_providers") - if isinstance(_custom_providers, list): - for _cp_entry in _custom_providers: - if not isinstance(_cp_entry, dict): - continue - _cp_url = (_cp_entry.get("base_url") or "").rstrip("/") - if _cp_url and _cp_url == self.base_url.rstrip("/"): - _cp_models = _cp_entry.get("models", {}) - if isinstance(_cp_models, dict): - _cp_model_cfg = _cp_models.get(self.model, {}) - if isinstance(_cp_model_cfg, dict): - _cp_ctx = _cp_model_cfg.get("context_length") - if _cp_ctx is not None: - try: - _config_context_length = int(_cp_ctx) - except (TypeError, ValueError): - pass - break + try: + from hermes_cli.config import get_compatible_custom_providers + _custom_providers = get_compatible_custom_providers(_agent_cfg) + except Exception: + _custom_providers = _agent_cfg.get("custom_providers") + if not isinstance(_custom_providers, list): + _custom_providers = [] + for _cp_entry in _custom_providers: + if not isinstance(_cp_entry, dict): + continue + _cp_url = (_cp_entry.get("base_url") or "").rstrip("/") + if _cp_url and _cp_url == self.base_url.rstrip("/"): + _cp_models = _cp_entry.get("models", {}) + if isinstance(_cp_models, dict): + _cp_model_cfg = _cp_models.get(self.model, {}) + if isinstance(_cp_model_cfg, dict): + _cp_ctx = _cp_model_cfg.get("context_length") + if _cp_ctx is not None: + try: + _config_context_length = int(_cp_ctx) + except (TypeError, ValueError): + pass + break # Select context engine: config-driven (like memory providers). # 1. Check config.yaml context.engine setting diff --git a/tests/hermes_cli/test_config.py b/tests/hermes_cli/test_config.py index d934a80125..397027d3a9 100644 --- a/tests/hermes_cli/test_config.py +++ b/tests/hermes_cli/test_config.py @@ -10,6 +10,7 @@ from hermes_cli.config import ( DEFAULT_CONFIG, get_hermes_home, ensure_hermes_home, + get_compatible_custom_providers, load_config, load_env, migrate_config, @@ -424,6 +425,146 @@ class TestAnthropicTokenMigration: assert load_env().get("ANTHROPIC_TOKEN") == "current-token" +class TestCustomProviderCompatibility: + """Custom provider compatibility across legacy and v12+ config schemas.""" + + def test_v11_upgrade_moves_custom_providers_into_providers(self, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "_config_version": 11, + "model": { + "default": "openai/gpt-5.4", + "provider": "openrouter", + }, + "custom_providers": [ + { + "name": "OpenAI Direct", + "base_url": "https://api.openai.com/v1", + "api_key": "test-key", + "api_mode": "codex_responses", + "model": "gpt-5-mini", + } + ], + "fallback_providers": [ + {"provider": "openai-direct", "model": "gpt-5-mini"} + ], + } + ), + encoding="utf-8", + ) + + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + migrate_config(interactive=False, quiet=True) + raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) + + assert raw["_config_version"] == 17 + assert raw["providers"]["openai-direct"] == { + "api": "https://api.openai.com/v1", + "api_key": "test-key", + "default_model": "gpt-5-mini", + "name": "OpenAI Direct", + "transport": "codex_responses", + } + # custom_providers removed by migration — runtime reads via compat layer + assert "custom_providers" not in raw + + def test_providers_dict_resolves_at_runtime(self, tmp_path): + """After migration deleted custom_providers, get_compatible_custom_providers + still finds entries from the providers dict.""" + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "_config_version": 17, + "providers": { + "openai-direct": { + "api": "https://api.openai.com/v1", + "api_key": "test-key", + "default_model": "gpt-5-mini", + "name": "OpenAI Direct", + "transport": "codex_responses", + } + }, + } + ), + encoding="utf-8", + ) + + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + compatible = get_compatible_custom_providers() + + assert len(compatible) == 1 + assert compatible[0]["name"] == "OpenAI Direct" + assert compatible[0]["base_url"] == "https://api.openai.com/v1" + assert compatible[0]["provider_key"] == "openai-direct" + assert compatible[0]["api_mode"] == "codex_responses" + + def test_compatible_custom_providers_prefers_api_then_url_then_base_url(self, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "_config_version": 17, + "providers": { + "my-provider": { + "name": "My Provider", + "api": "https://api.example.com/v1", + "url": "https://url.example.com/v1", + "base_url": "https://base.example.com/v1", + } + }, + } + ), + encoding="utf-8", + ) + + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + compatible = get_compatible_custom_providers() + + assert compatible == [ + { + "name": "My Provider", + "base_url": "https://api.example.com/v1", + "provider_key": "my-provider", + } + ] + + def test_dedup_across_legacy_and_providers(self, tmp_path): + """Same name+url in both schemas should not produce duplicates.""" + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "_config_version": 17, + "custom_providers": [ + { + "name": "OpenAI Direct", + "base_url": "https://api.openai.com/v1", + "api_key": "legacy-key", + } + ], + "providers": { + "openai-direct": { + "api": "https://api.openai.com/v1", + "api_key": "new-key", + "name": "OpenAI Direct", + } + }, + } + ), + encoding="utf-8", + ) + + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + compatible = get_compatible_custom_providers() + + assert len(compatible) == 1 + # Legacy entry wins (read first) + assert compatible[0]["api_key"] == "legacy-key" + + class TestInterimAssistantMessageConfig: """Test the explicit gateway interim-message config gate.""" @@ -441,6 +582,6 @@ class TestInterimAssistantMessageConfig: migrate_config(interactive=False, quiet=True) raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) - assert raw["_config_version"] == 16 + assert raw["_config_version"] == 17 assert raw["display"]["tool_progress"] == "off" assert raw["display"]["interim_assistant_messages"] is True diff --git a/tests/hermes_cli/test_runtime_provider_resolution.py b/tests/hermes_cli/test_runtime_provider_resolution.py index 20486a805b..c7510a55b8 100644 --- a/tests/hermes_cli/test_runtime_provider_resolution.py +++ b/tests/hermes_cli/test_runtime_provider_resolution.py @@ -119,6 +119,11 @@ def test_resolve_runtime_provider_falls_back_when_pool_empty(monkeypatch): def test_resolve_runtime_provider_codex(monkeypatch): + monkeypatch.setattr( + rp, + "load_pool", + lambda provider: type("P", (), {"has_credentials": lambda self: False})(), + ) monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openai-codex") monkeypatch.setattr( rp, @@ -567,6 +572,87 @@ def test_named_custom_provider_uses_saved_credentials(monkeypatch): assert resolved["source"] == "custom_provider:Local" +def test_named_custom_provider_uses_providers_dict_when_list_missing(monkeypatch): + """After v11→v12 migration deletes custom_providers, resolution should + still find entries in the providers dict via get_compatible_custom_providers.""" + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + monkeypatch.setattr( + rp, + "load_config", + lambda: { + "providers": { + "openai-direct-primary": { + "api": "https://api.openai.com/v1", + "api_key": "dir-key", + "default_model": "gpt-5-mini", + "name": "OpenAI Direct (Primary)", + "transport": "codex_responses", + } + } + }, + ) + monkeypatch.setattr( + rp, + "resolve_provider", + lambda *a, **k: (_ for _ in ()).throw( + AssertionError( + "resolve_provider should not be called for named custom providers" + ) + ), + ) + + resolved = rp.resolve_runtime_provider(requested="openai-direct-primary") + + assert resolved["provider"] == "custom" + assert resolved["api_mode"] == "codex_responses" + assert resolved["base_url"] == "https://api.openai.com/v1" + assert resolved["api_key"] == "dir-key" + assert resolved["requested_provider"] == "openai-direct-primary" + assert resolved["source"] == "custom_provider:OpenAI Direct (Primary)" + assert resolved["model"] == "gpt-5-mini" + + +def test_named_custom_provider_uses_key_env_from_providers_dict(monkeypatch): + """providers dict entries with key_env should resolve API key from env var.""" + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + monkeypatch.setenv("MYCORP_API_KEY", "env-secret") + monkeypatch.setattr( + rp, + "load_config", + lambda: { + "providers": { + "mycorp-proxy": { + "base_url": "https://proxy.example.com/v1", + "default_model": "acme-large", + "key_env": "MYCORP_API_KEY", + "name": "MyCorp Proxy", + } + } + }, + ) + monkeypatch.setattr( + rp, + "resolve_provider", + lambda *a, **k: (_ for _ in ()).throw( + AssertionError( + "resolve_provider should not be called for named custom providers" + ) + ), + ) + + resolved = rp.resolve_runtime_provider(requested="mycorp-proxy") + + assert resolved["provider"] == "custom" + assert resolved["api_mode"] == "chat_completions" + assert resolved["base_url"] == "https://proxy.example.com/v1" + assert resolved["api_key"] == "env-secret" + assert resolved["requested_provider"] == "mycorp-proxy" + assert resolved["source"] == "custom_provider:MyCorp Proxy" + assert resolved["model"] == "acme-large" + + def test_named_custom_provider_falls_back_to_openai_api_key(monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "env-openai-key") monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)