Compare commits

...

2 Commits

Author SHA1 Message Date
Teknium
a844513608 fix: extract custom_provider_slug() helper, harden gateway test
- Add custom_provider_slug() to hermes_cli/providers.py as the single
  source of truth for building 'custom:<name>' slugs.
- Use it in resolve_custom_provider() and list_authenticated_providers()
  instead of duplicated inline slug construction.
- Add _session_model_overrides and _voice_mode to gateway test runner
  for object.__new__() safety.
2026-04-10 03:00:39 -07:00
donrhmexe
da4ee6d193 fix: include custom_providers in /model command listings and resolution
Custom providers defined in config.yaml under  were
completely invisible to the /model command in both gateway (Telegram,
Discord, etc.) and CLI. The provider listing skipped them and explicit
switching via --provider failed with "Unknown provider".

Root cause: gateway/run.py, cli.py, and model_switch.py only read the
 dict from config, ignoring  entirely.

Changes:
- providers.py: add resolve_custom_provider() and extend
  resolve_provider_full() to check custom_providers after user_providers
- model_switch.py: propagate custom_providers through switch_model(),
  list_authenticated_providers(), and get_authenticated_provider_slugs();
  add custom provider section to provider listings
- gateway/run.py: read custom_providers from config, pass to all
  model-switch calls
- cli.py: hoist config loading, pass custom_providers to listing and
  switch calls

Tests: 4 new regression tests covering listing, resolution, and gateway
command handler. All 71 tests pass.
2026-04-10 02:49:25 -07:00
6 changed files with 306 additions and 10 deletions

22
cli.py
View File

@@ -4125,6 +4125,16 @@ class HermesCLI:
# Parse --provider and --global flags # Parse --provider and --global flags
model_input, explicit_provider, persist_global = parse_model_flags(raw_args) model_input, explicit_provider, persist_global = parse_model_flags(raw_args)
user_provs = None
custom_provs = None
try:
from hermes_cli.config import load_config
cfg = load_config()
user_provs = cfg.get("providers")
custom_provs = cfg.get("custom_providers")
except Exception:
pass
# No args at all: show available providers + models # No args at all: show available providers + models
if not model_input and not explicit_provider: if not model_input and not explicit_provider:
model_display = self.model or "unknown" model_display = self.model or "unknown"
@@ -4134,18 +4144,10 @@ class HermesCLI:
# Show authenticated providers with top models # Show authenticated providers with top models
try: try:
# Load user providers from config
user_provs = None
try:
from hermes_cli.config import load_config
cfg = load_config()
user_provs = cfg.get("providers")
except Exception:
pass
providers = list_authenticated_providers( providers = list_authenticated_providers(
current_provider=self.provider or "", current_provider=self.provider or "",
user_providers=user_provs, user_providers=user_provs,
custom_providers=custom_provs,
max_models=6, max_models=6,
) )
if providers: if providers:
@@ -4186,6 +4188,8 @@ class HermesCLI:
current_api_key=self.api_key or "", current_api_key=self.api_key or "",
is_global=persist_global, is_global=persist_global,
explicit_provider=explicit_provider, explicit_provider=explicit_provider,
user_providers=user_provs,
custom_providers=custom_provs,
) )
if not result.success: if not result.success:

View File

@@ -3546,6 +3546,7 @@ class GatewayRunner:
current_base_url = "" current_base_url = ""
current_api_key = "" current_api_key = ""
user_provs = None user_provs = None
custom_provs = None
config_path = _hermes_home / "config.yaml" config_path = _hermes_home / "config.yaml"
try: try:
if config_path.exists(): if config_path.exists():
@@ -3557,6 +3558,7 @@ class GatewayRunner:
current_provider = model_cfg.get("provider", current_provider) current_provider = model_cfg.get("provider", current_provider)
current_base_url = model_cfg.get("base_url", "") current_base_url = model_cfg.get("base_url", "")
user_provs = cfg.get("providers") user_provs = cfg.get("providers")
custom_provs = cfg.get("custom_providers")
except Exception: except Exception:
pass pass
@@ -3584,6 +3586,7 @@ class GatewayRunner:
providers = list_authenticated_providers( providers = list_authenticated_providers(
current_provider=current_provider, current_provider=current_provider,
user_providers=user_provs, user_providers=user_provs,
custom_providers=custom_provs,
max_models=50, max_models=50,
) )
except Exception: except Exception:
@@ -3611,6 +3614,8 @@ class GatewayRunner:
current_api_key=_cur_api_key, current_api_key=_cur_api_key,
is_global=False, is_global=False,
explicit_provider=provider_slug, explicit_provider=provider_slug,
user_providers=user_provs,
custom_providers=custom_provs,
) )
if not result.success: if not result.success:
return f"Error: {result.error_message}" return f"Error: {result.error_message}"
@@ -3689,6 +3694,7 @@ class GatewayRunner:
providers = list_authenticated_providers( providers = list_authenticated_providers(
current_provider=current_provider, current_provider=current_provider,
user_providers=user_provs, user_providers=user_provs,
custom_providers=custom_provs,
max_models=5, max_models=5,
) )
for p in providers: for p in providers:
@@ -3718,6 +3724,8 @@ class GatewayRunner:
current_api_key=current_api_key, current_api_key=current_api_key,
is_global=persist_global, is_global=persist_global,
explicit_provider=explicit_provider, explicit_provider=explicit_provider,
user_providers=user_provs,
custom_providers=custom_provs,
) )
if not result.success: if not result.success:

View File

@@ -25,6 +25,7 @@ from dataclasses import dataclass
from typing import List, NamedTuple, Optional from typing import List, NamedTuple, Optional
from hermes_cli.providers import ( from hermes_cli.providers import (
custom_provider_slug,
determine_api_mode, determine_api_mode,
get_label, get_label,
is_aggregator, is_aggregator,
@@ -336,6 +337,7 @@ def resolve_alias(
def get_authenticated_provider_slugs( def get_authenticated_provider_slugs(
current_provider: str = "", current_provider: str = "",
user_providers: dict = None, user_providers: dict = None,
custom_providers: list | None = None,
) -> list[str]: ) -> list[str]:
"""Return slugs of providers that have credentials. """Return slugs of providers that have credentials.
@@ -346,6 +348,7 @@ def get_authenticated_provider_slugs(
providers = list_authenticated_providers( providers = list_authenticated_providers(
current_provider=current_provider, current_provider=current_provider,
user_providers=user_providers, user_providers=user_providers,
custom_providers=custom_providers,
max_models=0, max_models=0,
) )
return [p["slug"] for p in providers] return [p["slug"] for p in providers]
@@ -383,6 +386,7 @@ def switch_model(
is_global: bool = False, is_global: bool = False,
explicit_provider: str = "", explicit_provider: str = "",
user_providers: dict = None, user_providers: dict = None,
custom_providers: list | None = None,
) -> ModelSwitchResult: ) -> ModelSwitchResult:
"""Core model-switching pipeline shared between CLI and gateway. """Core model-switching pipeline shared between CLI and gateway.
@@ -416,6 +420,7 @@ def switch_model(
is_global: Whether to persist the switch. is_global: Whether to persist the switch.
explicit_provider: From --provider flag (empty = no explicit provider). explicit_provider: From --provider flag (empty = no explicit provider).
user_providers: The ``providers:`` dict from config.yaml (for user endpoints). user_providers: The ``providers:`` dict from config.yaml (for user endpoints).
custom_providers: The ``custom_providers:`` list from config.yaml.
Returns: Returns:
ModelSwitchResult with all information the caller needs. ModelSwitchResult with all information the caller needs.
@@ -436,7 +441,11 @@ def switch_model(
# ================================================================= # =================================================================
if explicit_provider: if explicit_provider:
# Resolve the provider # Resolve the provider
pdef = resolve_provider_full(explicit_provider, user_providers) pdef = resolve_provider_full(
explicit_provider,
user_providers,
custom_providers,
)
if pdef is None: if pdef is None:
_switch_err = ( _switch_err = (
f"Unknown provider '{explicit_provider}'. " f"Unknown provider '{explicit_provider}'. "
@@ -516,6 +525,7 @@ def switch_model(
authed = get_authenticated_provider_slugs( authed = get_authenticated_provider_slugs(
current_provider=current_provider, current_provider=current_provider,
user_providers=user_providers, user_providers=user_providers,
custom_providers=custom_providers,
) )
fallback_result = _resolve_alias_fallback(raw_input, authed) fallback_result = _resolve_alias_fallback(raw_input, authed)
if fallback_result is not None: if fallback_result is not None:
@@ -590,6 +600,14 @@ def switch_model(
provider_changed = target_provider != current_provider provider_changed = target_provider != current_provider
provider_label = get_label(target_provider) provider_label = get_label(target_provider)
if target_provider.startswith("custom:"):
custom_pdef = resolve_provider_full(
target_provider,
user_providers,
custom_providers,
)
if custom_pdef is not None:
provider_label = custom_pdef.name
# --- Resolve credentials --- # --- Resolve credentials ---
api_key = current_api_key api_key = current_api_key
@@ -708,6 +726,7 @@ def switch_model(
def list_authenticated_providers( def list_authenticated_providers(
current_provider: str = "", current_provider: str = "",
user_providers: dict = None, user_providers: dict = None,
custom_providers: list | None = None,
max_models: int = 8, max_models: int = 8,
) -> List[dict]: ) -> List[dict]:
"""Detect which providers have credentials and list their curated models. """Detect which providers have credentials and list their curated models.
@@ -853,6 +872,43 @@ def list_authenticated_providers(
"api_url": api_url, "api_url": api_url,
}) })
# --- 4. Saved custom providers from config ---
if custom_providers and isinstance(custom_providers, list):
for entry in custom_providers:
if not isinstance(entry, dict):
continue
display_name = (entry.get("name") or "").strip()
api_url = (
entry.get("base_url", "")
or entry.get("url", "")
or entry.get("api", "")
or ""
).strip()
if not display_name or not api_url:
continue
slug = custom_provider_slug(display_name)
if slug in seen_slugs:
continue
models_list = []
default_model = (entry.get("model") or "").strip()
if default_model:
models_list.append(default_model)
results.append({
"slug": slug,
"name": display_name,
"is_current": slug == current_provider,
"is_user_defined": True,
"models": models_list,
"total_models": len(models_list),
"source": "user-config",
"api_url": api_url,
})
seen_slugs.add(slug)
# Sort: current provider first, then by model count descending # Sort: current provider first, then by model count descending
results.sort(key=lambda r: (not r["is_current"], -r["total_models"])) results.sort(key=lambda r: (not r["is_current"], -r["total_models"]))

View File

@@ -452,9 +452,64 @@ def resolve_user_provider(name: str, user_config: Dict[str, Any]) -> Optional[Pr
) )
def custom_provider_slug(display_name: str) -> str:
"""Build a canonical slug for a custom_providers entry.
Matches the convention used by runtime_provider and credential_pool
(``custom:<normalized-name>``). Centralised here so all call-sites
produce identical slugs.
"""
return "custom:" + display_name.strip().lower().replace(" ", "-")
def resolve_custom_provider(
name: str,
custom_providers: Optional[List[Dict[str, Any]]],
) -> Optional[ProviderDef]:
"""Resolve a provider from the user's config.yaml ``custom_providers`` list."""
if not custom_providers or not isinstance(custom_providers, list):
return None
requested = (name or "").strip().lower()
if not requested:
return None
for entry in custom_providers:
if not isinstance(entry, dict):
continue
display_name = (entry.get("name") or "").strip()
api_url = (
entry.get("base_url", "")
or entry.get("url", "")
or entry.get("api", "")
or ""
).strip()
if not display_name or not api_url:
continue
slug = custom_provider_slug(display_name)
if requested not in {display_name.lower(), slug}:
continue
return ProviderDef(
id=slug,
name=display_name,
transport="openai_chat",
api_key_env_vars=(),
base_url=api_url,
is_aggregator=False,
auth_type="api_key",
source="user-config",
)
return None
def resolve_provider_full( def resolve_provider_full(
name: str, name: str,
user_providers: Optional[Dict[str, Any]] = None, user_providers: Optional[Dict[str, Any]] = None,
custom_providers: Optional[List[Dict[str, Any]]] = None,
) -> Optional[ProviderDef]: ) -> Optional[ProviderDef]:
"""Full resolution chain: built-in → models.dev → user config. """Full resolution chain: built-in → models.dev → user config.
@@ -463,6 +518,7 @@ def resolve_provider_full(
Args: Args:
name: Provider name or alias. name: Provider name or alias.
user_providers: The ``providers:`` dict from config.yaml (optional). user_providers: The ``providers:`` dict from config.yaml (optional).
custom_providers: The ``custom_providers:`` list from config.yaml (optional).
Returns: Returns:
ProviderDef if found, else None. ProviderDef if found, else None.
@@ -485,6 +541,11 @@ def resolve_provider_full(
if user_pdef is not None: if user_pdef is not None:
return user_pdef return user_pdef
# 2b. Saved custom providers from config
custom_pdef = resolve_custom_provider(name, custom_providers)
if custom_pdef is not None:
return custom_pdef
# 3. Try models.dev directly (for providers not in our ALIASES) # 3. Try models.dev directly (for providers not in our ALIASES)
try: try:
from agent.models_dev import get_provider_info as _mdev_provider from agent.models_dev import get_provider_info as _mdev_provider

View File

@@ -0,0 +1,63 @@
"""Regression tests for gateway /model support of config.yaml custom_providers."""
import yaml
import pytest
from gateway.config import Platform
from gateway.platforms.base import MessageEvent, MessageType
from gateway.run import GatewayRunner
from gateway.session import SessionSource
def _make_runner():
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
runner._session_model_overrides = {}
return runner
def _make_event(text="/model"):
return MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"),
)
@pytest.mark.asyncio
async def test_handle_model_command_lists_saved_custom_provider(tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text(
yaml.safe_dump(
{
"model": {
"default": "gpt-5.4",
"provider": "openai-codex",
"base_url": "https://chatgpt.com/backend-api/codex",
},
"providers": {},
"custom_providers": [
{
"name": "Local (127.0.0.1:4141)",
"base_url": "http://127.0.0.1:4141/v1",
"model": "rotator-openrouter-coding",
}
],
}
),
encoding="utf-8",
)
import gateway.run as gateway_run
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
result = await _make_runner()._handle_model_command(_make_event())
assert result is not None
assert "Local (127.0.0.1:4141)" in result
assert "custom:local-(127.0.0.1:4141)" in result
assert "rotator-openrouter-coding" in result

View File

@@ -0,0 +1,104 @@
"""Regression tests for /model support of config.yaml custom_providers.
The terminal `hermes model` flow already exposes `custom_providers`, but the
shared slash-command pipeline (`/model` in CLI/gateway/Telegram) historically
only looked at `providers:`.
"""
import hermes_cli.providers as providers_mod
from hermes_cli.model_switch import list_authenticated_providers, switch_model
from hermes_cli.providers import resolve_provider_full
_MOCK_VALIDATION = {
"accepted": True,
"persist": True,
"recognized": True,
"message": None,
}
def test_list_authenticated_providers_includes_custom_providers(monkeypatch):
"""No-args /model menus should include saved custom_providers entries."""
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
monkeypatch.setattr(providers_mod, "HERMES_OVERLAYS", {})
providers = list_authenticated_providers(
current_provider="openai-codex",
user_providers={},
custom_providers=[
{
"name": "Local (127.0.0.1:4141)",
"base_url": "http://127.0.0.1:4141/v1",
"model": "rotator-openrouter-coding",
}
],
max_models=50,
)
assert any(
p["slug"] == "custom:local-(127.0.0.1:4141)"
and p["name"] == "Local (127.0.0.1:4141)"
and p["models"] == ["rotator-openrouter-coding"]
and p["api_url"] == "http://127.0.0.1:4141/v1"
for p in providers
)
def test_resolve_provider_full_finds_named_custom_provider():
"""Explicit /model --provider should resolve saved custom_providers entries."""
resolved = resolve_provider_full(
"custom:local-(127.0.0.1:4141)",
user_providers={},
custom_providers=[
{
"name": "Local (127.0.0.1:4141)",
"base_url": "http://127.0.0.1:4141/v1",
}
],
)
assert resolved is not None
assert resolved.id == "custom:local-(127.0.0.1:4141)"
assert resolved.name == "Local (127.0.0.1:4141)"
assert resolved.base_url == "http://127.0.0.1:4141/v1"
assert resolved.source == "user-config"
def test_switch_model_accepts_explicit_named_custom_provider(monkeypatch):
"""Shared /model switch pipeline should accept --provider for custom_providers."""
monkeypatch.setattr(
"hermes_cli.runtime_provider.resolve_runtime_provider",
lambda requested: {
"api_key": "no-key-required",
"base_url": "http://127.0.0.1:4141/v1",
"api_mode": "chat_completions",
},
)
monkeypatch.setattr("hermes_cli.models.validate_requested_model", lambda *a, **k: _MOCK_VALIDATION)
monkeypatch.setattr("hermes_cli.model_switch.get_model_info", lambda *a, **k: None)
monkeypatch.setattr("hermes_cli.model_switch.get_model_capabilities", lambda *a, **k: None)
result = switch_model(
raw_input="rotator-openrouter-coding",
current_provider="openai-codex",
current_model="gpt-5.4",
current_base_url="https://chatgpt.com/backend-api/codex",
current_api_key="",
explicit_provider="custom:local-(127.0.0.1:4141)",
user_providers={},
custom_providers=[
{
"name": "Local (127.0.0.1:4141)",
"base_url": "http://127.0.0.1:4141/v1",
"model": "rotator-openrouter-coding",
}
],
)
assert result.success is True
assert result.target_provider == "custom:local-(127.0.0.1:4141)"
assert result.provider_label == "Local (127.0.0.1:4141)"
assert result.new_model == "rotator-openrouter-coding"
assert result.base_url == "http://127.0.0.1:4141/v1"
assert result.api_key == "no-key-required"