mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
fix(config): preserve custom provider api key refs
This commit is contained in:
@@ -1527,6 +1527,22 @@ def select_provider_and_model(args=None):
|
||||
all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS]
|
||||
|
||||
def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]:
|
||||
from hermes_cli.config import read_raw_config
|
||||
|
||||
def _identity(entry):
|
||||
return (
|
||||
str(entry.get("provider_key", "") or "").strip(),
|
||||
str(entry.get("name", "") or "").strip(),
|
||||
str(entry.get("base_url", "") or "").strip().rstrip("/"),
|
||||
str(entry.get("model", "") or "").strip(),
|
||||
)
|
||||
|
||||
raw_api_key_refs = {}
|
||||
for raw_entry in get_compatible_custom_providers(read_raw_config()):
|
||||
raw_api_key = str(raw_entry.get("api_key", "") or "").strip()
|
||||
if "${" in raw_api_key:
|
||||
raw_api_key_refs[_identity(raw_entry)] = raw_api_key
|
||||
|
||||
custom_provider_map = {}
|
||||
for entry in get_compatible_custom_providers(cfg):
|
||||
if not isinstance(entry, dict):
|
||||
@@ -1550,6 +1566,7 @@ def select_provider_and_model(args=None):
|
||||
"model": entry.get("model", ""),
|
||||
"api_mode": entry.get("api_mode", ""),
|
||||
"provider_key": provider_key,
|
||||
"api_key_ref": raw_api_key_refs.get(_identity(entry), ""),
|
||||
}
|
||||
return custom_provider_map
|
||||
|
||||
@@ -2782,6 +2799,19 @@ def _auto_provider_name(base_url: str) -> str:
|
||||
return name
|
||||
|
||||
|
||||
def _custom_provider_api_key_config_value(provider_info, resolved_api_key=""):
|
||||
"""Return the value that should be persisted for a custom provider key."""
|
||||
api_key_ref = str(provider_info.get("api_key_ref", "") or "").strip()
|
||||
if api_key_ref:
|
||||
return api_key_ref
|
||||
|
||||
key_env = str(provider_info.get("key_env", "") or "").strip()
|
||||
if key_env and not str(provider_info.get("api_key", "") or "").strip():
|
||||
return f"${{{key_env}}}"
|
||||
|
||||
return str(resolved_api_key or "").strip()
|
||||
|
||||
|
||||
def _save_custom_provider(
|
||||
base_url, api_key="", model="", context_length=None, name=None
|
||||
):
|
||||
@@ -2923,6 +2953,7 @@ def _model_flow_named_custom(config, provider_info):
|
||||
# Resolve key from env var if api_key not set directly
|
||||
if not api_key and key_env:
|
||||
api_key = os.environ.get(key_env, "")
|
||||
config_api_key = _custom_provider_api_key_config_value(provider_info, api_key)
|
||||
|
||||
print(f" Provider: {name}")
|
||||
print(f" URL: {base_url}")
|
||||
@@ -3019,8 +3050,8 @@ def _model_flow_named_custom(config, provider_info):
|
||||
else:
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = base_url
|
||||
if api_key:
|
||||
model["api_key"] = api_key
|
||||
if config_api_key:
|
||||
model["api_key"] = config_api_key
|
||||
# Apply api_mode from custom_providers entry, or clear stale value
|
||||
custom_api_mode = provider_info.get("api_mode", "")
|
||||
if custom_api_mode:
|
||||
@@ -3038,15 +3069,15 @@ def _model_flow_named_custom(config, provider_info):
|
||||
provider_entry = providers_cfg.get(provider_key)
|
||||
if isinstance(provider_entry, dict):
|
||||
provider_entry["default_model"] = model_name
|
||||
if api_key and not str(provider_entry.get("api_key", "") or "").strip():
|
||||
provider_entry["api_key"] = api_key
|
||||
if config_api_key and not str(provider_entry.get("api_key", "") or "").strip():
|
||||
provider_entry["api_key"] = config_api_key
|
||||
if key_env and not str(provider_entry.get("key_env", "") or "").strip():
|
||||
provider_entry["key_env"] = key_env
|
||||
cfg["providers"] = providers_cfg
|
||||
save_config(cfg)
|
||||
else:
|
||||
# Save model name to the custom_providers entry for next time
|
||||
_save_custom_provider(base_url, api_key, model_name)
|
||||
_save_custom_provider(base_url, config_api_key, model_name)
|
||||
|
||||
print(f"\n✅ Model set to: {model_name}")
|
||||
print(f" Provider: {name} ({base_url})")
|
||||
|
||||
@@ -52,7 +52,12 @@ class TestCustomProviderModelSwitch:
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
# fetch_api_models MUST be called even though model was saved
|
||||
mock_fetch.assert_called_once_with("sk-test", "https://vllm.example.com/v1", timeout=8.0)
|
||||
mock_fetch.assert_called_once_with(
|
||||
"sk-test",
|
||||
"https://vllm.example.com/v1",
|
||||
timeout=8.0,
|
||||
api_mode=None,
|
||||
)
|
||||
|
||||
def test_can_switch_to_different_model(self, config_home):
|
||||
"""User selects a different model than the saved one."""
|
||||
@@ -173,3 +178,82 @@ class TestCustomProviderModelSwitch:
|
||||
model = config.get("model")
|
||||
assert isinstance(model, dict)
|
||||
assert "api_mode" not in model, "Stale api_mode should be removed"
|
||||
|
||||
def test_env_template_api_key_is_preserved_in_model_config(self, config_home, monkeypatch):
|
||||
"""Selecting an env-backed custom provider must not inline the secret."""
|
||||
import yaml
|
||||
from hermes_cli.main import _model_flow_named_custom
|
||||
|
||||
config_path = config_home / "config.yaml"
|
||||
config_path.write_text(
|
||||
"model:\n"
|
||||
" default: old-model\n"
|
||||
" provider: openrouter\n"
|
||||
"custom_providers:\n"
|
||||
"- name: Example Provider\n"
|
||||
" base_url: https://api.example-provider.test/v1\n"
|
||||
" api_key: ${EXAMPLE_PROVIDER_API_KEY}\n"
|
||||
" model: qwen3.6-35b-fast\n"
|
||||
)
|
||||
monkeypatch.setenv("EXAMPLE_PROVIDER_API_KEY", "sk-live-example-provider")
|
||||
|
||||
provider_info = {
|
||||
"name": "Example Provider",
|
||||
"base_url": "https://api.example-provider.test/v1",
|
||||
"api_key": "sk-live-example-provider",
|
||||
"api_key_ref": "${EXAMPLE_PROVIDER_API_KEY}",
|
||||
"model": "qwen3.6-35b-fast",
|
||||
}
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=["qwen3.6-35b-fast"]) as mock_fetch, \
|
||||
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||
patch("builtins.input", return_value="1"), \
|
||||
patch("builtins.print"):
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
mock_fetch.assert_called_once_with(
|
||||
"sk-live-example-provider",
|
||||
"https://api.example-provider.test/v1",
|
||||
timeout=8.0,
|
||||
api_mode=None,
|
||||
)
|
||||
config = yaml.safe_load(config_path.read_text()) or {}
|
||||
assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
|
||||
assert config["custom_providers"][0]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
|
||||
assert "sk-live-example-provider" not in config_path.read_text()
|
||||
|
||||
def test_key_env_custom_provider_persists_reference_not_secret(self, config_home, monkeypatch):
|
||||
"""key_env custom providers should also avoid writing plaintext keys."""
|
||||
import yaml
|
||||
from hermes_cli.main import _model_flow_named_custom
|
||||
|
||||
config_path = config_home / "config.yaml"
|
||||
config_path.write_text(
|
||||
"model:\n"
|
||||
" default: old-model\n"
|
||||
"custom_providers:\n"
|
||||
"- name: Example Provider\n"
|
||||
" base_url: https://api.example-provider.test/v1\n"
|
||||
" key_env: EXAMPLE_PROVIDER_API_KEY\n"
|
||||
" model: qwen3.6-35b-fast\n"
|
||||
)
|
||||
monkeypatch.setenv("EXAMPLE_PROVIDER_API_KEY", "sk-live-example-provider")
|
||||
|
||||
provider_info = {
|
||||
"name": "Example Provider",
|
||||
"base_url": "https://api.example-provider.test/v1",
|
||||
"api_key": "",
|
||||
"key_env": "EXAMPLE_PROVIDER_API_KEY",
|
||||
"model": "qwen3.6-35b-fast",
|
||||
}
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=["qwen3.6-35b-fast"]), \
|
||||
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||
patch("builtins.input", return_value="1"), \
|
||||
patch("builtins.print"):
|
||||
_model_flow_named_custom({}, provider_info)
|
||||
|
||||
config = yaml.safe_load(config_path.read_text()) or {}
|
||||
assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
|
||||
assert config["custom_providers"][0]["key_env"] == "EXAMPLE_PROVIDER_API_KEY"
|
||||
assert "sk-live-example-provider" not in config_path.read_text()
|
||||
|
||||
Reference in New Issue
Block a user