mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
fix(tui): share static model detection
This commit is contained in:
@@ -1379,27 +1379,35 @@ def curated_models_for_provider(
|
||||
return [(m, "") for m in models]
|
||||
|
||||
|
||||
def detect_provider_for_model(
|
||||
def _provider_keys(provider: str) -> set[str]:
|
||||
key = (provider or "").strip().lower()
|
||||
normalized = normalize_provider(provider)
|
||||
return {k for k in (key, normalized) if k}
|
||||
|
||||
|
||||
def _model_in_provider_catalog(name_lower: str, providers: set[str]) -> bool:
|
||||
return any(
|
||||
name_lower == model.lower()
|
||||
for provider in providers
|
||||
for model in _PROVIDER_MODELS.get(provider, [])
|
||||
)
|
||||
|
||||
|
||||
def detect_static_provider_for_model(
|
||||
model_name: str,
|
||||
current_provider: str,
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Auto-detect the best provider for a model name.
|
||||
"""Auto-detect a provider from static catalogs only.
|
||||
|
||||
Returns ``(provider_id, model_name)`` — the model name may be remapped
|
||||
(e.g. bare ``deepseek-chat`` → ``deepseek/deepseek-chat`` for OpenRouter).
|
||||
Returns ``None`` when no confident match is found.
|
||||
|
||||
Priority:
|
||||
0. Bare provider name → switch to that provider's default model
|
||||
1. Direct provider with credentials (highest)
|
||||
2. Direct provider without credentials → remap to OpenRouter slug
|
||||
3. OpenRouter catalog match
|
||||
"""
|
||||
name = (model_name or "").strip()
|
||||
if not name:
|
||||
return None
|
||||
|
||||
name_lower = name.lower()
|
||||
current_keys = _provider_keys(current_provider)
|
||||
|
||||
# --- Step 0: bare provider name typed as model ---
|
||||
# If someone types `/model nous` or `/model anthropic`, treat it as a
|
||||
@@ -1412,7 +1420,7 @@ def detect_provider_for_model(
|
||||
if (
|
||||
resolved_provider in _PROVIDER_LABELS
|
||||
and default_models
|
||||
and resolved_provider != normalize_provider(current_provider)
|
||||
and resolved_provider not in current_keys
|
||||
):
|
||||
return (resolved_provider, default_models[0])
|
||||
|
||||
@@ -1420,56 +1428,43 @@ def detect_provider_for_model(
|
||||
_AGGREGATORS = {"nous", "openrouter", "ai-gateway", "copilot", "kilocode"}
|
||||
|
||||
# If the model belongs to the current provider's catalog, don't suggest switching
|
||||
current_models = _PROVIDER_MODELS.get(current_provider, [])
|
||||
if any(name_lower == m.lower() for m in current_models):
|
||||
if _model_in_provider_catalog(name_lower, current_keys):
|
||||
return None
|
||||
|
||||
# --- Step 1: check static provider catalogs for a direct match ---
|
||||
direct_match: Optional[str] = None
|
||||
for pid, models in _PROVIDER_MODELS.items():
|
||||
if pid == current_provider or pid in _AGGREGATORS:
|
||||
if pid in current_keys or pid in _AGGREGATORS:
|
||||
continue
|
||||
if any(name_lower == m.lower() for m in models):
|
||||
direct_match = pid
|
||||
break
|
||||
return (pid, name)
|
||||
|
||||
if direct_match:
|
||||
# Check if we have credentials for this provider — env vars,
|
||||
# credential pool, or auth store entries.
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
pconfig = PROVIDER_REGISTRY.get(direct_match)
|
||||
if pconfig:
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
has_creds = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
# Also check credential pool and auth store — covers OAuth,
|
||||
# Claude Code tokens, and other non-env-var credentials (#10300).
|
||||
if not has_creds:
|
||||
try:
|
||||
from agent.credential_pool import load_pool
|
||||
pool = load_pool(direct_match)
|
||||
if pool.has_credentials():
|
||||
has_creds = True
|
||||
except Exception:
|
||||
pass
|
||||
if not has_creds:
|
||||
try:
|
||||
from hermes_cli.auth import _load_auth_store
|
||||
store = _load_auth_store()
|
||||
if direct_match in store.get("providers", {}) or direct_match in store.get("credential_pool", {}):
|
||||
has_creds = True
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
# Always return the direct provider match. If credentials are
|
||||
# missing, the client init will give a clear error rather than
|
||||
# silently routing through the wrong provider (#10300).
|
||||
return (direct_match, name)
|
||||
|
||||
def detect_provider_for_model(
|
||||
model_name: str,
|
||||
current_provider: str,
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Auto-detect the best provider for a model name.
|
||||
|
||||
Returns ``(provider_id, model_name)`` — the model name may be remapped
|
||||
(e.g. bare ``deepseek-chat`` → ``deepseek/deepseek-chat`` for OpenRouter).
|
||||
Returns ``None`` when no confident match is found.
|
||||
|
||||
Priority:
|
||||
0. Bare provider name → switch to that provider's default model
|
||||
1. Direct provider static catalog match
|
||||
2. OpenRouter catalog match
|
||||
"""
|
||||
name = (model_name or "").strip()
|
||||
if not name:
|
||||
return None
|
||||
|
||||
static_match = detect_static_provider_for_model(name, current_provider)
|
||||
if static_match:
|
||||
return static_match
|
||||
if _model_in_provider_catalog(name.lower(), _provider_keys(current_provider)):
|
||||
return None
|
||||
|
||||
# --- Step 2: check OpenRouter catalog ---
|
||||
# First try exact match (handles provider/model format)
|
||||
|
||||
@@ -113,7 +113,8 @@ def test_startup_runtime_does_not_treat_inference_provider_as_explicit(monkeypat
|
||||
monkeypatch.delenv("HERMES_TUI_PROVIDER", raising=False)
|
||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "nous")
|
||||
monkeypatch.setattr(
|
||||
server, "_detect_static_provider_for_model", lambda model, provider: None
|
||||
"hermes_cli.models.detect_static_provider_for_model",
|
||||
lambda model, provider: None,
|
||||
)
|
||||
|
||||
assert server._resolve_startup_runtime() == ("nous/hermes-test", None)
|
||||
@@ -130,7 +131,9 @@ def test_startup_runtime_detects_provider_for_model_env(monkeypatch):
|
||||
assert current_provider == "auto"
|
||||
return "anthropic", "anthropic/claude-sonnet-4.6"
|
||||
|
||||
monkeypatch.setattr(server, "_detect_static_provider_for_model", fake_detect)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.detect_static_provider_for_model", fake_detect
|
||||
)
|
||||
|
||||
assert server._resolve_startup_runtime() == (
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
@@ -145,7 +148,9 @@ def test_startup_runtime_does_not_call_network_detector(monkeypatch):
|
||||
monkeypatch.setattr(server, "_load_cfg", lambda: {"model": {"provider": "auto"}})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.detect_provider_for_model",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("network detector called")),
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
AssertionError("network detector called")
|
||||
),
|
||||
)
|
||||
|
||||
model, provider = server._resolve_startup_runtime()
|
||||
|
||||
@@ -574,48 +574,6 @@ def _resolve_model() -> str:
|
||||
return "anthropic/claude-sonnet-4"
|
||||
|
||||
|
||||
def _detect_static_provider_for_model(model_name: str, current_provider: str) -> tuple[str, str] | None:
|
||||
"""Startup-safe provider detection: static catalogs only, no network fetches."""
|
||||
name = (model_name or "").strip()
|
||||
if not name:
|
||||
return None
|
||||
|
||||
try:
|
||||
from hermes_cli.models import (
|
||||
_PROVIDER_ALIASES,
|
||||
_PROVIDER_LABELS,
|
||||
_PROVIDER_MODELS,
|
||||
normalize_provider,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
name_lower = name.lower()
|
||||
normalized_current = normalize_provider(current_provider)
|
||||
resolved_provider = _PROVIDER_ALIASES.get(name_lower, name_lower)
|
||||
if resolved_provider not in {"custom", "openrouter"}:
|
||||
default_models = _PROVIDER_MODELS.get(resolved_provider, [])
|
||||
if (
|
||||
resolved_provider in _PROVIDER_LABELS
|
||||
and default_models
|
||||
and resolved_provider != normalized_current
|
||||
):
|
||||
return resolved_provider, default_models[0]
|
||||
|
||||
aggregators = {"nous", "openrouter", "ai-gateway", "copilot", "kilocode"}
|
||||
current_models = _PROVIDER_MODELS.get(normalized_current, [])
|
||||
if any(name_lower == m.lower() for m in current_models):
|
||||
return None
|
||||
|
||||
for provider, models in _PROVIDER_MODELS.items():
|
||||
if provider == normalized_current or provider in aggregators:
|
||||
continue
|
||||
if any(name_lower == m.lower() for m in models):
|
||||
return provider, name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_startup_runtime() -> tuple[str, str | None]:
|
||||
model = _resolve_model()
|
||||
explicit_provider = os.environ.get("HERMES_TUI_PROVIDER", "").strip()
|
||||
@@ -630,13 +588,19 @@ def _resolve_startup_runtime() -> tuple[str, str | None]:
|
||||
return model, None
|
||||
|
||||
try:
|
||||
from hermes_cli.models import detect_static_provider_for_model
|
||||
|
||||
cfg = _load_cfg().get("model") or {}
|
||||
current_provider = (
|
||||
str(cfg.get("provider") or "").strip().lower()
|
||||
if isinstance(cfg, dict)
|
||||
else ""
|
||||
) or os.environ.get("HERMES_INFERENCE_PROVIDER", "").strip().lower() or "auto"
|
||||
detected = _detect_static_provider_for_model(explicit_model, current_provider)
|
||||
(
|
||||
str(cfg.get("provider") or "").strip().lower()
|
||||
if isinstance(cfg, dict)
|
||||
else ""
|
||||
)
|
||||
or os.environ.get("HERMES_INFERENCE_PROVIDER", "").strip().lower()
|
||||
or "auto"
|
||||
)
|
||||
detected = detect_static_provider_for_model(explicit_model, current_provider)
|
||||
if detected:
|
||||
provider, detected_model = detected
|
||||
return detected_model, provider
|
||||
|
||||
Reference in New Issue
Block a user