Compare commits

...

1 Commits

Author SHA1 Message Date
Teknium
1fb4a46ad4 fix(tui): preserve custom provider identity on resume 2026-06-13 05:10:27 -07:00
4 changed files with 288 additions and 0 deletions

View File

@@ -660,6 +660,51 @@ def has_named_custom_provider(requested_provider: str) -> bool:
return False
def find_custom_provider_identity(base_url: str) -> Optional[str]:
"""Map an endpoint URL back to its canonical ``custom:<name>`` menu key.
Runtime agent state stores named custom endpoints as the resolved provider
``"custom"``. The session DB deliberately does not persist raw API keys,
so Desktop/TUI resume needs the original configured entry name to re-resolve
credentials from ``providers:`` / ``custom_providers:``.
"""
target = _normalize_base_url_for_match(base_url)
if not target:
return None
try:
config = load_config()
except Exception:
return None
providers = config.get("providers")
if isinstance(providers, dict):
for ep_name, entry in providers.items():
if not isinstance(entry, dict):
continue
entry_url = entry.get("api") or entry.get("url") or entry.get("base_url") or ""
if _normalize_base_url_for_match(entry_url) == target:
return f"custom:{_normalize_custom_provider_name(str(ep_name))}"
try:
custom_providers = get_compatible_custom_providers(config)
except Exception:
custom_providers = None
for entry in custom_providers or []:
if not isinstance(entry, dict):
continue
name = entry.get("name")
if not isinstance(name, str) or not name.strip():
continue
if _normalize_base_url_for_match(entry.get("base_url")) == target:
return f"custom:{_normalize_custom_provider_name(name)}"
return None
def _normalize_base_url_for_match(value) -> str:
return str(value or "").strip().rstrip("/").lower()
def _custom_provider_request_overrides(custom_provider: Dict[str, Any]) -> Dict[str, Any]:
extra_body = custom_provider.get("extra_body")
if not isinstance(extra_body, dict) or not extra_body:

View File

@@ -0,0 +1,60 @@
"""Custom provider entry identity lookup tests."""
import hermes_cli.runtime_provider as rp
def test_find_custom_provider_identity_matches_legacy_custom_providers_list(monkeypatch):
monkeypatch.setattr(
rp,
"load_config",
lambda: {
"custom_providers": [
{"name": "MiMo v2.5 Pro", "base_url": "https://api.mimo.example/v1"}
]
},
)
assert (
rp.find_custom_provider_identity("https://api.mimo.example/v1/")
== "custom:mimo-v2.5-pro"
)
def test_find_custom_provider_identity_matches_providers_dict_key(monkeypatch):
monkeypatch.setattr(
rp,
"load_config",
lambda: {"providers": {"local gpu": {"api": "http://127.0.0.1:8000/v1"}}},
)
assert rp.find_custom_provider_identity("HTTP://127.0.0.1:8000/v1") == "custom:local-gpu"
def test_find_custom_provider_identity_matches_providers_dict_url_alias(monkeypatch):
monkeypatch.setattr(
rp,
"load_config",
lambda: {"providers": {"proxy": {"url": "https://proxy.example/anthropic"}}},
)
assert rp.find_custom_provider_identity("https://proxy.example/anthropic") == "custom:proxy"
def test_find_custom_provider_identity_returns_none_for_unknown_url(monkeypatch):
monkeypatch.setattr(
rp,
"load_config",
lambda: {
"custom_providers": [
{"name": "known", "base_url": "https://known.example/v1"}
]
},
)
assert rp.find_custom_provider_identity("https://other.example/v1") is None
def test_find_custom_provider_identity_ignores_bad_config(monkeypatch):
monkeypatch.setattr(rp, "load_config", lambda: (_ for _ in ()).throw(RuntimeError("boom")))
assert rp.find_custom_provider_identity("https://known.example/v1") is None

View File

@@ -0,0 +1,159 @@
"""Session persistence must not strip a named custom provider's identity."""
import json
import types
from unittest.mock import MagicMock, patch
import hermes_cli.runtime_provider as rp
MIMO_URL = "https://token-plan-cn.xiaomimimo.com/v1"
MIMO_KEY = "sk-mimo-entry-key"
LEGACY_LIST_CONFIG = {
"custom_providers": [
{
"name": "mimo-v2.5-pro",
"base_url": MIMO_URL,
"api_key": MIMO_KEY,
"api_mode": "chat_completions",
}
]
}
PROVIDERS_DICT_CONFIG = {
"providers": {
"mimo-v2.5-pro": {
"api": MIMO_URL,
"api_key": MIMO_KEY,
}
}
}
def _custom_agent(base_url=MIMO_URL):
return types.SimpleNamespace(
model="mimo-v2.5-pro",
provider="custom",
base_url=base_url,
api_mode="chat_completions",
reasoning_config=None,
service_tier=None,
)
def _make_agent_with_override(override, monkeypatch, config):
"""Run _make_agent through real resolve_runtime_provider with patched config."""
monkeypatch.setattr(rp, "load_config", lambda: config)
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None)
fake_cfg = {"agent": {"system_prompt": ""}, "model": {"default": "unused"}}
with (
patch("tui_gateway.server._load_cfg", return_value=fake_cfg),
patch("tui_gateway.server._get_db", return_value=MagicMock()),
patch("tui_gateway.server._load_reasoning_config", return_value=None),
patch("tui_gateway.server._load_service_tier", return_value=None),
patch("tui_gateway.server._load_enabled_toolsets", return_value=None),
patch("run_agent.AIAgent") as mock_agent,
):
from tui_gateway.server import _make_agent
_make_agent("sid-custom", "key-custom", model_override=override)
return mock_agent.call_args.kwargs
def test_runtime_model_config_persists_menu_key_instead_of_resolved_custom(monkeypatch):
monkeypatch.setattr(rp, "load_config", lambda: LEGACY_LIST_CONFIG)
from tui_gateway.server import _runtime_model_config
config = _runtime_model_config(_custom_agent())
assert config["provider"] == "custom:mimo-v2.5-pro"
assert config["base_url"] == MIMO_URL
assert "api_key" not in config
def test_runtime_model_config_persists_menu_key_for_providers_dict_entry(monkeypatch):
monkeypatch.setattr(rp, "load_config", lambda: PROVIDERS_DICT_CONFIG)
from tui_gateway.server import _runtime_model_config
assert _runtime_model_config(_custom_agent())["provider"] == "custom:mimo-v2.5-pro"
def test_runtime_model_config_keeps_bare_custom_when_no_entry_matches(monkeypatch):
monkeypatch.setattr(rp, "load_config", lambda: {})
from tui_gateway.server import _runtime_model_config
assert _runtime_model_config(_custom_agent())["provider"] == "custom"
def test_runtime_model_config_leaves_non_custom_provider_untouched(monkeypatch):
def _boom():
raise AssertionError("identity lookup must not run for built-ins")
monkeypatch.setattr(rp, "load_config", _boom)
from tui_gateway.server import _runtime_model_config
agent = _custom_agent()
agent.provider = "anthropic"
agent.base_url = "https://api.anthropic.com"
assert _runtime_model_config(agent)["provider"] == "anthropic"
def test_persisted_session_round_trip_restores_entry_credentials(monkeypatch):
monkeypatch.setattr(rp, "load_config", lambda: LEGACY_LIST_CONFIG)
from tui_gateway.server import _runtime_model_config, _stored_session_runtime_overrides
model_config = _runtime_model_config(_custom_agent())
row = {
"model": "mimo-v2.5-pro",
"model_config": json.dumps(model_config),
}
overrides = _stored_session_runtime_overrides(row)
assert overrides["model_override"]["provider"] == "custom:mimo-v2.5-pro"
kwargs = _make_agent_with_override(overrides["model_override"], monkeypatch, LEGACY_LIST_CONFIG)
assert kwargs["provider"] == "custom"
assert kwargs["base_url"] == MIMO_URL
assert kwargs["api_key"] == MIMO_KEY
def test_legacy_row_with_bare_custom_heals_via_base_url(monkeypatch):
override = {
"model": "mimo-v2.5-pro",
"provider": "custom",
"base_url": MIMO_URL,
"api_mode": "chat_completions",
}
kwargs = _make_agent_with_override(override, monkeypatch, LEGACY_LIST_CONFIG)
assert kwargs["provider"] == "custom"
assert kwargs["base_url"] == MIMO_URL
assert kwargs["api_key"] == MIMO_KEY
def test_legacy_row_without_matching_entry_keeps_endpoint(monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
override = {
"model": "local-model",
"provider": "custom",
"base_url": "http://127.0.0.1:8000/v1",
"api_mode": "chat_completions",
}
kwargs = _make_agent_with_override(override, monkeypatch, {})
assert kwargs["provider"] == "custom"
assert kwargs["base_url"] == "http://127.0.0.1:8000/v1"
assert kwargs["api_key"] == "no-key-required"

View File

@@ -1548,6 +1548,17 @@ def _runtime_model_config(agent, existing: dict | None = None) -> dict:
if model:
config["model"] = model
if provider:
if provider == "custom" and base_url:
# ``agent.provider`` is the resolved provider. For named custom
# endpoints that is just "custom", so persisting it loses the
# configured entry identity needed to resolve key_env/api_key on a
# future Desktop/TUI resume.
try:
from hermes_cli.runtime_provider import find_custom_provider_identity
provider = find_custom_provider_identity(base_url) or provider
except Exception:
logger.debug("custom provider identity lookup failed", exc_info=True)
config["provider"] = provider
if base_url:
config["base_url"] = base_url
@@ -3299,9 +3310,22 @@ def _make_agent(
override_base_url = model_override.get("base_url")
override_api_key = model_override.get("api_key")
override_api_mode = model_override.get("api_mode")
resolve_kwargs = {}
if override_base_url and str(requested_provider or "").strip().lower() == "custom":
# Heal rows persisted before _runtime_model_config kept the named
# custom provider identity. If no configured entry owns the URL,
# still pass it to the direct-custom path so local endpoints remain
# resumable instead of falling through to the global provider.
from hermes_cli.runtime_provider import find_custom_provider_identity
recovered = find_custom_provider_identity(override_base_url)
if recovered:
requested_provider = recovered
resolve_kwargs["explicit_base_url"] = override_base_url
runtime = resolve_runtime_provider(
requested=requested_provider,
target_model=model or None,
**resolve_kwargs,
)
# The switch already resolved concrete credentials/endpoint; honor them
# so a custom/named endpoint survives the rebuild even if global