diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 586962d102e..1d77fffa92f 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -43,7 +43,7 @@ import yaml from hermes_cli.config import get_hermes_home, get_config_path, read_raw_config from hermes_constants import OPENROUTER_BASE_URL -from utils import atomic_replace, is_truthy_value +from utils import atomic_replace, atomic_yaml_write, is_truthy_value logger = logging.getLogger(__name__) @@ -3653,7 +3653,7 @@ def _update_config_for_provider( config["model"] = model_cfg - config_path.write_text(yaml.safe_dump(config, sort_keys=False)) + atomic_yaml_write(config_path, config, sort_keys=False) return config_path @@ -3712,7 +3712,7 @@ def _reset_config_provider() -> Path: model["provider"] = "auto" if "base_url" in model: model["base_url"] = OPENROUTER_BASE_URL - config_path.write_text(yaml.safe_dump(config, sort_keys=False)) + atomic_yaml_write(config_path, config, sort_keys=False) return config_path diff --git a/tests/hermes_cli/test_auth_commands.py b/tests/hermes_cli/test_auth_commands.py index 824d0608c07..50f639d08ac 100644 --- a/tests/hermes_cli/test_auth_commands.py +++ b/tests/hermes_cli/test_auth_commands.py @@ -5,8 +5,10 @@ from __future__ import annotations import base64 import json from datetime import datetime, timezone +from unittest.mock import patch import pytest +import yaml def _write_auth_store(tmp_path, payload: dict) -> None: @@ -589,6 +591,39 @@ def test_logout_clears_stale_active_codex_without_provider_credentials(tmp_path, assert "provider: auto" in config_text +def test_reset_config_provider_uses_atomic_yaml_write(tmp_path, monkeypatch): + """Logout config reset should delegate the YAML write atomically.""" + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + config_path = hermes_home / "config.yaml" + original = { + "model": { + "default": "gpt-5.3-codex", + "provider": "openai-codex", + "base_url": "https://chatgpt.com/backend-api/codex", + } + } + config_path.write_text(yaml.safe_dump(original, sort_keys=False), encoding="utf-8") + original_text = config_path.read_text(encoding="utf-8") + + from hermes_cli.auth import _reset_config_provider + + def _boom(path, data, **kwargs): + assert path == config_path + assert data["model"]["provider"] == "auto" + assert data["model"]["base_url"] == "https://openrouter.ai/api/v1" + assert kwargs["sort_keys"] is False + raise OSError("simulated atomic write failure") + + with patch("hermes_cli.auth.atomic_yaml_write", side_effect=_boom) as mock_write: + with pytest.raises(OSError, match="simulated atomic write failure"): + _reset_config_provider() + + assert mock_write.call_count == 1 + assert config_path.read_text(encoding="utf-8") == original_text + + def test_auth_list_does_not_call_mutating_select(monkeypatch, capsys): from hermes_cli.auth_commands import auth_list_command diff --git a/tests/hermes_cli/test_model_provider_persistence.py b/tests/hermes_cli/test_model_provider_persistence.py index 2a827ca7ef2..8808e009b4a 100644 --- a/tests/hermes_cli/test_model_provider_persistence.py +++ b/tests/hermes_cli/test_model_provider_persistence.py @@ -71,6 +71,32 @@ class TestSaveModelChoiceAlwaysDict: class TestProviderPersistsAfterModelSave: + def test_update_config_for_provider_uses_atomic_yaml_write(self, config_home): + """Provider switches should delegate config writes to atomic_yaml_write.""" + from hermes_cli.auth import _update_config_for_provider + + config_path = config_home / "config.yaml" + original_text = config_path.read_text(encoding="utf-8") + + def _boom(path, data, **kwargs): + assert path == config_path + assert data["model"]["provider"] == "nous" + assert data["model"]["base_url"] == "https://inference.example.com/v1" + assert data["model"]["default"] == "some-old-model" + assert kwargs["sort_keys"] is False + raise OSError("simulated atomic write failure") + + with patch("hermes_cli.auth.atomic_yaml_write", side_effect=_boom) as mock_write: + with pytest.raises(OSError, match="simulated atomic write failure"): + _update_config_for_provider( + "nous", + "https://inference.example.com/v1/", + default_model="llama-3.3", + ) + + assert mock_write.call_count == 1 + assert config_path.read_text(encoding="utf-8") == original_text + def test_api_key_provider_saved_when_model_was_string(self, config_home, monkeypatch): """_model_flow_api_key_provider must persist the provider even when config.model started as a plain string."""