diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 615325a135..860f74bb59 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1699,8 +1699,9 @@ def _remove_custom_provider(config): def _model_flow_named_custom(config, provider_info): """Handle a named custom provider from config.yaml custom_providers list. - If the entry has a saved model name, activates it immediately. - Otherwise probes the endpoint's /models API to let the user pick one. + Always probes the endpoint's /models API to let the user pick a model. + If a model was previously saved, it is pre-selected in the menu. + Falls back to the saved model if probing fails. """ from hermes_cli.auth import _save_model_choice, deactivate_provider from hermes_cli.config import load_config, save_config @@ -1711,40 +1712,29 @@ def _model_flow_named_custom(config, provider_info): api_key = provider_info.get("api_key", "") saved_model = provider_info.get("model", "") - # If a model is saved, just activate immediately — no probing needed - if saved_model: - _save_model_choice(saved_model) - - cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "custom" - model["base_url"] = base_url - if api_key: - model["api_key"] = api_key - save_config(cfg) - deactivate_provider() - - print(f"✅ Switched to: {saved_model}") - print(f" Provider: {name} ({base_url})") - return - - # No saved model — probe endpoint and let user pick print(f" Provider: {name}") print(f" URL: {base_url}") + if saved_model: + print(f" Current: {saved_model}") print() - print("No model saved for this provider. Fetching available models...") + + print("Fetching available models...") models = fetch_api_models(api_key, base_url, timeout=8.0) if models: + default_idx = 0 + if saved_model and saved_model in models: + default_idx = models.index(saved_model) + print(f"Found {len(models)} model(s):\n") try: from simple_term_menu import TerminalMenu - menu_items = [f" {m}" for m in models] + [" Cancel"] + menu_items = [ + f" {m} (current)" if m == saved_model else f" {m}" + for m in models + ] + [" Cancel"] menu = TerminalMenu( - menu_items, cursor_index=0, + menu_items, cursor_index=default_idx, menu_cursor="-> ", menu_cursor_style=("fg_green", "bold"), menu_highlight_style=("fg_green",), cycle_cursor=True, clear_screen=False, @@ -1760,7 +1750,8 @@ def _model_flow_named_custom(config, provider_info): model_name = models[idx] except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError): for i, m in enumerate(models, 1): - print(f" {i}. {m}") + suffix = " (current)" if m == saved_model else "" + print(f" {i}. {m}{suffix}") print(f" {len(models) + 1}. Cancel") print() try: @@ -1776,6 +1767,13 @@ def _model_flow_named_custom(config, provider_info): except (ValueError, KeyboardInterrupt, EOFError): print("\nCancelled.") return + elif saved_model: + print("Could not fetch models from endpoint.") + try: + model_name = input(f"Model name [{saved_model}]: ").strip() or saved_model + except (KeyboardInterrupt, EOFError): + print("\nCancelled.") + return else: print("Could not fetch models from endpoint. Enter model name manually.") try: