fix(config): preserve custom provider api key refs

This commit is contained in:
helix4u
2026-04-25 18:21:53 -06:00
committed by Teknium
parent 2c56dce0ed
commit 1fdc31b214
2 changed files with 121 additions and 6 deletions

View File

@@ -1527,6 +1527,22 @@ def select_provider_and_model(args=None):
all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS] all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS]
def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]: def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]:
from hermes_cli.config import read_raw_config
def _identity(entry):
return (
str(entry.get("provider_key", "") or "").strip(),
str(entry.get("name", "") or "").strip(),
str(entry.get("base_url", "") or "").strip().rstrip("/"),
str(entry.get("model", "") or "").strip(),
)
raw_api_key_refs = {}
for raw_entry in get_compatible_custom_providers(read_raw_config()):
raw_api_key = str(raw_entry.get("api_key", "") or "").strip()
if "${" in raw_api_key:
raw_api_key_refs[_identity(raw_entry)] = raw_api_key
custom_provider_map = {} custom_provider_map = {}
for entry in get_compatible_custom_providers(cfg): for entry in get_compatible_custom_providers(cfg):
if not isinstance(entry, dict): if not isinstance(entry, dict):
@@ -1550,6 +1566,7 @@ def select_provider_and_model(args=None):
"model": entry.get("model", ""), "model": entry.get("model", ""),
"api_mode": entry.get("api_mode", ""), "api_mode": entry.get("api_mode", ""),
"provider_key": provider_key, "provider_key": provider_key,
"api_key_ref": raw_api_key_refs.get(_identity(entry), ""),
} }
return custom_provider_map return custom_provider_map
@@ -2782,6 +2799,19 @@ def _auto_provider_name(base_url: str) -> str:
return name return name
def _custom_provider_api_key_config_value(provider_info, resolved_api_key=""):
"""Return the value that should be persisted for a custom provider key."""
api_key_ref = str(provider_info.get("api_key_ref", "") or "").strip()
if api_key_ref:
return api_key_ref
key_env = str(provider_info.get("key_env", "") or "").strip()
if key_env and not str(provider_info.get("api_key", "") or "").strip():
return f"${{{key_env}}}"
return str(resolved_api_key or "").strip()
def _save_custom_provider( def _save_custom_provider(
base_url, api_key="", model="", context_length=None, name=None base_url, api_key="", model="", context_length=None, name=None
): ):
@@ -2923,6 +2953,7 @@ def _model_flow_named_custom(config, provider_info):
# Resolve key from env var if api_key not set directly # Resolve key from env var if api_key not set directly
if not api_key and key_env: if not api_key and key_env:
api_key = os.environ.get(key_env, "") api_key = os.environ.get(key_env, "")
config_api_key = _custom_provider_api_key_config_value(provider_info, api_key)
print(f" Provider: {name}") print(f" Provider: {name}")
print(f" URL: {base_url}") print(f" URL: {base_url}")
@@ -3019,8 +3050,8 @@ def _model_flow_named_custom(config, provider_info):
else: else:
model["provider"] = "custom" model["provider"] = "custom"
model["base_url"] = base_url model["base_url"] = base_url
if api_key: if config_api_key:
model["api_key"] = api_key model["api_key"] = config_api_key
# Apply api_mode from custom_providers entry, or clear stale value # Apply api_mode from custom_providers entry, or clear stale value
custom_api_mode = provider_info.get("api_mode", "") custom_api_mode = provider_info.get("api_mode", "")
if custom_api_mode: if custom_api_mode:
@@ -3038,15 +3069,15 @@ def _model_flow_named_custom(config, provider_info):
provider_entry = providers_cfg.get(provider_key) provider_entry = providers_cfg.get(provider_key)
if isinstance(provider_entry, dict): if isinstance(provider_entry, dict):
provider_entry["default_model"] = model_name provider_entry["default_model"] = model_name
if api_key and not str(provider_entry.get("api_key", "") or "").strip(): if config_api_key and not str(provider_entry.get("api_key", "") or "").strip():
provider_entry["api_key"] = api_key provider_entry["api_key"] = config_api_key
if key_env and not str(provider_entry.get("key_env", "") or "").strip(): if key_env and not str(provider_entry.get("key_env", "") or "").strip():
provider_entry["key_env"] = key_env provider_entry["key_env"] = key_env
cfg["providers"] = providers_cfg cfg["providers"] = providers_cfg
save_config(cfg) save_config(cfg)
else: else:
# Save model name to the custom_providers entry for next time # Save model name to the custom_providers entry for next time
_save_custom_provider(base_url, api_key, model_name) _save_custom_provider(base_url, config_api_key, model_name)
print(f"\n✅ Model set to: {model_name}") print(f"\n✅ Model set to: {model_name}")
print(f" Provider: {name} ({base_url})") print(f" Provider: {name} ({base_url})")

View File

@@ -52,7 +52,12 @@ class TestCustomProviderModelSwitch:
_model_flow_named_custom({}, provider_info) _model_flow_named_custom({}, provider_info)
# fetch_api_models MUST be called even though model was saved # fetch_api_models MUST be called even though model was saved
mock_fetch.assert_called_once_with("sk-test", "https://vllm.example.com/v1", timeout=8.0) mock_fetch.assert_called_once_with(
"sk-test",
"https://vllm.example.com/v1",
timeout=8.0,
api_mode=None,
)
def test_can_switch_to_different_model(self, config_home): def test_can_switch_to_different_model(self, config_home):
"""User selects a different model than the saved one.""" """User selects a different model than the saved one."""
@@ -173,3 +178,82 @@ class TestCustomProviderModelSwitch:
model = config.get("model") model = config.get("model")
assert isinstance(model, dict) assert isinstance(model, dict)
assert "api_mode" not in model, "Stale api_mode should be removed" assert "api_mode" not in model, "Stale api_mode should be removed"
def test_env_template_api_key_is_preserved_in_model_config(self, config_home, monkeypatch):
"""Selecting an env-backed custom provider must not inline the secret."""
import yaml
from hermes_cli.main import _model_flow_named_custom
config_path = config_home / "config.yaml"
config_path.write_text(
"model:\n"
" default: old-model\n"
" provider: openrouter\n"
"custom_providers:\n"
"- name: Example Provider\n"
" base_url: https://api.example-provider.test/v1\n"
" api_key: ${EXAMPLE_PROVIDER_API_KEY}\n"
" model: qwen3.6-35b-fast\n"
)
monkeypatch.setenv("EXAMPLE_PROVIDER_API_KEY", "sk-live-example-provider")
provider_info = {
"name": "Example Provider",
"base_url": "https://api.example-provider.test/v1",
"api_key": "sk-live-example-provider",
"api_key_ref": "${EXAMPLE_PROVIDER_API_KEY}",
"model": "qwen3.6-35b-fast",
}
with patch("hermes_cli.models.fetch_api_models", return_value=["qwen3.6-35b-fast"]) as mock_fetch, \
patch.dict("sys.modules", {"simple_term_menu": None}), \
patch("builtins.input", return_value="1"), \
patch("builtins.print"):
_model_flow_named_custom({}, provider_info)
mock_fetch.assert_called_once_with(
"sk-live-example-provider",
"https://api.example-provider.test/v1",
timeout=8.0,
api_mode=None,
)
config = yaml.safe_load(config_path.read_text()) or {}
assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
assert config["custom_providers"][0]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
assert "sk-live-example-provider" not in config_path.read_text()
def test_key_env_custom_provider_persists_reference_not_secret(self, config_home, monkeypatch):
"""key_env custom providers should also avoid writing plaintext keys."""
import yaml
from hermes_cli.main import _model_flow_named_custom
config_path = config_home / "config.yaml"
config_path.write_text(
"model:\n"
" default: old-model\n"
"custom_providers:\n"
"- name: Example Provider\n"
" base_url: https://api.example-provider.test/v1\n"
" key_env: EXAMPLE_PROVIDER_API_KEY\n"
" model: qwen3.6-35b-fast\n"
)
monkeypatch.setenv("EXAMPLE_PROVIDER_API_KEY", "sk-live-example-provider")
provider_info = {
"name": "Example Provider",
"base_url": "https://api.example-provider.test/v1",
"api_key": "",
"key_env": "EXAMPLE_PROVIDER_API_KEY",
"model": "qwen3.6-35b-fast",
}
with patch("hermes_cli.models.fetch_api_models", return_value=["qwen3.6-35b-fast"]), \
patch.dict("sys.modules", {"simple_term_menu": None}), \
patch("builtins.input", return_value="1"), \
patch("builtins.print"):
_model_flow_named_custom({}, provider_info)
config = yaml.safe_load(config_path.read_text()) or {}
assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
assert config["custom_providers"][0]["key_env"] == "EXAMPLE_PROVIDER_API_KEY"
assert "sk-live-example-provider" not in config_path.read_text()