mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 23:41:35 +08:00
Compare commits
8 Commits
fix/docker
...
fix/status
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2aff451ed | ||
|
|
4e91b0240b | ||
|
|
5e92a4ce5a | ||
|
|
471c663fdf | ||
|
|
64d333204b | ||
|
|
c44af43840 | ||
|
|
b117bbc125 | ||
|
|
b59da08730 |
61
cli.py
61
cli.py
@@ -3484,6 +3484,56 @@ class HermesCLI:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Error generating insights: {e}")
|
print(f" Error generating insights: {e}")
|
||||||
|
|
||||||
|
def _check_config_mcp_changes(self) -> None:
|
||||||
|
"""Detect mcp_servers changes in config.yaml and auto-reload MCP connections.
|
||||||
|
|
||||||
|
Called from process_loop every CONFIG_WATCH_INTERVAL seconds.
|
||||||
|
Compares config.yaml mtime + mcp_servers section against the last
|
||||||
|
known state. When a change is detected, triggers _reload_mcp() and
|
||||||
|
informs the user so they know the tool list has been refreshed.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import yaml as _yaml
|
||||||
|
|
||||||
|
CONFIG_WATCH_INTERVAL = 5.0 # seconds between config.yaml stat() calls
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - self._last_config_check < CONFIG_WATCH_INTERVAL:
|
||||||
|
return
|
||||||
|
self._last_config_check = now
|
||||||
|
|
||||||
|
from hermes_cli.config import get_config_path as _get_config_path
|
||||||
|
cfg_path = _get_config_path()
|
||||||
|
if not cfg_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
mtime = cfg_path.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
|
||||||
|
if mtime == self._config_mtime:
|
||||||
|
return # File unchanged — fast path
|
||||||
|
|
||||||
|
# File changed — check whether mcp_servers section changed
|
||||||
|
self._config_mtime = mtime
|
||||||
|
try:
|
||||||
|
with open(cfg_path, encoding="utf-8") as f:
|
||||||
|
new_cfg = _yaml.safe_load(f) or {}
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_mcp = new_cfg.get("mcp_servers") or {}
|
||||||
|
if new_mcp == self._config_mcp_servers:
|
||||||
|
return # mcp_servers unchanged (some other section was edited)
|
||||||
|
|
||||||
|
self._config_mcp_servers = new_mcp
|
||||||
|
# Notify user and reload
|
||||||
|
print()
|
||||||
|
print("🔄 MCP server config changed — reloading connections...")
|
||||||
|
with self._busy_command(self._slow_command_status("/reload-mcp")):
|
||||||
|
self._reload_mcp()
|
||||||
|
|
||||||
def _reload_mcp(self):
|
def _reload_mcp(self):
|
||||||
"""Reload MCP servers: disconnect all, re-read config.yaml, reconnect.
|
"""Reload MCP servers: disconnect all, re-read config.yaml, reconnect.
|
||||||
|
|
||||||
@@ -4749,6 +4799,12 @@ class HermesCLI:
|
|||||||
self._interrupt_queue = queue.Queue() # For messages typed while agent is running
|
self._interrupt_queue = queue.Queue() # For messages typed while agent is running
|
||||||
self._should_exit = False
|
self._should_exit = False
|
||||||
self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit
|
self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit
|
||||||
|
# Config file watcher — detect mcp_servers changes and auto-reload
|
||||||
|
from hermes_cli.config import get_config_path as _get_config_path
|
||||||
|
_cfg_path = _get_config_path()
|
||||||
|
self._config_mtime: float = _cfg_path.stat().st_mtime if _cfg_path.exists() else 0.0
|
||||||
|
self._config_mcp_servers: dict = self.config.get("mcp_servers") or {}
|
||||||
|
self._last_config_check: float = 0.0 # monotonic time of last check
|
||||||
|
|
||||||
# Clarify tool state: interactive question/answer with the user.
|
# Clarify tool state: interactive question/answer with the user.
|
||||||
# When the agent calls the clarify tool, _clarify_state is set and
|
# When the agent calls the clarify tool, _clarify_state is set and
|
||||||
@@ -4797,7 +4853,7 @@ class HermesCLI:
|
|||||||
# Ensure tirith security scanner is available (downloads if needed)
|
# Ensure tirith security scanner is available (downloads if needed)
|
||||||
try:
|
try:
|
||||||
from tools.tirith_security import ensure_installed
|
from tools.tirith_security import ensure_installed
|
||||||
ensure_installed()
|
ensure_installed(log_failures=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Non-fatal — fail-open at scan time if unavailable
|
pass # Non-fatal — fail-open at scan time if unavailable
|
||||||
|
|
||||||
@@ -5682,6 +5738,9 @@ class HermesCLI:
|
|||||||
try:
|
try:
|
||||||
user_input = self._pending_input.get(timeout=0.1)
|
user_input = self._pending_input.get(timeout=0.1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
|
# Periodic config watcher — auto-reload MCP on mcp_servers change
|
||||||
|
if not self._agent_running:
|
||||||
|
self._check_config_mcp_changes()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not user_input:
|
if not user_input:
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ class GatewayRunner:
|
|||||||
# Ensure tirith security scanner is available (downloads if needed)
|
# Ensure tirith security scanner is available (downloads if needed)
|
||||||
try:
|
try:
|
||||||
from tools.tirith_security import ensure_installed
|
from tools.tirith_security import ensure_installed
|
||||||
ensure_installed()
|
ensure_installed(log_failures=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Non-fatal — fail-open at scan time if unavailable
|
pass # Non-fatal — fail-open at scan time if unavailable
|
||||||
|
|
||||||
@@ -1114,6 +1114,9 @@ class GatewayRunner:
|
|||||||
# let the adapter-level batching/queueing logic absorb them.
|
# let the adapter-level batching/queueing logic absorb them.
|
||||||
_quick_key = build_session_key(source)
|
_quick_key = build_session_key(source)
|
||||||
if _quick_key in self._running_agents:
|
if _quick_key in self._running_agents:
|
||||||
|
if event.get_command() == "status":
|
||||||
|
return await self._handle_status_command(event)
|
||||||
|
|
||||||
if event.message_type == MessageType.PHOTO:
|
if event.message_type == MessageType.PHOTO:
|
||||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||||
adapter = self.adapters.get(source.platform)
|
adapter = self.adapters.get(source.platform)
|
||||||
@@ -1822,6 +1825,8 @@ class GatewayRunner:
|
|||||||
# Update session with actual prompt token count and model from the agent
|
# Update session with actual prompt token count and model from the agent
|
||||||
self.session_store.update_session(
|
self.session_store.update_session(
|
||||||
session_entry.session_key,
|
session_entry.session_key,
|
||||||
|
input_tokens=agent_result.get("input_tokens", 0),
|
||||||
|
output_tokens=agent_result.get("output_tokens", 0),
|
||||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||||
model=agent_result.get("model"),
|
model=agent_result.get("model"),
|
||||||
)
|
)
|
||||||
@@ -4171,11 +4176,15 @@ class GatewayRunner:
|
|||||||
# Return final response, or a message if something went wrong
|
# Return final response, or a message if something went wrong
|
||||||
final_response = result.get("final_response")
|
final_response = result.get("final_response")
|
||||||
|
|
||||||
# Extract last actual prompt token count from the agent's compressor
|
# Extract actual token counts from the agent instance used for this run
|
||||||
_last_prompt_toks = 0
|
_last_prompt_toks = 0
|
||||||
|
_input_toks = 0
|
||||||
|
_output_toks = 0
|
||||||
_agent = agent_holder[0]
|
_agent = agent_holder[0]
|
||||||
if _agent and hasattr(_agent, "context_compressor"):
|
if _agent and hasattr(_agent, "context_compressor"):
|
||||||
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
|
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
|
||||||
|
_input_toks = getattr(_agent, "session_prompt_tokens", 0)
|
||||||
|
_output_toks = getattr(_agent, "session_completion_tokens", 0)
|
||||||
_resolved_model = getattr(_agent, "model", None) if _agent else None
|
_resolved_model = getattr(_agent, "model", None) if _agent else None
|
||||||
|
|
||||||
if not final_response:
|
if not final_response:
|
||||||
@@ -4187,6 +4196,8 @@ class GatewayRunner:
|
|||||||
"tools": tools_holder[0] or [],
|
"tools": tools_holder[0] or [],
|
||||||
"history_offset": len(agent_history),
|
"history_offset": len(agent_history),
|
||||||
"last_prompt_tokens": _last_prompt_toks,
|
"last_prompt_tokens": _last_prompt_toks,
|
||||||
|
"input_tokens": _input_toks,
|
||||||
|
"output_tokens": _output_toks,
|
||||||
"model": _resolved_model,
|
"model": _resolved_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4250,6 +4261,8 @@ class GatewayRunner:
|
|||||||
"tools": tools_holder[0] or [],
|
"tools": tools_holder[0] or [],
|
||||||
"history_offset": len(agent_history),
|
"history_offset": len(agent_history),
|
||||||
"last_prompt_tokens": _last_prompt_toks,
|
"last_prompt_tokens": _last_prompt_toks,
|
||||||
|
"input_tokens": _input_toks,
|
||||||
|
"output_tokens": _output_toks,
|
||||||
"model": _resolved_model,
|
"model": _resolved_model,
|
||||||
"session_id": effective_session_id,
|
"session_id": effective_session_id,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -927,6 +927,11 @@ class HonchoSessionManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||||
|
honcho_session = self._sessions_cache.get(session.honcho_session_id)
|
||||||
|
if not honcho_session:
|
||||||
|
logger.warning("No Honcho session cached for '%s', skipping AI seed", session_key)
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
wrapped = (
|
wrapped = (
|
||||||
f"<ai_identity_seed>\n"
|
f"<ai_identity_seed>\n"
|
||||||
@@ -935,7 +940,7 @@ class HonchoSessionManager:
|
|||||||
f"{content.strip()}\n"
|
f"{content.strip()}\n"
|
||||||
f"</ai_identity_seed>"
|
f"</ai_identity_seed>"
|
||||||
)
|
)
|
||||||
assistant_peer.add_message("assistant", wrapped)
|
honcho_session.add_messages([assistant_peer.message(wrapped)])
|
||||||
logger.info("Seeded AI identity from '%s' into %s", source, session_key)
|
logger.info("Seeded AI identity from '%s' into %s", source, session_key)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
133
tests/gateway/test_status_command.py
Normal file
133
tests/gateway/test_status_command.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""Tests for gateway /status behavior and token persistence."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||||
|
from gateway.platforms.base import MessageEvent
|
||||||
|
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||||
|
|
||||||
|
|
||||||
|
def _make_source() -> SessionSource:
|
||||||
|
return SessionSource(
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
user_id="u1",
|
||||||
|
chat_id="c1",
|
||||||
|
user_name="tester",
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_event(text: str) -> MessageEvent:
|
||||||
|
return MessageEvent(
|
||||||
|
text=text,
|
||||||
|
source=_make_source(),
|
||||||
|
message_id="m1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runner(session_entry: SessionEntry):
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner.config = GatewayConfig(
|
||||||
|
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||||
|
)
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.send = AsyncMock()
|
||||||
|
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||||
|
runner._voice_mode = {}
|
||||||
|
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||||
|
runner.session_store = MagicMock()
|
||||||
|
runner.session_store.get_or_create_session.return_value = session_entry
|
||||||
|
runner.session_store.load_transcript.return_value = []
|
||||||
|
runner.session_store.has_any_sessions.return_value = True
|
||||||
|
runner.session_store.append_to_transcript = MagicMock()
|
||||||
|
runner.session_store.rewrite_transcript = MagicMock()
|
||||||
|
runner.session_store.update_session = MagicMock()
|
||||||
|
runner._running_agents = {}
|
||||||
|
runner._pending_messages = {}
|
||||||
|
runner._pending_approvals = {}
|
||||||
|
runner._session_db = None
|
||||||
|
runner._reasoning_config = None
|
||||||
|
runner._provider_routing = {}
|
||||||
|
runner._fallback_model = None
|
||||||
|
runner._show_reasoning = False
|
||||||
|
runner._is_user_authorized = lambda _source: True
|
||||||
|
runner._set_session_env = lambda _context: None
|
||||||
|
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
|
||||||
|
runner._send_voice_reply = AsyncMock()
|
||||||
|
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
|
||||||
|
runner._emit_gateway_run_progress = AsyncMock()
|
||||||
|
return runner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_command_reports_running_agent_without_interrupt(monkeypatch):
|
||||||
|
session_entry = SessionEntry(
|
||||||
|
session_key=build_session_key(_make_source()),
|
||||||
|
session_id="sess-1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_type="dm",
|
||||||
|
total_tokens=321,
|
||||||
|
)
|
||||||
|
runner = _make_runner(session_entry)
|
||||||
|
running_agent = MagicMock()
|
||||||
|
runner._running_agents[build_session_key(_make_source())] = running_agent
|
||||||
|
|
||||||
|
result = await runner._handle_message(_make_event("/status"))
|
||||||
|
|
||||||
|
assert "**Tokens:** 321" in result
|
||||||
|
assert "**Agent Running:** Yes ⚡" in result
|
||||||
|
running_agent.interrupt.assert_not_called()
|
||||||
|
assert runner._pending_messages == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||||
|
import gateway.run as gateway_run
|
||||||
|
|
||||||
|
session_entry = SessionEntry(
|
||||||
|
session_key=build_session_key(_make_source()),
|
||||||
|
session_id="sess-1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
runner = _make_runner(session_entry)
|
||||||
|
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||||
|
runner._run_agent = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"final_response": "ok",
|
||||||
|
"messages": [],
|
||||||
|
"tools": [],
|
||||||
|
"history_offset": 0,
|
||||||
|
"last_prompt_tokens": 80,
|
||||||
|
"input_tokens": 120,
|
||||||
|
"output_tokens": 45,
|
||||||
|
"model": "openai/test-model",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"agent.model_metadata.get_model_context_length",
|
||||||
|
lambda *_args, **_kwargs: 100000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await runner._handle_message(_make_event("hello"))
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
runner.session_store.update_session.assert_called_once_with(
|
||||||
|
session_entry.session_key,
|
||||||
|
input_tokens=120,
|
||||||
|
output_tokens=45,
|
||||||
|
last_prompt_tokens=80,
|
||||||
|
model="openai/test-model",
|
||||||
|
)
|
||||||
@@ -68,6 +68,22 @@ class TestAtomicJsonWrite:
|
|||||||
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
||||||
assert len(tmp_files) == 0
|
assert len(tmp_files) == 0
|
||||||
|
|
||||||
|
def test_cleans_up_temp_file_on_baseexception(self, tmp_path):
|
||||||
|
class SimulatedAbort(BaseException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
target = tmp_path / "data.json"
|
||||||
|
original = {"preserved": True}
|
||||||
|
target.write_text(json.dumps(original), encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("utils.json.dump", side_effect=SimulatedAbort):
|
||||||
|
with pytest.raises(SimulatedAbort):
|
||||||
|
atomic_json_write(target, {"new": True})
|
||||||
|
|
||||||
|
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
||||||
|
assert len(tmp_files) == 0
|
||||||
|
assert json.loads(target.read_text(encoding="utf-8")) == original
|
||||||
|
|
||||||
def test_accepts_string_path(self, tmp_path):
|
def test_accepts_string_path(self, tmp_path):
|
||||||
target = str(tmp_path / "string_path.json")
|
target = str(tmp_path / "string_path.json")
|
||||||
atomic_json_write(target, {"string": True})
|
atomic_json_write(target, {"string": True})
|
||||||
|
|||||||
44
tests/test_atomic_yaml_write.py
Normal file
44
tests/test_atomic_yaml_write.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Tests for utils.atomic_yaml_write — crash-safe YAML file writes."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from utils import atomic_yaml_write
|
||||||
|
|
||||||
|
|
||||||
|
class TestAtomicYamlWrite:
|
||||||
|
def test_writes_valid_yaml(self, tmp_path):
|
||||||
|
target = tmp_path / "data.yaml"
|
||||||
|
data = {"key": "value", "nested": {"a": 1}}
|
||||||
|
|
||||||
|
atomic_yaml_write(target, data)
|
||||||
|
|
||||||
|
assert yaml.safe_load(target.read_text(encoding="utf-8")) == data
|
||||||
|
|
||||||
|
def test_cleans_up_temp_file_on_baseexception(self, tmp_path):
|
||||||
|
class SimulatedAbort(BaseException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
target = tmp_path / "data.yaml"
|
||||||
|
original = {"preserved": True}
|
||||||
|
target.write_text(yaml.safe_dump(original), encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("utils.yaml.dump", side_effect=SimulatedAbort):
|
||||||
|
with pytest.raises(SimulatedAbort):
|
||||||
|
atomic_yaml_write(target, {"new": True})
|
||||||
|
|
||||||
|
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
||||||
|
assert len(tmp_files) == 0
|
||||||
|
assert yaml.safe_load(target.read_text(encoding="utf-8")) == original
|
||||||
|
|
||||||
|
def test_appends_extra_content(self, tmp_path):
|
||||||
|
target = tmp_path / "data.yaml"
|
||||||
|
|
||||||
|
atomic_yaml_write(target, {"key": "value"}, extra_content="\n# comment\n")
|
||||||
|
|
||||||
|
text = target.read_text(encoding="utf-8")
|
||||||
|
assert "key: value" in text
|
||||||
|
assert "# comment" in text
|
||||||
103
tests/test_cli_mcp_config_watch.py
Normal file
103
tests/test_cli_mcp_config_watch.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""Tests for automatic MCP reload when config.yaml mcp_servers section changes."""
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cli(tmp_path, mcp_servers=None):
|
||||||
|
"""Create a minimal HermesCLI instance with mocked config."""
|
||||||
|
import cli as cli_mod
|
||||||
|
obj = object.__new__(cli_mod.HermesCLI)
|
||||||
|
obj.config = {"mcp_servers": mcp_servers or {}}
|
||||||
|
obj._agent_running = False
|
||||||
|
obj._last_config_check = 0.0
|
||||||
|
obj._config_mcp_servers = mcp_servers or {}
|
||||||
|
|
||||||
|
cfg_file = tmp_path / "config.yaml"
|
||||||
|
cfg_file.write_text("mcp_servers: {}\n")
|
||||||
|
obj._config_mtime = cfg_file.stat().st_mtime
|
||||||
|
|
||||||
|
obj._reload_mcp = MagicMock()
|
||||||
|
obj._busy_command = MagicMock()
|
||||||
|
obj._busy_command.return_value.__enter__ = MagicMock(return_value=None)
|
||||||
|
obj._busy_command.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
obj._slow_command_status = MagicMock(return_value="reloading...")
|
||||||
|
|
||||||
|
return obj, cfg_file
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPConfigWatch:
|
||||||
|
|
||||||
|
def test_no_change_does_not_reload(self, tmp_path):
|
||||||
|
"""If mtime and mcp_servers unchanged, _reload_mcp is NOT called."""
|
||||||
|
obj, cfg_file = _make_cli(tmp_path)
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
|
||||||
|
obj._check_config_mcp_changes()
|
||||||
|
|
||||||
|
obj._reload_mcp.assert_not_called()
|
||||||
|
|
||||||
|
def test_mtime_change_with_same_mcp_servers_does_not_reload(self, tmp_path):
|
||||||
|
"""If file mtime changes but mcp_servers is identical, no reload."""
|
||||||
|
import yaml
|
||||||
|
obj, cfg_file = _make_cli(tmp_path, mcp_servers={"fs": {"command": "npx"}})
|
||||||
|
|
||||||
|
# Write same mcp_servers but touch the file
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": {"fs": {"command": "npx"}}}))
|
||||||
|
# Force mtime to appear changed
|
||||||
|
obj._config_mtime = 0.0
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
|
||||||
|
obj._check_config_mcp_changes()
|
||||||
|
|
||||||
|
obj._reload_mcp.assert_not_called()
|
||||||
|
|
||||||
|
def test_new_mcp_server_triggers_reload(self, tmp_path):
|
||||||
|
"""Adding a new MCP server to config triggers auto-reload."""
|
||||||
|
import yaml
|
||||||
|
obj, cfg_file = _make_cli(tmp_path, mcp_servers={})
|
||||||
|
|
||||||
|
# Simulate user adding a new MCP server to config.yaml
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": {"github": {"url": "https://mcp.github.com"}}}))
|
||||||
|
obj._config_mtime = 0.0 # force stale mtime
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
|
||||||
|
obj._check_config_mcp_changes()
|
||||||
|
|
||||||
|
obj._reload_mcp.assert_called_once()
|
||||||
|
|
||||||
|
def test_removed_mcp_server_triggers_reload(self, tmp_path):
|
||||||
|
"""Removing an MCP server from config triggers auto-reload."""
|
||||||
|
import yaml
|
||||||
|
obj, cfg_file = _make_cli(tmp_path, mcp_servers={"github": {"url": "https://mcp.github.com"}})
|
||||||
|
|
||||||
|
# Simulate user removing the server
|
||||||
|
cfg_file.write_text(yaml.dump({"mcp_servers": {}}))
|
||||||
|
obj._config_mtime = 0.0
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
|
||||||
|
obj._check_config_mcp_changes()
|
||||||
|
|
||||||
|
obj._reload_mcp.assert_called_once()
|
||||||
|
|
||||||
|
def test_interval_throttle_skips_check(self, tmp_path):
|
||||||
|
"""If called within CONFIG_WATCH_INTERVAL, stat() is skipped."""
|
||||||
|
obj, cfg_file = _make_cli(tmp_path)
|
||||||
|
obj._last_config_check = time.monotonic() # just checked
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.get_config_path", return_value=cfg_file), \
|
||||||
|
patch.object(Path, "stat") as mock_stat:
|
||||||
|
obj._check_config_mcp_changes()
|
||||||
|
mock_stat.assert_not_called()
|
||||||
|
|
||||||
|
obj._reload_mcp.assert_not_called()
|
||||||
|
|
||||||
|
def test_missing_config_file_does_not_crash(self, tmp_path):
|
||||||
|
"""If config.yaml doesn't exist, _check_config_mcp_changes is a no-op."""
|
||||||
|
obj, cfg_file = _make_cli(tmp_path)
|
||||||
|
missing = tmp_path / "nonexistent.yaml"
|
||||||
|
|
||||||
|
with patch("hermes_cli.config.get_config_path", return_value=missing):
|
||||||
|
obj._check_config_mcp_changes() # should not raise
|
||||||
|
|
||||||
|
obj._reload_mcp.assert_not_called()
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
"""Tests for tools/checkpoint_manager.py — CheckpointManager."""
|
"""Tests for tools/checkpoint_manager.py — CheckpointManager."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
import pytest
|
import pytest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@@ -143,6 +145,12 @@ class TestTakeCheckpoint:
|
|||||||
result = mgr.ensure_checkpoint(str(work_dir), "initial")
|
result = mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
def test_successful_checkpoint_does_not_log_expected_diff_exit(self, mgr, work_dir, caplog):
|
||||||
|
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||||
|
result = mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||||
|
assert result is True
|
||||||
|
assert not any("diff --cached --quiet" in r.getMessage() for r in caplog.records)
|
||||||
|
|
||||||
def test_dedup_same_turn(self, mgr, work_dir):
|
def test_dedup_same_turn(self, mgr, work_dir):
|
||||||
r1 = mgr.ensure_checkpoint(str(work_dir), "first")
|
r1 = mgr.ensure_checkpoint(str(work_dir), "first")
|
||||||
r2 = mgr.ensure_checkpoint(str(work_dir), "second")
|
r2 = mgr.ensure_checkpoint(str(work_dir), "second")
|
||||||
@@ -375,6 +383,26 @@ class TestErrorResilience:
|
|||||||
result = mgr.ensure_checkpoint(str(work_dir), "test")
|
result = mgr.ensure_checkpoint(str(work_dir), "test")
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog):
|
||||||
|
completed = subprocess.CompletedProcess(
|
||||||
|
args=["git", "diff", "--cached", "--quiet"],
|
||||||
|
returncode=1,
|
||||||
|
stdout="",
|
||||||
|
stderr="",
|
||||||
|
)
|
||||||
|
with patch("tools.checkpoint_manager.subprocess.run", return_value=completed):
|
||||||
|
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||||
|
ok, stdout, stderr = _run_git(
|
||||||
|
["diff", "--cached", "--quiet"],
|
||||||
|
tmp_path / "shadow",
|
||||||
|
str(tmp_path / "work"),
|
||||||
|
allowed_returncodes={1},
|
||||||
|
)
|
||||||
|
assert ok is False
|
||||||
|
assert stdout == ""
|
||||||
|
assert stderr == ""
|
||||||
|
assert not caplog.records
|
||||||
|
|
||||||
def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch):
|
def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch):
|
||||||
"""Checkpoint failures should never raise — they're silently logged."""
|
"""Checkpoint failures should never raise — they're silently logged."""
|
||||||
def broken_run_git(*args, **kwargs):
|
def broken_run_git(*args, **kwargs):
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ handling without requiring a running terminal environment.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from tools.file_tools import (
|
from tools.file_tools import (
|
||||||
@@ -87,13 +88,26 @@ class TestWriteFileHandler:
|
|||||||
mock_ops.write_file.assert_called_once_with("/tmp/out.txt", "hello world!\n")
|
mock_ops.write_file.assert_called_once_with("/tmp/out.txt", "hello world!\n")
|
||||||
|
|
||||||
@patch("tools.file_tools._get_file_ops")
|
@patch("tools.file_tools._get_file_ops")
|
||||||
def test_exception_returns_error_json(self, mock_get):
|
def test_permission_error_returns_error_json_without_error_log(self, mock_get, caplog):
|
||||||
mock_get.side_effect = PermissionError("read-only filesystem")
|
mock_get.side_effect = PermissionError("read-only filesystem")
|
||||||
|
|
||||||
from tools.file_tools import write_file_tool
|
from tools.file_tools import write_file_tool
|
||||||
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
|
with caplog.at_level(logging.DEBUG, logger="tools.file_tools"):
|
||||||
|
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
|
||||||
assert "error" in result
|
assert "error" in result
|
||||||
assert "read-only" in result["error"]
|
assert "read-only" in result["error"]
|
||||||
|
assert any("write_file expected denial" in r.getMessage() for r in caplog.records)
|
||||||
|
assert not any(r.levelno >= logging.ERROR for r in caplog.records)
|
||||||
|
|
||||||
|
@patch("tools.file_tools._get_file_ops")
|
||||||
|
def test_unexpected_exception_still_logs_error(self, mock_get, caplog):
|
||||||
|
mock_get.side_effect = RuntimeError("boom")
|
||||||
|
|
||||||
|
from tools.file_tools import write_file_tool
|
||||||
|
with caplog.at_level(logging.ERROR, logger="tools.file_tools"):
|
||||||
|
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
|
||||||
|
assert result["error"] == "boom"
|
||||||
|
assert any("write_file error" in r.getMessage() for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
class TestPatchHandler:
|
class TestPatchHandler:
|
||||||
|
|||||||
@@ -315,6 +315,23 @@ class TestEnsureInstalled:
|
|||||||
mock_thread.start.assert_called_once()
|
mock_thread.start.assert_called_once()
|
||||||
_tirith_mod._resolved_path = None
|
_tirith_mod._resolved_path = None
|
||||||
|
|
||||||
|
@patch("tools.tirith_security._load_security_config")
|
||||||
|
def test_startup_prefetch_can_suppress_install_failure_logs(self, mock_cfg):
|
||||||
|
mock_cfg.return_value = {"tirith_enabled": True, "tirith_path": "tirith",
|
||||||
|
"tirith_timeout": 5, "tirith_fail_open": True}
|
||||||
|
_tirith_mod._resolved_path = None
|
||||||
|
with patch("tools.tirith_security.shutil.which", return_value=None), \
|
||||||
|
patch("tools.tirith_security._hermes_bin_dir", return_value="/nonexistent"), \
|
||||||
|
patch("tools.tirith_security._is_install_failed_on_disk", return_value=False), \
|
||||||
|
patch("tools.tirith_security.threading.Thread") as MockThread:
|
||||||
|
mock_thread = MagicMock()
|
||||||
|
MockThread.return_value = mock_thread
|
||||||
|
result = ensure_installed(log_failures=False)
|
||||||
|
assert result is None
|
||||||
|
assert MockThread.call_args.kwargs["kwargs"] == {"log_failures": False}
|
||||||
|
mock_thread.start.assert_called_once()
|
||||||
|
_tirith_mod._resolved_path = None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Failed download caches the miss (Finding #1)
|
# Failed download caches the miss (Finding #1)
|
||||||
@@ -516,6 +533,22 @@ class TestCosignVerification:
|
|||||||
assert path is None
|
assert path is None
|
||||||
assert reason == "cosign_missing"
|
assert reason == "cosign_missing"
|
||||||
|
|
||||||
|
@patch("tools.tirith_security.logger.debug")
|
||||||
|
@patch("tools.tirith_security.logger.warning")
|
||||||
|
@patch("tools.tirith_security.shutil.which", return_value=None)
|
||||||
|
@patch("tools.tirith_security._download_file")
|
||||||
|
@patch("tools.tirith_security._detect_target", return_value="aarch64-apple-darwin")
|
||||||
|
def test_install_quiet_mode_downgrades_cosign_missing_log(self, mock_target, mock_dl,
|
||||||
|
mock_which, mock_warning,
|
||||||
|
mock_debug):
|
||||||
|
"""Startup prefetch should not surface cosign-missing as a warning."""
|
||||||
|
from tools.tirith_security import _install_tirith
|
||||||
|
path, reason = _install_tirith(log_failures=False)
|
||||||
|
assert path is None
|
||||||
|
assert reason == "cosign_missing"
|
||||||
|
mock_warning.assert_not_called()
|
||||||
|
mock_debug.assert_called()
|
||||||
|
|
||||||
@patch("tools.tirith_security._verify_cosign", return_value=None)
|
@patch("tools.tirith_security._verify_cosign", return_value=None)
|
||||||
@patch("tools.tirith_security.shutil.which", return_value="/usr/local/bin/cosign")
|
@patch("tools.tirith_security.shutil.which", return_value="/usr/local/bin/cosign")
|
||||||
@patch("tools.tirith_security._download_file")
|
@patch("tools.tirith_security._download_file")
|
||||||
|
|||||||
@@ -92,10 +92,17 @@ def _run_git(
|
|||||||
shadow_repo: Path,
|
shadow_repo: Path,
|
||||||
working_dir: str,
|
working_dir: str,
|
||||||
timeout: int = _GIT_TIMEOUT,
|
timeout: int = _GIT_TIMEOUT,
|
||||||
|
allowed_returncodes: Optional[Set[int]] = None,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""Run a git command against the shadow repo. Returns (ok, stdout, stderr)."""
|
"""Run a git command against the shadow repo. Returns (ok, stdout, stderr).
|
||||||
|
|
||||||
|
``allowed_returncodes`` suppresses error logging for known/expected non-zero
|
||||||
|
exits while preserving the normal ``ok = (returncode == 0)`` contract.
|
||||||
|
Example: ``git diff --cached --quiet`` returns 1 when changes exist.
|
||||||
|
"""
|
||||||
env = _git_env(shadow_repo, working_dir)
|
env = _git_env(shadow_repo, working_dir)
|
||||||
cmd = ["git"] + list(args)
|
cmd = ["git"] + list(args)
|
||||||
|
allowed_returncodes = allowed_returncodes or set()
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
cmd,
|
cmd,
|
||||||
@@ -108,7 +115,7 @@ def _run_git(
|
|||||||
ok = result.returncode == 0
|
ok = result.returncode == 0
|
||||||
stdout = result.stdout.strip()
|
stdout = result.stdout.strip()
|
||||||
stderr = result.stderr.strip()
|
stderr = result.stderr.strip()
|
||||||
if not ok:
|
if not ok and result.returncode not in allowed_returncodes:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Git command failed: %s (rc=%d) stderr=%s",
|
"Git command failed: %s (rc=%d) stderr=%s",
|
||||||
" ".join(cmd), result.returncode, stderr,
|
" ".join(cmd), result.returncode, stderr,
|
||||||
@@ -381,7 +388,10 @@ class CheckpointManager:
|
|||||||
|
|
||||||
# Check if there's anything to commit
|
# Check if there's anything to commit
|
||||||
ok_diff, diff_out, _ = _run_git(
|
ok_diff, diff_out, _ = _run_git(
|
||||||
["diff", "--cached", "--quiet"], shadow, working_dir,
|
["diff", "--cached", "--quiet"],
|
||||||
|
shadow,
|
||||||
|
working_dir,
|
||||||
|
allowed_returncodes={1},
|
||||||
)
|
)
|
||||||
if ok_diff:
|
if ok_diff:
|
||||||
# No changes to commit
|
# No changes to commit
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""File Tools Module - LLM agent file manipulation tools."""
|
"""File Tools Module - LLM agent file manipulation tools."""
|
||||||
|
|
||||||
|
import errno
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -11,6 +12,18 @@ from agent.redact import redact_sensitive_text
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_EXPECTED_WRITE_ERRNOS = {errno.EACCES, errno.EPERM, errno.EROFS}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_expected_write_exception(exc: Exception) -> bool:
|
||||||
|
"""Return True for expected write denials that should not hit error logs."""
|
||||||
|
if isinstance(exc, PermissionError):
|
||||||
|
return True
|
||||||
|
if isinstance(exc, OSError) and exc.errno in _EXPECTED_WRITE_ERRNOS:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
_file_ops_lock = threading.Lock()
|
_file_ops_lock = threading.Lock()
|
||||||
_file_ops_cache: dict = {}
|
_file_ops_cache: dict = {}
|
||||||
|
|
||||||
@@ -238,7 +251,10 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
|||||||
result = file_ops.write_file(path, content)
|
result = file_ops.write_file(path, content)
|
||||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("write_file error: %s: %s", type(e).__name__, e)
|
if _is_expected_write_exception(e):
|
||||||
|
logger.debug("write_file expected denial: %s: %s", type(e).__name__, e)
|
||||||
|
else:
|
||||||
|
logger.error("write_file error: %s: %s", type(e).__name__, e, exc_info=True)
|
||||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ def _verify_checksum(archive_path: str, checksums_path: str, archive_name: str)
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _install_tirith() -> tuple[str | None, str]:
|
def _install_tirith(*, log_failures: bool = True) -> tuple[str | None, str]:
|
||||||
"""Download and install tirith to $HERMES_HOME/bin/tirith.
|
"""Download and install tirith to $HERMES_HOME/bin/tirith.
|
||||||
|
|
||||||
Verifies provenance via cosign and SHA-256 checksum.
|
Verifies provenance via cosign and SHA-256 checksum.
|
||||||
@@ -287,6 +287,8 @@ def _install_tirith() -> tuple[str | None, str]:
|
|||||||
failure_reason is a short tag used by the disk marker to decide if the
|
failure_reason is a short tag used by the disk marker to decide if the
|
||||||
failure is retryable (e.g. "cosign_missing" clears when cosign appears).
|
failure is retryable (e.g. "cosign_missing" clears when cosign appears).
|
||||||
"""
|
"""
|
||||||
|
log = logger.warning if log_failures else logger.debug
|
||||||
|
|
||||||
target = _detect_target()
|
target = _detect_target()
|
||||||
if not target:
|
if not target:
|
||||||
logger.info("tirith auto-install: unsupported platform %s/%s",
|
logger.info("tirith auto-install: unsupported platform %s/%s",
|
||||||
@@ -309,7 +311,7 @@ def _install_tirith() -> tuple[str | None, str]:
|
|||||||
_download_file(f"{base_url}/{archive_name}", archive_path)
|
_download_file(f"{base_url}/{archive_name}", archive_path)
|
||||||
_download_file(f"{base_url}/checksums.txt", checksums_path)
|
_download_file(f"{base_url}/checksums.txt", checksums_path)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("tirith download failed: %s", exc)
|
log("tirith download failed: %s", exc)
|
||||||
return None, "download_failed"
|
return None, "download_failed"
|
||||||
|
|
||||||
# Cosign provenance verification is mandatory for auto-install.
|
# Cosign provenance verification is mandatory for auto-install.
|
||||||
@@ -320,25 +322,25 @@ def _install_tirith() -> tuple[str | None, str]:
|
|||||||
_download_file(f"{base_url}/checksums.txt.sig", sig_path)
|
_download_file(f"{base_url}/checksums.txt.sig", sig_path)
|
||||||
_download_file(f"{base_url}/checksums.txt.pem", cert_path)
|
_download_file(f"{base_url}/checksums.txt.pem", cert_path)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("tirith install skipped: cosign artifacts unavailable (%s). "
|
log("tirith install skipped: cosign artifacts unavailable (%s). "
|
||||||
"Install tirith manually or install cosign for auto-install.", exc)
|
"Install tirith manually or install cosign for auto-install.", exc)
|
||||||
return None, "cosign_artifacts_unavailable"
|
return None, "cosign_artifacts_unavailable"
|
||||||
|
|
||||||
# Check cosign availability before attempting verification so we can
|
# Check cosign availability before attempting verification so we can
|
||||||
# distinguish "not installed" (retryable) from "installed but broken."
|
# distinguish "not installed" (retryable) from "installed but broken."
|
||||||
if not shutil.which("cosign"):
|
if not shutil.which("cosign"):
|
||||||
logger.warning("tirith install skipped: cosign not found on PATH. "
|
log("tirith install skipped: cosign not found on PATH. "
|
||||||
"Install cosign for auto-install, or install tirith manually.")
|
"Install cosign for auto-install, or install tirith manually.")
|
||||||
return None, "cosign_missing"
|
return None, "cosign_missing"
|
||||||
|
|
||||||
cosign_result = _verify_cosign(checksums_path, sig_path, cert_path)
|
cosign_result = _verify_cosign(checksums_path, sig_path, cert_path)
|
||||||
if cosign_result is not True:
|
if cosign_result is not True:
|
||||||
# False = verification rejected, None = execution failure (timeout/OSError)
|
# False = verification rejected, None = execution failure (timeout/OSError)
|
||||||
if cosign_result is None:
|
if cosign_result is None:
|
||||||
logger.warning("tirith install aborted: cosign execution failed")
|
log("tirith install aborted: cosign execution failed")
|
||||||
return None, "cosign_exec_failed"
|
return None, "cosign_exec_failed"
|
||||||
else:
|
else:
|
||||||
logger.warning("tirith install aborted: cosign provenance verification failed")
|
log("tirith install aborted: cosign provenance verification failed")
|
||||||
return None, "cosign_verification_failed"
|
return None, "cosign_verification_failed"
|
||||||
|
|
||||||
if not _verify_checksum(archive_path, checksums_path, archive_name):
|
if not _verify_checksum(archive_path, checksums_path, archive_name):
|
||||||
@@ -354,7 +356,7 @@ def _install_tirith() -> tuple[str | None, str]:
|
|||||||
tar.extract(member, tmpdir)
|
tar.extract(member, tmpdir)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("tirith binary not found in archive")
|
log("tirith binary not found in archive")
|
||||||
return None, "binary_not_in_archive"
|
return None, "binary_not_in_archive"
|
||||||
|
|
||||||
src = os.path.join(tmpdir, "tirith")
|
src = os.path.join(tmpdir, "tirith")
|
||||||
@@ -473,7 +475,7 @@ def _resolve_tirith_path(configured_path: str) -> str:
|
|||||||
return expanded
|
return expanded
|
||||||
|
|
||||||
|
|
||||||
def _background_install():
|
def _background_install(*, log_failures: bool = True):
|
||||||
"""Background thread target: download and install tirith."""
|
"""Background thread target: download and install tirith."""
|
||||||
global _resolved_path, _install_failure_reason
|
global _resolved_path, _install_failure_reason
|
||||||
with _install_lock:
|
with _install_lock:
|
||||||
@@ -494,7 +496,7 @@ def _background_install():
|
|||||||
_install_failure_reason = ""
|
_install_failure_reason = ""
|
||||||
return
|
return
|
||||||
|
|
||||||
installed, reason = _install_tirith()
|
installed, reason = _install_tirith(log_failures=log_failures)
|
||||||
if installed:
|
if installed:
|
||||||
_resolved_path = installed
|
_resolved_path = installed
|
||||||
_install_failure_reason = ""
|
_install_failure_reason = ""
|
||||||
@@ -505,7 +507,7 @@ def _background_install():
|
|||||||
_mark_install_failed(reason)
|
_mark_install_failed(reason)
|
||||||
|
|
||||||
|
|
||||||
def ensure_installed():
|
def ensure_installed(*, log_failures: bool = True):
|
||||||
"""Ensure tirith is available, downloading in background if needed.
|
"""Ensure tirith is available, downloading in background if needed.
|
||||||
|
|
||||||
Quick PATH/local checks are synchronous; network download runs in a
|
Quick PATH/local checks are synchronous; network download runs in a
|
||||||
@@ -578,7 +580,10 @@ def ensure_installed():
|
|||||||
# Need to download — launch background thread so startup doesn't block
|
# Need to download — launch background thread so startup doesn't block
|
||||||
if _install_thread is None or not _install_thread.is_alive():
|
if _install_thread is None or not _install_thread.is_alive():
|
||||||
_install_thread = threading.Thread(
|
_install_thread = threading.Thread(
|
||||||
target=_background_install, daemon=True)
|
target=_background_install,
|
||||||
|
kwargs={"log_failures": log_failures},
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
_install_thread.start()
|
_install_thread.start()
|
||||||
|
|
||||||
return None # Not available yet; commands will fail-open until ready
|
return None # Not available yet; commands will fail-open until ready
|
||||||
|
|||||||
4
utils.py
4
utils.py
@@ -50,6 +50,8 @@ def atomic_json_write(
|
|||||||
os.fsync(f.fileno())
|
os.fsync(f.fileno())
|
||||||
os.replace(tmp_path, path)
|
os.replace(tmp_path, path)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
|
# Intentionally catch BaseException so temp-file cleanup still runs for
|
||||||
|
# KeyboardInterrupt/SystemExit before re-raising the original signal.
|
||||||
try:
|
try:
|
||||||
os.unlink(tmp_path)
|
os.unlink(tmp_path)
|
||||||
except OSError:
|
except OSError:
|
||||||
@@ -96,6 +98,8 @@ def atomic_yaml_write(
|
|||||||
os.fsync(f.fileno())
|
os.fsync(f.fileno())
|
||||||
os.replace(tmp_path, path)
|
os.replace(tmp_path, path)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
|
# Match atomic_json_write: cleanup must also happen for process-level
|
||||||
|
# interruptions before we re-raise them.
|
||||||
try:
|
try:
|
||||||
os.unlink(tmp_path)
|
os.unlink(tmp_path)
|
||||||
except OSError:
|
except OSError:
|
||||||
|
|||||||
Reference in New Issue
Block a user