mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
fix(config): preserve custom provider api key refs
This commit is contained in:
@@ -1527,6 +1527,22 @@ def select_provider_and_model(args=None):
|
|||||||
all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS]
|
all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS]
|
||||||
|
|
||||||
def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]:
|
def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]:
|
||||||
|
from hermes_cli.config import read_raw_config
|
||||||
|
|
||||||
|
def _identity(entry):
|
||||||
|
return (
|
||||||
|
str(entry.get("provider_key", "") or "").strip(),
|
||||||
|
str(entry.get("name", "") or "").strip(),
|
||||||
|
str(entry.get("base_url", "") or "").strip().rstrip("/"),
|
||||||
|
str(entry.get("model", "") or "").strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_api_key_refs = {}
|
||||||
|
for raw_entry in get_compatible_custom_providers(read_raw_config()):
|
||||||
|
raw_api_key = str(raw_entry.get("api_key", "") or "").strip()
|
||||||
|
if "${" in raw_api_key:
|
||||||
|
raw_api_key_refs[_identity(raw_entry)] = raw_api_key
|
||||||
|
|
||||||
custom_provider_map = {}
|
custom_provider_map = {}
|
||||||
for entry in get_compatible_custom_providers(cfg):
|
for entry in get_compatible_custom_providers(cfg):
|
||||||
if not isinstance(entry, dict):
|
if not isinstance(entry, dict):
|
||||||
@@ -1550,6 +1566,7 @@ def select_provider_and_model(args=None):
|
|||||||
"model": entry.get("model", ""),
|
"model": entry.get("model", ""),
|
||||||
"api_mode": entry.get("api_mode", ""),
|
"api_mode": entry.get("api_mode", ""),
|
||||||
"provider_key": provider_key,
|
"provider_key": provider_key,
|
||||||
|
"api_key_ref": raw_api_key_refs.get(_identity(entry), ""),
|
||||||
}
|
}
|
||||||
return custom_provider_map
|
return custom_provider_map
|
||||||
|
|
||||||
@@ -2782,6 +2799,19 @@ def _auto_provider_name(base_url: str) -> str:
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _custom_provider_api_key_config_value(provider_info, resolved_api_key=""):
|
||||||
|
"""Return the value that should be persisted for a custom provider key."""
|
||||||
|
api_key_ref = str(provider_info.get("api_key_ref", "") or "").strip()
|
||||||
|
if api_key_ref:
|
||||||
|
return api_key_ref
|
||||||
|
|
||||||
|
key_env = str(provider_info.get("key_env", "") or "").strip()
|
||||||
|
if key_env and not str(provider_info.get("api_key", "") or "").strip():
|
||||||
|
return f"${{{key_env}}}"
|
||||||
|
|
||||||
|
return str(resolved_api_key or "").strip()
|
||||||
|
|
||||||
|
|
||||||
def _save_custom_provider(
|
def _save_custom_provider(
|
||||||
base_url, api_key="", model="", context_length=None, name=None
|
base_url, api_key="", model="", context_length=None, name=None
|
||||||
):
|
):
|
||||||
@@ -2923,6 +2953,7 @@ def _model_flow_named_custom(config, provider_info):
|
|||||||
# Resolve key from env var if api_key not set directly
|
# Resolve key from env var if api_key not set directly
|
||||||
if not api_key and key_env:
|
if not api_key and key_env:
|
||||||
api_key = os.environ.get(key_env, "")
|
api_key = os.environ.get(key_env, "")
|
||||||
|
config_api_key = _custom_provider_api_key_config_value(provider_info, api_key)
|
||||||
|
|
||||||
print(f" Provider: {name}")
|
print(f" Provider: {name}")
|
||||||
print(f" URL: {base_url}")
|
print(f" URL: {base_url}")
|
||||||
@@ -3019,8 +3050,8 @@ def _model_flow_named_custom(config, provider_info):
|
|||||||
else:
|
else:
|
||||||
model["provider"] = "custom"
|
model["provider"] = "custom"
|
||||||
model["base_url"] = base_url
|
model["base_url"] = base_url
|
||||||
if api_key:
|
if config_api_key:
|
||||||
model["api_key"] = api_key
|
model["api_key"] = config_api_key
|
||||||
# Apply api_mode from custom_providers entry, or clear stale value
|
# Apply api_mode from custom_providers entry, or clear stale value
|
||||||
custom_api_mode = provider_info.get("api_mode", "")
|
custom_api_mode = provider_info.get("api_mode", "")
|
||||||
if custom_api_mode:
|
if custom_api_mode:
|
||||||
@@ -3038,15 +3069,15 @@ def _model_flow_named_custom(config, provider_info):
|
|||||||
provider_entry = providers_cfg.get(provider_key)
|
provider_entry = providers_cfg.get(provider_key)
|
||||||
if isinstance(provider_entry, dict):
|
if isinstance(provider_entry, dict):
|
||||||
provider_entry["default_model"] = model_name
|
provider_entry["default_model"] = model_name
|
||||||
if api_key and not str(provider_entry.get("api_key", "") or "").strip():
|
if config_api_key and not str(provider_entry.get("api_key", "") or "").strip():
|
||||||
provider_entry["api_key"] = api_key
|
provider_entry["api_key"] = config_api_key
|
||||||
if key_env and not str(provider_entry.get("key_env", "") or "").strip():
|
if key_env and not str(provider_entry.get("key_env", "") or "").strip():
|
||||||
provider_entry["key_env"] = key_env
|
provider_entry["key_env"] = key_env
|
||||||
cfg["providers"] = providers_cfg
|
cfg["providers"] = providers_cfg
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
else:
|
else:
|
||||||
# Save model name to the custom_providers entry for next time
|
# Save model name to the custom_providers entry for next time
|
||||||
_save_custom_provider(base_url, api_key, model_name)
|
_save_custom_provider(base_url, config_api_key, model_name)
|
||||||
|
|
||||||
print(f"\n✅ Model set to: {model_name}")
|
print(f"\n✅ Model set to: {model_name}")
|
||||||
print(f" Provider: {name} ({base_url})")
|
print(f" Provider: {name} ({base_url})")
|
||||||
|
|||||||
@@ -52,7 +52,12 @@ class TestCustomProviderModelSwitch:
|
|||||||
_model_flow_named_custom({}, provider_info)
|
_model_flow_named_custom({}, provider_info)
|
||||||
|
|
||||||
# fetch_api_models MUST be called even though model was saved
|
# fetch_api_models MUST be called even though model was saved
|
||||||
mock_fetch.assert_called_once_with("sk-test", "https://vllm.example.com/v1", timeout=8.0)
|
mock_fetch.assert_called_once_with(
|
||||||
|
"sk-test",
|
||||||
|
"https://vllm.example.com/v1",
|
||||||
|
timeout=8.0,
|
||||||
|
api_mode=None,
|
||||||
|
)
|
||||||
|
|
||||||
def test_can_switch_to_different_model(self, config_home):
|
def test_can_switch_to_different_model(self, config_home):
|
||||||
"""User selects a different model than the saved one."""
|
"""User selects a different model than the saved one."""
|
||||||
@@ -173,3 +178,82 @@ class TestCustomProviderModelSwitch:
|
|||||||
model = config.get("model")
|
model = config.get("model")
|
||||||
assert isinstance(model, dict)
|
assert isinstance(model, dict)
|
||||||
assert "api_mode" not in model, "Stale api_mode should be removed"
|
assert "api_mode" not in model, "Stale api_mode should be removed"
|
||||||
|
|
||||||
|
def test_env_template_api_key_is_preserved_in_model_config(self, config_home, monkeypatch):
|
||||||
|
"""Selecting an env-backed custom provider must not inline the secret."""
|
||||||
|
import yaml
|
||||||
|
from hermes_cli.main import _model_flow_named_custom
|
||||||
|
|
||||||
|
config_path = config_home / "config.yaml"
|
||||||
|
config_path.write_text(
|
||||||
|
"model:\n"
|
||||||
|
" default: old-model\n"
|
||||||
|
" provider: openrouter\n"
|
||||||
|
"custom_providers:\n"
|
||||||
|
"- name: Example Provider\n"
|
||||||
|
" base_url: https://api.example-provider.test/v1\n"
|
||||||
|
" api_key: ${EXAMPLE_PROVIDER_API_KEY}\n"
|
||||||
|
" model: qwen3.6-35b-fast\n"
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("EXAMPLE_PROVIDER_API_KEY", "sk-live-example-provider")
|
||||||
|
|
||||||
|
provider_info = {
|
||||||
|
"name": "Example Provider",
|
||||||
|
"base_url": "https://api.example-provider.test/v1",
|
||||||
|
"api_key": "sk-live-example-provider",
|
||||||
|
"api_key_ref": "${EXAMPLE_PROVIDER_API_KEY}",
|
||||||
|
"model": "qwen3.6-35b-fast",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("hermes_cli.models.fetch_api_models", return_value=["qwen3.6-35b-fast"]) as mock_fetch, \
|
||||||
|
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||||
|
patch("builtins.input", return_value="1"), \
|
||||||
|
patch("builtins.print"):
|
||||||
|
_model_flow_named_custom({}, provider_info)
|
||||||
|
|
||||||
|
mock_fetch.assert_called_once_with(
|
||||||
|
"sk-live-example-provider",
|
||||||
|
"https://api.example-provider.test/v1",
|
||||||
|
timeout=8.0,
|
||||||
|
api_mode=None,
|
||||||
|
)
|
||||||
|
config = yaml.safe_load(config_path.read_text()) or {}
|
||||||
|
assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
|
||||||
|
assert config["custom_providers"][0]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
|
||||||
|
assert "sk-live-example-provider" not in config_path.read_text()
|
||||||
|
|
||||||
|
def test_key_env_custom_provider_persists_reference_not_secret(self, config_home, monkeypatch):
|
||||||
|
"""key_env custom providers should also avoid writing plaintext keys."""
|
||||||
|
import yaml
|
||||||
|
from hermes_cli.main import _model_flow_named_custom
|
||||||
|
|
||||||
|
config_path = config_home / "config.yaml"
|
||||||
|
config_path.write_text(
|
||||||
|
"model:\n"
|
||||||
|
" default: old-model\n"
|
||||||
|
"custom_providers:\n"
|
||||||
|
"- name: Example Provider\n"
|
||||||
|
" base_url: https://api.example-provider.test/v1\n"
|
||||||
|
" key_env: EXAMPLE_PROVIDER_API_KEY\n"
|
||||||
|
" model: qwen3.6-35b-fast\n"
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("EXAMPLE_PROVIDER_API_KEY", "sk-live-example-provider")
|
||||||
|
|
||||||
|
provider_info = {
|
||||||
|
"name": "Example Provider",
|
||||||
|
"base_url": "https://api.example-provider.test/v1",
|
||||||
|
"api_key": "",
|
||||||
|
"key_env": "EXAMPLE_PROVIDER_API_KEY",
|
||||||
|
"model": "qwen3.6-35b-fast",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("hermes_cli.models.fetch_api_models", return_value=["qwen3.6-35b-fast"]), \
|
||||||
|
patch.dict("sys.modules", {"simple_term_menu": None}), \
|
||||||
|
patch("builtins.input", return_value="1"), \
|
||||||
|
patch("builtins.print"):
|
||||||
|
_model_flow_named_custom({}, provider_info)
|
||||||
|
|
||||||
|
config = yaml.safe_load(config_path.read_text()) or {}
|
||||||
|
assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}"
|
||||||
|
assert config["custom_providers"][0]["key_env"] == "EXAMPLE_PROVIDER_API_KEY"
|
||||||
|
assert "sk-live-example-provider" not in config_path.read_text()
|
||||||
|
|||||||
Reference in New Issue
Block a user