From e48a497d166bfc32743d503fd00136c9140af5fe Mon Sep 17 00:00:00 2001 From: Brooklyn Nicholson Date: Sat, 25 Apr 2026 13:56:16 -0500 Subject: [PATCH] fix(tui): share static model detection --- hermes_cli/models.py | 99 +++++++++++++++----------------- tests/test_tui_gateway_server.py | 11 +++- tui_gateway/server.py | 58 ++++--------------- 3 files changed, 66 insertions(+), 102 deletions(-) diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 3a902ffdf5..494c3d2a9a 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -1379,27 +1379,35 @@ def curated_models_for_provider( return [(m, "") for m in models] -def detect_provider_for_model( +def _provider_keys(provider: str) -> set[str]: + key = (provider or "").strip().lower() + normalized = normalize_provider(provider) + return {k for k in (key, normalized) if k} + + +def _model_in_provider_catalog(name_lower: str, providers: set[str]) -> bool: + return any( + name_lower == model.lower() + for provider in providers + for model in _PROVIDER_MODELS.get(provider, []) + ) + + +def detect_static_provider_for_model( model_name: str, current_provider: str, ) -> Optional[tuple[str, str]]: - """Auto-detect the best provider for a model name. + """Auto-detect a provider from static catalogs only. Returns ``(provider_id, model_name)`` — the model name may be remapped - (e.g. bare ``deepseek-chat`` → ``deepseek/deepseek-chat`` for OpenRouter). Returns ``None`` when no confident match is found. - - Priority: - 0. Bare provider name → switch to that provider's default model - 1. Direct provider with credentials (highest) - 2. Direct provider without credentials → remap to OpenRouter slug - 3. OpenRouter catalog match """ name = (model_name or "").strip() if not name: return None name_lower = name.lower() + current_keys = _provider_keys(current_provider) # --- Step 0: bare provider name typed as model --- # If someone types `/model nous` or `/model anthropic`, treat it as a @@ -1412,7 +1420,7 @@ def detect_provider_for_model( if ( resolved_provider in _PROVIDER_LABELS and default_models - and resolved_provider != normalize_provider(current_provider) + and resolved_provider not in current_keys ): return (resolved_provider, default_models[0]) @@ -1420,56 +1428,43 @@ def detect_provider_for_model( _AGGREGATORS = {"nous", "openrouter", "ai-gateway", "copilot", "kilocode"} # If the model belongs to the current provider's catalog, don't suggest switching - current_models = _PROVIDER_MODELS.get(current_provider, []) - if any(name_lower == m.lower() for m in current_models): + if _model_in_provider_catalog(name_lower, current_keys): return None # --- Step 1: check static provider catalogs for a direct match --- - direct_match: Optional[str] = None for pid, models in _PROVIDER_MODELS.items(): - if pid == current_provider or pid in _AGGREGATORS: + if pid in current_keys or pid in _AGGREGATORS: continue if any(name_lower == m.lower() for m in models): - direct_match = pid - break + return (pid, name) - if direct_match: - # Check if we have credentials for this provider — env vars, - # credential pool, or auth store entries. - has_creds = False - try: - from hermes_cli.auth import PROVIDER_REGISTRY - pconfig = PROVIDER_REGISTRY.get(direct_match) - if pconfig: - for env_var in pconfig.api_key_env_vars: - if os.getenv(env_var, "").strip(): - has_creds = True - break - except Exception: - pass - # Also check credential pool and auth store — covers OAuth, - # Claude Code tokens, and other non-env-var credentials (#10300). - if not has_creds: - try: - from agent.credential_pool import load_pool - pool = load_pool(direct_match) - if pool.has_credentials(): - has_creds = True - except Exception: - pass - if not has_creds: - try: - from hermes_cli.auth import _load_auth_store - store = _load_auth_store() - if direct_match in store.get("providers", {}) or direct_match in store.get("credential_pool", {}): - has_creds = True - except Exception: - pass + return None - # Always return the direct provider match. If credentials are - # missing, the client init will give a clear error rather than - # silently routing through the wrong provider (#10300). - return (direct_match, name) + +def detect_provider_for_model( + model_name: str, + current_provider: str, +) -> Optional[tuple[str, str]]: + """Auto-detect the best provider for a model name. + + Returns ``(provider_id, model_name)`` — the model name may be remapped + (e.g. bare ``deepseek-chat`` → ``deepseek/deepseek-chat`` for OpenRouter). + Returns ``None`` when no confident match is found. + + Priority: + 0. Bare provider name → switch to that provider's default model + 1. Direct provider static catalog match + 2. OpenRouter catalog match + """ + name = (model_name or "").strip() + if not name: + return None + + static_match = detect_static_provider_for_model(name, current_provider) + if static_match: + return static_match + if _model_in_provider_catalog(name.lower(), _provider_keys(current_provider)): + return None # --- Step 2: check OpenRouter catalog --- # First try exact match (handles provider/model format) diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index 1610c52939..8bb6f003bf 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -113,7 +113,8 @@ def test_startup_runtime_does_not_treat_inference_provider_as_explicit(monkeypat monkeypatch.delenv("HERMES_TUI_PROVIDER", raising=False) monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "nous") monkeypatch.setattr( - server, "_detect_static_provider_for_model", lambda model, provider: None + "hermes_cli.models.detect_static_provider_for_model", + lambda model, provider: None, ) assert server._resolve_startup_runtime() == ("nous/hermes-test", None) @@ -130,7 +131,9 @@ def test_startup_runtime_detects_provider_for_model_env(monkeypatch): assert current_provider == "auto" return "anthropic", "anthropic/claude-sonnet-4.6" - monkeypatch.setattr(server, "_detect_static_provider_for_model", fake_detect) + monkeypatch.setattr( + "hermes_cli.models.detect_static_provider_for_model", fake_detect + ) assert server._resolve_startup_runtime() == ( "anthropic/claude-sonnet-4.6", @@ -145,7 +148,9 @@ def test_startup_runtime_does_not_call_network_detector(monkeypatch): monkeypatch.setattr(server, "_load_cfg", lambda: {"model": {"provider": "auto"}}) monkeypatch.setattr( "hermes_cli.models.detect_provider_for_model", - lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("network detector called")), + lambda *_args, **_kwargs: (_ for _ in ()).throw( + AssertionError("network detector called") + ), ) model, provider = server._resolve_startup_runtime() diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 557fec19c2..7f981663a3 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -574,48 +574,6 @@ def _resolve_model() -> str: return "anthropic/claude-sonnet-4" -def _detect_static_provider_for_model(model_name: str, current_provider: str) -> tuple[str, str] | None: - """Startup-safe provider detection: static catalogs only, no network fetches.""" - name = (model_name or "").strip() - if not name: - return None - - try: - from hermes_cli.models import ( - _PROVIDER_ALIASES, - _PROVIDER_LABELS, - _PROVIDER_MODELS, - normalize_provider, - ) - except Exception: - return None - - name_lower = name.lower() - normalized_current = normalize_provider(current_provider) - resolved_provider = _PROVIDER_ALIASES.get(name_lower, name_lower) - if resolved_provider not in {"custom", "openrouter"}: - default_models = _PROVIDER_MODELS.get(resolved_provider, []) - if ( - resolved_provider in _PROVIDER_LABELS - and default_models - and resolved_provider != normalized_current - ): - return resolved_provider, default_models[0] - - aggregators = {"nous", "openrouter", "ai-gateway", "copilot", "kilocode"} - current_models = _PROVIDER_MODELS.get(normalized_current, []) - if any(name_lower == m.lower() for m in current_models): - return None - - for provider, models in _PROVIDER_MODELS.items(): - if provider == normalized_current or provider in aggregators: - continue - if any(name_lower == m.lower() for m in models): - return provider, name - - return None - - def _resolve_startup_runtime() -> tuple[str, str | None]: model = _resolve_model() explicit_provider = os.environ.get("HERMES_TUI_PROVIDER", "").strip() @@ -630,13 +588,19 @@ def _resolve_startup_runtime() -> tuple[str, str | None]: return model, None try: + from hermes_cli.models import detect_static_provider_for_model + cfg = _load_cfg().get("model") or {} current_provider = ( - str(cfg.get("provider") or "").strip().lower() - if isinstance(cfg, dict) - else "" - ) or os.environ.get("HERMES_INFERENCE_PROVIDER", "").strip().lower() or "auto" - detected = _detect_static_provider_for_model(explicit_model, current_provider) + ( + str(cfg.get("provider") or "").strip().lower() + if isinstance(cfg, dict) + else "" + ) + or os.environ.get("HERMES_INFERENCE_PROVIDER", "").strip().lower() + or "auto" + ) + detected = detect_static_provider_for_model(explicit_model, current_provider) if detected: provider, detected_model = detected return detected_model, provider