mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-30 16:01:49 +08:00
Compare commits
2 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a844513608 | ||
|
|
da4ee6d193 |
22
cli.py
22
cli.py
@@ -4125,6 +4125,16 @@ class HermesCLI:
|
|||||||
# Parse --provider and --global flags
|
# Parse --provider and --global flags
|
||||||
model_input, explicit_provider, persist_global = parse_model_flags(raw_args)
|
model_input, explicit_provider, persist_global = parse_model_flags(raw_args)
|
||||||
|
|
||||||
|
user_provs = None
|
||||||
|
custom_provs = None
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import load_config
|
||||||
|
cfg = load_config()
|
||||||
|
user_provs = cfg.get("providers")
|
||||||
|
custom_provs = cfg.get("custom_providers")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# No args at all: show available providers + models
|
# No args at all: show available providers + models
|
||||||
if not model_input and not explicit_provider:
|
if not model_input and not explicit_provider:
|
||||||
model_display = self.model or "unknown"
|
model_display = self.model or "unknown"
|
||||||
@@ -4134,18 +4144,10 @@ class HermesCLI:
|
|||||||
|
|
||||||
# Show authenticated providers with top models
|
# Show authenticated providers with top models
|
||||||
try:
|
try:
|
||||||
# Load user providers from config
|
|
||||||
user_provs = None
|
|
||||||
try:
|
|
||||||
from hermes_cli.config import load_config
|
|
||||||
cfg = load_config()
|
|
||||||
user_provs = cfg.get("providers")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
providers = list_authenticated_providers(
|
providers = list_authenticated_providers(
|
||||||
current_provider=self.provider or "",
|
current_provider=self.provider or "",
|
||||||
user_providers=user_provs,
|
user_providers=user_provs,
|
||||||
|
custom_providers=custom_provs,
|
||||||
max_models=6,
|
max_models=6,
|
||||||
)
|
)
|
||||||
if providers:
|
if providers:
|
||||||
@@ -4186,6 +4188,8 @@ class HermesCLI:
|
|||||||
current_api_key=self.api_key or "",
|
current_api_key=self.api_key or "",
|
||||||
is_global=persist_global,
|
is_global=persist_global,
|
||||||
explicit_provider=explicit_provider,
|
explicit_provider=explicit_provider,
|
||||||
|
user_providers=user_provs,
|
||||||
|
custom_providers=custom_provs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.success:
|
if not result.success:
|
||||||
|
|||||||
@@ -3546,6 +3546,7 @@ class GatewayRunner:
|
|||||||
current_base_url = ""
|
current_base_url = ""
|
||||||
current_api_key = ""
|
current_api_key = ""
|
||||||
user_provs = None
|
user_provs = None
|
||||||
|
custom_provs = None
|
||||||
config_path = _hermes_home / "config.yaml"
|
config_path = _hermes_home / "config.yaml"
|
||||||
try:
|
try:
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
@@ -3557,6 +3558,7 @@ class GatewayRunner:
|
|||||||
current_provider = model_cfg.get("provider", current_provider)
|
current_provider = model_cfg.get("provider", current_provider)
|
||||||
current_base_url = model_cfg.get("base_url", "")
|
current_base_url = model_cfg.get("base_url", "")
|
||||||
user_provs = cfg.get("providers")
|
user_provs = cfg.get("providers")
|
||||||
|
custom_provs = cfg.get("custom_providers")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -3584,6 +3586,7 @@ class GatewayRunner:
|
|||||||
providers = list_authenticated_providers(
|
providers = list_authenticated_providers(
|
||||||
current_provider=current_provider,
|
current_provider=current_provider,
|
||||||
user_providers=user_provs,
|
user_providers=user_provs,
|
||||||
|
custom_providers=custom_provs,
|
||||||
max_models=50,
|
max_models=50,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -3611,6 +3614,8 @@ class GatewayRunner:
|
|||||||
current_api_key=_cur_api_key,
|
current_api_key=_cur_api_key,
|
||||||
is_global=False,
|
is_global=False,
|
||||||
explicit_provider=provider_slug,
|
explicit_provider=provider_slug,
|
||||||
|
user_providers=user_provs,
|
||||||
|
custom_providers=custom_provs,
|
||||||
)
|
)
|
||||||
if not result.success:
|
if not result.success:
|
||||||
return f"Error: {result.error_message}"
|
return f"Error: {result.error_message}"
|
||||||
@@ -3689,6 +3694,7 @@ class GatewayRunner:
|
|||||||
providers = list_authenticated_providers(
|
providers = list_authenticated_providers(
|
||||||
current_provider=current_provider,
|
current_provider=current_provider,
|
||||||
user_providers=user_provs,
|
user_providers=user_provs,
|
||||||
|
custom_providers=custom_provs,
|
||||||
max_models=5,
|
max_models=5,
|
||||||
)
|
)
|
||||||
for p in providers:
|
for p in providers:
|
||||||
@@ -3718,6 +3724,8 @@ class GatewayRunner:
|
|||||||
current_api_key=current_api_key,
|
current_api_key=current_api_key,
|
||||||
is_global=persist_global,
|
is_global=persist_global,
|
||||||
explicit_provider=explicit_provider,
|
explicit_provider=explicit_provider,
|
||||||
|
user_providers=user_provs,
|
||||||
|
custom_providers=custom_provs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result.success:
|
if not result.success:
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from dataclasses import dataclass
|
|||||||
from typing import List, NamedTuple, Optional
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
from hermes_cli.providers import (
|
from hermes_cli.providers import (
|
||||||
|
custom_provider_slug,
|
||||||
determine_api_mode,
|
determine_api_mode,
|
||||||
get_label,
|
get_label,
|
||||||
is_aggregator,
|
is_aggregator,
|
||||||
@@ -336,6 +337,7 @@ def resolve_alias(
|
|||||||
def get_authenticated_provider_slugs(
|
def get_authenticated_provider_slugs(
|
||||||
current_provider: str = "",
|
current_provider: str = "",
|
||||||
user_providers: dict = None,
|
user_providers: dict = None,
|
||||||
|
custom_providers: list | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Return slugs of providers that have credentials.
|
"""Return slugs of providers that have credentials.
|
||||||
|
|
||||||
@@ -346,6 +348,7 @@ def get_authenticated_provider_slugs(
|
|||||||
providers = list_authenticated_providers(
|
providers = list_authenticated_providers(
|
||||||
current_provider=current_provider,
|
current_provider=current_provider,
|
||||||
user_providers=user_providers,
|
user_providers=user_providers,
|
||||||
|
custom_providers=custom_providers,
|
||||||
max_models=0,
|
max_models=0,
|
||||||
)
|
)
|
||||||
return [p["slug"] for p in providers]
|
return [p["slug"] for p in providers]
|
||||||
@@ -383,6 +386,7 @@ def switch_model(
|
|||||||
is_global: bool = False,
|
is_global: bool = False,
|
||||||
explicit_provider: str = "",
|
explicit_provider: str = "",
|
||||||
user_providers: dict = None,
|
user_providers: dict = None,
|
||||||
|
custom_providers: list | None = None,
|
||||||
) -> ModelSwitchResult:
|
) -> ModelSwitchResult:
|
||||||
"""Core model-switching pipeline shared between CLI and gateway.
|
"""Core model-switching pipeline shared between CLI and gateway.
|
||||||
|
|
||||||
@@ -416,6 +420,7 @@ def switch_model(
|
|||||||
is_global: Whether to persist the switch.
|
is_global: Whether to persist the switch.
|
||||||
explicit_provider: From --provider flag (empty = no explicit provider).
|
explicit_provider: From --provider flag (empty = no explicit provider).
|
||||||
user_providers: The ``providers:`` dict from config.yaml (for user endpoints).
|
user_providers: The ``providers:`` dict from config.yaml (for user endpoints).
|
||||||
|
custom_providers: The ``custom_providers:`` list from config.yaml.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelSwitchResult with all information the caller needs.
|
ModelSwitchResult with all information the caller needs.
|
||||||
@@ -436,7 +441,11 @@ def switch_model(
|
|||||||
# =================================================================
|
# =================================================================
|
||||||
if explicit_provider:
|
if explicit_provider:
|
||||||
# Resolve the provider
|
# Resolve the provider
|
||||||
pdef = resolve_provider_full(explicit_provider, user_providers)
|
pdef = resolve_provider_full(
|
||||||
|
explicit_provider,
|
||||||
|
user_providers,
|
||||||
|
custom_providers,
|
||||||
|
)
|
||||||
if pdef is None:
|
if pdef is None:
|
||||||
_switch_err = (
|
_switch_err = (
|
||||||
f"Unknown provider '{explicit_provider}'. "
|
f"Unknown provider '{explicit_provider}'. "
|
||||||
@@ -516,6 +525,7 @@ def switch_model(
|
|||||||
authed = get_authenticated_provider_slugs(
|
authed = get_authenticated_provider_slugs(
|
||||||
current_provider=current_provider,
|
current_provider=current_provider,
|
||||||
user_providers=user_providers,
|
user_providers=user_providers,
|
||||||
|
custom_providers=custom_providers,
|
||||||
)
|
)
|
||||||
fallback_result = _resolve_alias_fallback(raw_input, authed)
|
fallback_result = _resolve_alias_fallback(raw_input, authed)
|
||||||
if fallback_result is not None:
|
if fallback_result is not None:
|
||||||
@@ -590,6 +600,14 @@ def switch_model(
|
|||||||
|
|
||||||
provider_changed = target_provider != current_provider
|
provider_changed = target_provider != current_provider
|
||||||
provider_label = get_label(target_provider)
|
provider_label = get_label(target_provider)
|
||||||
|
if target_provider.startswith("custom:"):
|
||||||
|
custom_pdef = resolve_provider_full(
|
||||||
|
target_provider,
|
||||||
|
user_providers,
|
||||||
|
custom_providers,
|
||||||
|
)
|
||||||
|
if custom_pdef is not None:
|
||||||
|
provider_label = custom_pdef.name
|
||||||
|
|
||||||
# --- Resolve credentials ---
|
# --- Resolve credentials ---
|
||||||
api_key = current_api_key
|
api_key = current_api_key
|
||||||
@@ -708,6 +726,7 @@ def switch_model(
|
|||||||
def list_authenticated_providers(
|
def list_authenticated_providers(
|
||||||
current_provider: str = "",
|
current_provider: str = "",
|
||||||
user_providers: dict = None,
|
user_providers: dict = None,
|
||||||
|
custom_providers: list | None = None,
|
||||||
max_models: int = 8,
|
max_models: int = 8,
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""Detect which providers have credentials and list their curated models.
|
"""Detect which providers have credentials and list their curated models.
|
||||||
@@ -853,6 +872,43 @@ def list_authenticated_providers(
|
|||||||
"api_url": api_url,
|
"api_url": api_url,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# --- 4. Saved custom providers from config ---
|
||||||
|
if custom_providers and isinstance(custom_providers, list):
|
||||||
|
for entry in custom_providers:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
display_name = (entry.get("name") or "").strip()
|
||||||
|
api_url = (
|
||||||
|
entry.get("base_url", "")
|
||||||
|
or entry.get("url", "")
|
||||||
|
or entry.get("api", "")
|
||||||
|
or ""
|
||||||
|
).strip()
|
||||||
|
if not display_name or not api_url:
|
||||||
|
continue
|
||||||
|
|
||||||
|
slug = custom_provider_slug(display_name)
|
||||||
|
if slug in seen_slugs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
models_list = []
|
||||||
|
default_model = (entry.get("model") or "").strip()
|
||||||
|
if default_model:
|
||||||
|
models_list.append(default_model)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"slug": slug,
|
||||||
|
"name": display_name,
|
||||||
|
"is_current": slug == current_provider,
|
||||||
|
"is_user_defined": True,
|
||||||
|
"models": models_list,
|
||||||
|
"total_models": len(models_list),
|
||||||
|
"source": "user-config",
|
||||||
|
"api_url": api_url,
|
||||||
|
})
|
||||||
|
seen_slugs.add(slug)
|
||||||
|
|
||||||
# Sort: current provider first, then by model count descending
|
# Sort: current provider first, then by model count descending
|
||||||
results.sort(key=lambda r: (not r["is_current"], -r["total_models"]))
|
results.sort(key=lambda r: (not r["is_current"], -r["total_models"]))
|
||||||
|
|
||||||
|
|||||||
@@ -452,9 +452,64 @@ def resolve_user_provider(name: str, user_config: Dict[str, Any]) -> Optional[Pr
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def custom_provider_slug(display_name: str) -> str:
|
||||||
|
"""Build a canonical slug for a custom_providers entry.
|
||||||
|
|
||||||
|
Matches the convention used by runtime_provider and credential_pool
|
||||||
|
(``custom:<normalized-name>``). Centralised here so all call-sites
|
||||||
|
produce identical slugs.
|
||||||
|
"""
|
||||||
|
return "custom:" + display_name.strip().lower().replace(" ", "-")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_custom_provider(
|
||||||
|
name: str,
|
||||||
|
custom_providers: Optional[List[Dict[str, Any]]],
|
||||||
|
) -> Optional[ProviderDef]:
|
||||||
|
"""Resolve a provider from the user's config.yaml ``custom_providers`` list."""
|
||||||
|
if not custom_providers or not isinstance(custom_providers, list):
|
||||||
|
return None
|
||||||
|
|
||||||
|
requested = (name or "").strip().lower()
|
||||||
|
if not requested:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for entry in custom_providers:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
display_name = (entry.get("name") or "").strip()
|
||||||
|
api_url = (
|
||||||
|
entry.get("base_url", "")
|
||||||
|
or entry.get("url", "")
|
||||||
|
or entry.get("api", "")
|
||||||
|
or ""
|
||||||
|
).strip()
|
||||||
|
if not display_name or not api_url:
|
||||||
|
continue
|
||||||
|
|
||||||
|
slug = custom_provider_slug(display_name)
|
||||||
|
if requested not in {display_name.lower(), slug}:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return ProviderDef(
|
||||||
|
id=slug,
|
||||||
|
name=display_name,
|
||||||
|
transport="openai_chat",
|
||||||
|
api_key_env_vars=(),
|
||||||
|
base_url=api_url,
|
||||||
|
is_aggregator=False,
|
||||||
|
auth_type="api_key",
|
||||||
|
source="user-config",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def resolve_provider_full(
|
def resolve_provider_full(
|
||||||
name: str,
|
name: str,
|
||||||
user_providers: Optional[Dict[str, Any]] = None,
|
user_providers: Optional[Dict[str, Any]] = None,
|
||||||
|
custom_providers: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> Optional[ProviderDef]:
|
) -> Optional[ProviderDef]:
|
||||||
"""Full resolution chain: built-in → models.dev → user config.
|
"""Full resolution chain: built-in → models.dev → user config.
|
||||||
|
|
||||||
@@ -463,6 +518,7 @@ def resolve_provider_full(
|
|||||||
Args:
|
Args:
|
||||||
name: Provider name or alias.
|
name: Provider name or alias.
|
||||||
user_providers: The ``providers:`` dict from config.yaml (optional).
|
user_providers: The ``providers:`` dict from config.yaml (optional).
|
||||||
|
custom_providers: The ``custom_providers:`` list from config.yaml (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ProviderDef if found, else None.
|
ProviderDef if found, else None.
|
||||||
@@ -485,6 +541,11 @@ def resolve_provider_full(
|
|||||||
if user_pdef is not None:
|
if user_pdef is not None:
|
||||||
return user_pdef
|
return user_pdef
|
||||||
|
|
||||||
|
# 2b. Saved custom providers from config
|
||||||
|
custom_pdef = resolve_custom_provider(name, custom_providers)
|
||||||
|
if custom_pdef is not None:
|
||||||
|
return custom_pdef
|
||||||
|
|
||||||
# 3. Try models.dev directly (for providers not in our ALIASES)
|
# 3. Try models.dev directly (for providers not in our ALIASES)
|
||||||
try:
|
try:
|
||||||
from agent.models_dev import get_provider_info as _mdev_provider
|
from agent.models_dev import get_provider_info as _mdev_provider
|
||||||
|
|||||||
63
tests/gateway/test_model_command_custom_providers.py
Normal file
63
tests/gateway/test_model_command_custom_providers.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Regression tests for gateway /model support of config.yaml custom_providers."""
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gateway.config import Platform
|
||||||
|
from gateway.platforms.base import MessageEvent, MessageType
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runner():
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner.adapters = {}
|
||||||
|
runner._voice_mode = {}
|
||||||
|
runner._session_model_overrides = {}
|
||||||
|
return runner
|
||||||
|
|
||||||
|
|
||||||
|
def _make_event(text="/model"):
|
||||||
|
return MessageEvent(
|
||||||
|
text=text,
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_model_command_lists_saved_custom_provider(tmp_path, monkeypatch):
|
||||||
|
hermes_home = tmp_path / ".hermes"
|
||||||
|
hermes_home.mkdir()
|
||||||
|
(hermes_home / "config.yaml").write_text(
|
||||||
|
yaml.safe_dump(
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"default": "gpt-5.4",
|
||||||
|
"provider": "openai-codex",
|
||||||
|
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||||
|
},
|
||||||
|
"providers": {},
|
||||||
|
"custom_providers": [
|
||||||
|
{
|
||||||
|
"name": "Local (127.0.0.1:4141)",
|
||||||
|
"base_url": "http://127.0.0.1:4141/v1",
|
||||||
|
"model": "rotator-openrouter-coding",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
import gateway.run as gateway_run
|
||||||
|
|
||||||
|
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||||
|
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||||
|
|
||||||
|
result = await _make_runner()._handle_model_command(_make_event())
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "Local (127.0.0.1:4141)" in result
|
||||||
|
assert "custom:local-(127.0.0.1:4141)" in result
|
||||||
|
assert "rotator-openrouter-coding" in result
|
||||||
104
tests/hermes_cli/test_model_switch_custom_providers.py
Normal file
104
tests/hermes_cli/test_model_switch_custom_providers.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Regression tests for /model support of config.yaml custom_providers.
|
||||||
|
|
||||||
|
The terminal `hermes model` flow already exposes `custom_providers`, but the
|
||||||
|
shared slash-command pipeline (`/model` in CLI/gateway/Telegram) historically
|
||||||
|
only looked at `providers:`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hermes_cli.providers as providers_mod
|
||||||
|
from hermes_cli.model_switch import list_authenticated_providers, switch_model
|
||||||
|
from hermes_cli.providers import resolve_provider_full
|
||||||
|
|
||||||
|
|
||||||
|
_MOCK_VALIDATION = {
|
||||||
|
"accepted": True,
|
||||||
|
"persist": True,
|
||||||
|
"recognized": True,
|
||||||
|
"message": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_authenticated_providers_includes_custom_providers(monkeypatch):
|
||||||
|
"""No-args /model menus should include saved custom_providers entries."""
|
||||||
|
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||||
|
monkeypatch.setattr(providers_mod, "HERMES_OVERLAYS", {})
|
||||||
|
|
||||||
|
providers = list_authenticated_providers(
|
||||||
|
current_provider="openai-codex",
|
||||||
|
user_providers={},
|
||||||
|
custom_providers=[
|
||||||
|
{
|
||||||
|
"name": "Local (127.0.0.1:4141)",
|
||||||
|
"base_url": "http://127.0.0.1:4141/v1",
|
||||||
|
"model": "rotator-openrouter-coding",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
max_models=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert any(
|
||||||
|
p["slug"] == "custom:local-(127.0.0.1:4141)"
|
||||||
|
and p["name"] == "Local (127.0.0.1:4141)"
|
||||||
|
and p["models"] == ["rotator-openrouter-coding"]
|
||||||
|
and p["api_url"] == "http://127.0.0.1:4141/v1"
|
||||||
|
for p in providers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_provider_full_finds_named_custom_provider():
|
||||||
|
"""Explicit /model --provider should resolve saved custom_providers entries."""
|
||||||
|
resolved = resolve_provider_full(
|
||||||
|
"custom:local-(127.0.0.1:4141)",
|
||||||
|
user_providers={},
|
||||||
|
custom_providers=[
|
||||||
|
{
|
||||||
|
"name": "Local (127.0.0.1:4141)",
|
||||||
|
"base_url": "http://127.0.0.1:4141/v1",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved is not None
|
||||||
|
assert resolved.id == "custom:local-(127.0.0.1:4141)"
|
||||||
|
assert resolved.name == "Local (127.0.0.1:4141)"
|
||||||
|
assert resolved.base_url == "http://127.0.0.1:4141/v1"
|
||||||
|
assert resolved.source == "user-config"
|
||||||
|
|
||||||
|
|
||||||
|
def test_switch_model_accepts_explicit_named_custom_provider(monkeypatch):
|
||||||
|
"""Shared /model switch pipeline should accept --provider for custom_providers."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||||
|
lambda requested: {
|
||||||
|
"api_key": "no-key-required",
|
||||||
|
"base_url": "http://127.0.0.1:4141/v1",
|
||||||
|
"api_mode": "chat_completions",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("hermes_cli.models.validate_requested_model", lambda *a, **k: _MOCK_VALIDATION)
|
||||||
|
monkeypatch.setattr("hermes_cli.model_switch.get_model_info", lambda *a, **k: None)
|
||||||
|
monkeypatch.setattr("hermes_cli.model_switch.get_model_capabilities", lambda *a, **k: None)
|
||||||
|
|
||||||
|
result = switch_model(
|
||||||
|
raw_input="rotator-openrouter-coding",
|
||||||
|
current_provider="openai-codex",
|
||||||
|
current_model="gpt-5.4",
|
||||||
|
current_base_url="https://chatgpt.com/backend-api/codex",
|
||||||
|
current_api_key="",
|
||||||
|
explicit_provider="custom:local-(127.0.0.1:4141)",
|
||||||
|
user_providers={},
|
||||||
|
custom_providers=[
|
||||||
|
{
|
||||||
|
"name": "Local (127.0.0.1:4141)",
|
||||||
|
"base_url": "http://127.0.0.1:4141/v1",
|
||||||
|
"model": "rotator-openrouter-coding",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.target_provider == "custom:local-(127.0.0.1:4141)"
|
||||||
|
assert result.provider_label == "Local (127.0.0.1:4141)"
|
||||||
|
assert result.new_model == "rotator-openrouter-coding"
|
||||||
|
assert result.base_url == "http://127.0.0.1:4141/v1"
|
||||||
|
assert result.api_key == "no-key-required"
|
||||||
Reference in New Issue
Block a user