mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
merge: resolve conflicts with main (show_cost, turn routing, docker docs)
This commit is contained in:
@@ -295,3 +295,97 @@ class TestOnConnect:
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
agent.on_connect(mock_conn)
|
||||
assert agent._conn is mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slash commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlashCommands:
|
||||
"""Test slash command dispatch in the ACP adapter."""
|
||||
|
||||
def _make_state(self, mock_manager):
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.model = "test-model"
|
||||
state.agent.provider = "openrouter"
|
||||
state.model = "test-model"
|
||||
return state
|
||||
|
||||
def test_help_lists_commands(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/help", state)
|
||||
assert result is not None
|
||||
assert "/help" in result
|
||||
assert "/model" in result
|
||||
assert "/tools" in result
|
||||
assert "/reset" in result
|
||||
|
||||
def test_model_shows_current(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/model", state)
|
||||
assert "test-model" in result
|
||||
|
||||
def test_context_empty(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = []
|
||||
result = agent._handle_slash_command("/context", state)
|
||||
assert "empty" in result.lower()
|
||||
|
||||
def test_context_with_messages(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
result = agent._handle_slash_command("/context", state)
|
||||
assert "2 messages" in result
|
||||
assert "user: 1" in result
|
||||
|
||||
def test_reset_clears_history(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = [{"role": "user", "content": "hello"}]
|
||||
result = agent._handle_slash_command("/reset", state)
|
||||
assert "cleared" in result.lower()
|
||||
assert len(state.history) == 0
|
||||
|
||||
def test_version(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/version", state)
|
||||
assert HERMES_VERSION in result
|
||||
|
||||
def test_unknown_command_returns_none(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/nonexistent", state)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_command_intercepted_in_prompt(self, agent, mock_manager):
|
||||
"""Slash commands should be handled without calling the LLM."""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
mock_conn = AsyncMock(spec=acp.Client)
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="/help")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "end_turn"
|
||||
mock_conn.session_update.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_slash_falls_through_to_llm(self, agent, mock_manager):
|
||||
"""Unknown /commands should be sent to the LLM, not intercepted."""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
mock_conn = AsyncMock(spec=acp.Client)
|
||||
agent._conn = mock_conn
|
||||
|
||||
# Mock run_in_executor to avoid actually running the agent
|
||||
with patch("asyncio.get_running_loop") as mock_loop:
|
||||
mock_loop.return_value.run_in_executor = AsyncMock(return_value={
|
||||
"final_response": "I processed /foo",
|
||||
"messages": [],
|
||||
})
|
||||
prompt = [TextContentBlock(type="text", text="/foo bar")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
61
tests/agent/test_smart_model_routing.py
Normal file
61
tests/agent/test_smart_model_routing.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from agent.smart_model_routing import choose_cheap_model_route
|
||||
|
||||
|
||||
_BASE_CONFIG = {
|
||||
"enabled": True,
|
||||
"cheap_model": {
|
||||
"provider": "openrouter",
|
||||
"model": "google/gemini-2.5-flash",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_returns_none_when_disabled():
|
||||
cfg = {**_BASE_CONFIG, "enabled": False}
|
||||
assert choose_cheap_model_route("what time is it in tokyo?", cfg) is None
|
||||
|
||||
|
||||
def test_routes_short_simple_prompt():
|
||||
result = choose_cheap_model_route("what time is it in tokyo?", _BASE_CONFIG)
|
||||
assert result is not None
|
||||
assert result["provider"] == "openrouter"
|
||||
assert result["model"] == "google/gemini-2.5-flash"
|
||||
assert result["routing_reason"] == "simple_turn"
|
||||
|
||||
|
||||
def test_skips_long_prompt():
|
||||
prompt = "please summarize this carefully " * 20
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_skips_code_like_prompt():
|
||||
prompt = "debug this traceback: ```python\nraise ValueError('bad')\n```"
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_skips_tool_heavy_prompt_keywords():
|
||||
prompt = "implement a patch for this docker error"
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_resolve_turn_route_falls_back_to_primary_when_route_runtime_cannot_be_resolved(monkeypatch):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda **kwargs: (_ for _ in ()).throw(RuntimeError("bad route")),
|
||||
)
|
||||
result = resolve_turn_route(
|
||||
"what time is it in tokyo?",
|
||||
_BASE_CONFIG,
|
||||
{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"provider": "openrouter",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_mode": "chat_completions",
|
||||
"api_key": "sk-primary",
|
||||
},
|
||||
)
|
||||
assert result["model"] == "anthropic/claude-sonnet-4"
|
||||
assert result["runtime"]["provider"] == "openrouter"
|
||||
assert result["label"] is None
|
||||
@@ -26,6 +26,12 @@ def _isolate_hermes_home(tmp_path, monkeypatch):
|
||||
(fake_home / "memories").mkdir()
|
||||
(fake_home / "skills").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(fake_home))
|
||||
# Reset plugin singleton so tests don't leak plugins from ~/.hermes/plugins/
|
||||
try:
|
||||
import hermes_cli.plugins as _plugins_mod
|
||||
monkeypatch.setattr(_plugins_mod, "_plugin_manager", None)
|
||||
except Exception:
|
||||
pass
|
||||
# Tests should not inherit the agent's current gateway/messaging surface.
|
||||
# Individual tests that need gateway behavior set these explicitly.
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_enrichment_uses_athabasca_upload_guidance_without_stale_r2_warning():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
with patch(
|
||||
"tools.vision_tools.vision_analyze_tool",
|
||||
return_value='{"success": true, "analysis": "A painted serpent warrior."}',
|
||||
):
|
||||
enriched = await runner._enrich_message_with_vision(
|
||||
"caption",
|
||||
["/tmp/test.jpg"],
|
||||
)
|
||||
|
||||
assert "R2 not configured" not in enriched
|
||||
assert "Gateway media URL available for reference" not in enriched
|
||||
assert "POST /api/uploads" in enriched
|
||||
assert "Do not store the local cache path" in enriched
|
||||
assert "caption" in enriched
|
||||
156
tests/gateway/test_pii_redaction.py
Normal file
156
tests/gateway/test_pii_redaction.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tests for PII redaction in gateway session context prompts."""
|
||||
|
||||
from gateway.session import (
|
||||
SessionContext,
|
||||
SessionSource,
|
||||
build_session_context_prompt,
|
||||
_hash_id,
|
||||
_hash_sender_id,
|
||||
_hash_chat_id,
|
||||
_looks_like_phone,
|
||||
)
|
||||
from gateway.config import Platform, HomeChannel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHashHelpers:
|
||||
def test_hash_id_deterministic(self):
|
||||
assert _hash_id("12345") == _hash_id("12345")
|
||||
|
||||
def test_hash_id_12_hex_chars(self):
|
||||
h = _hash_id("user-abc")
|
||||
assert len(h) == 12
|
||||
assert all(c in "0123456789abcdef" for c in h)
|
||||
|
||||
def test_hash_sender_id_prefix(self):
|
||||
assert _hash_sender_id("12345").startswith("user_")
|
||||
assert len(_hash_sender_id("12345")) == 17 # "user_" + 12
|
||||
|
||||
def test_hash_chat_id_preserves_prefix(self):
|
||||
result = _hash_chat_id("telegram:12345")
|
||||
assert result.startswith("telegram:")
|
||||
assert "12345" not in result
|
||||
|
||||
def test_hash_chat_id_no_prefix(self):
|
||||
result = _hash_chat_id("12345")
|
||||
assert len(result) == 12
|
||||
assert "12345" not in result
|
||||
|
||||
def test_looks_like_phone(self):
|
||||
assert _looks_like_phone("+15551234567")
|
||||
assert _looks_like_phone("15551234567")
|
||||
assert _looks_like_phone("+1-555-123-4567")
|
||||
assert not _looks_like_phone("alice")
|
||||
assert not _looks_like_phone("user-123")
|
||||
assert not _looks_like_phone("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: build_session_context_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_context(
|
||||
user_id="user-123",
|
||||
user_name=None,
|
||||
chat_id="telegram:99999",
|
||||
platform=Platform.TELEGRAM,
|
||||
home_channels=None,
|
||||
):
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
)
|
||||
return SessionContext(
|
||||
source=source,
|
||||
connected_platforms=[platform],
|
||||
home_channels=home_channels or {},
|
||||
)
|
||||
|
||||
|
||||
class TestBuildSessionContextPromptRedaction:
|
||||
def test_no_redaction_by_default(self):
|
||||
ctx = _make_context(user_id="user-123")
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
assert "user-123" in prompt
|
||||
|
||||
def test_user_id_hashed_when_redact_pii(self):
|
||||
ctx = _make_context(user_id="user-123")
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "user-123" not in prompt
|
||||
assert "user_" in prompt # hashed ID present
|
||||
|
||||
def test_user_name_not_redacted(self):
|
||||
ctx = _make_context(user_id="user-123", user_name="Alice")
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "Alice" in prompt
|
||||
# user_id should not appear when user_name is present (name takes priority)
|
||||
assert "user-123" not in prompt
|
||||
|
||||
def test_home_channel_id_hashed(self):
|
||||
hc = {
|
||||
Platform.TELEGRAM: HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="telegram:99999",
|
||||
name="Home Chat",
|
||||
)
|
||||
}
|
||||
ctx = _make_context(home_channels=hc)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "99999" not in prompt
|
||||
assert "telegram:" in prompt # prefix preserved
|
||||
assert "Home Chat" in prompt # name not redacted
|
||||
|
||||
def test_home_channel_id_preserved_without_redaction(self):
|
||||
hc = {
|
||||
Platform.TELEGRAM: HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="telegram:99999",
|
||||
name="Home Chat",
|
||||
)
|
||||
}
|
||||
ctx = _make_context(home_channels=hc)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=False)
|
||||
assert "99999" in prompt
|
||||
|
||||
def test_redaction_is_deterministic(self):
|
||||
ctx = _make_context(user_id="+15551234567")
|
||||
prompt1 = build_session_context_prompt(ctx, redact_pii=True)
|
||||
prompt2 = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert prompt1 == prompt2
|
||||
|
||||
def test_different_ids_produce_different_hashes(self):
|
||||
ctx1 = _make_context(user_id="user-A")
|
||||
ctx2 = _make_context(user_id="user-B")
|
||||
p1 = build_session_context_prompt(ctx1, redact_pii=True)
|
||||
p2 = build_session_context_prompt(ctx2, redact_pii=True)
|
||||
assert p1 != p2
|
||||
|
||||
def test_discord_ids_not_redacted_even_with_flag(self):
|
||||
"""Discord needs real IDs for <@user_id> mentions."""
|
||||
ctx = _make_context(user_id="123456789", platform=Platform.DISCORD)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "123456789" in prompt
|
||||
|
||||
def test_whatsapp_ids_redacted(self):
|
||||
ctx = _make_context(user_id="+15551234567", platform=Platform.WHATSAPP)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "+15551234567" not in prompt
|
||||
assert "user_" in prompt
|
||||
|
||||
def test_signal_ids_redacted(self):
|
||||
ctx = _make_context(user_id="+15551234567", platform=Platform.SIGNAL)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "+15551234567" not in prompt
|
||||
assert "user_" in prompt
|
||||
|
||||
def test_slack_ids_not_redacted(self):
|
||||
"""Slack may need IDs for mentions too."""
|
||||
ctx = _make_context(user_id="U12345ABC", platform=Platform.SLACK)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "U12345ABC" in prompt
|
||||
89
tests/gateway/test_runner_startup_failures.py
Normal file
89
tests/gateway/test_runner_startup_failures.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.status import read_runtime_status
|
||||
|
||||
|
||||
class _RetryableFailureAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
self._set_fatal_error(
|
||||
"telegram_connect_error",
|
||||
"Telegram startup failed: temporary DNS resolution failure.",
|
||||
retryable=True,
|
||||
)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _DisabledAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=False, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
raise AssertionError("connect should not be called for disabled platforms")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _RetryableFailureAdapter())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is False
|
||||
assert runner.should_exit_cleanly is False
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "startup_failed"
|
||||
assert "temporary DNS resolution failure" in state["exit_reason"]
|
||||
assert state["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=False, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
assert runner.should_exit_cleanly is False
|
||||
assert runner.adapters == {}
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "running"
|
||||
@@ -100,6 +100,39 @@ async def test_polling_conflict_stops_polling_and_notifies_handler(monkeypatch):
|
||||
fatal_handler.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.acquire_scoped_lock",
|
||||
lambda scope, identity, metadata=None: (True, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.release_scoped_lock",
|
||||
lambda scope, identity: None,
|
||||
)
|
||||
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
app = SimpleNamespace(
|
||||
bot=SimpleNamespace(),
|
||||
updater=SimpleNamespace(),
|
||||
add_handler=MagicMock(),
|
||||
initialize=AsyncMock(side_effect=RuntimeError("Temporary failure in name resolution")),
|
||||
start=AsyncMock(),
|
||||
)
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
ok = await adapter.connect()
|
||||
|
||||
assert ok is False
|
||||
assert adapter.fatal_error_code == "telegram_connect_error"
|
||||
assert adapter.fatal_error_retryable is True
|
||||
assert "Temporary failure in name resolution" in adapter.fatal_error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_skips_inactive_updater_and_app(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
@@ -12,7 +12,8 @@ EXPECTED_COMMANDS = {
|
||||
"/personality", "/clear", "/history", "/new", "/reset", "/retry",
|
||||
"/undo", "/save", "/config", "/cron", "/skills", "/platforms",
|
||||
"/verbose", "/reasoning", "/compress", "/title", "/usage", "/insights", "/paste",
|
||||
"/reload-mcp", "/rollback", "/background", "/skin", "/voice", "/quit",
|
||||
"/reload-mcp", "/rollback", "/stop", "/background", "/skin", "/voice", "/browser", "/quit",
|
||||
"/plugins",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_systemd_status_warns_when_linger_disabled(monkeypatch, tmp_path, capsys
|
||||
monkeypatch.setattr(gateway, "get_systemd_linger_status", lambda: (False, ""))
|
||||
|
||||
def fake_run(cmd, capture_output=False, text=False, check=False):
|
||||
if cmd[:4] == ["systemctl", "--user", "status", gateway.SERVICE_NAME]:
|
||||
if cmd[:4] == ["systemctl", "--user", "status", gateway.get_service_name()]:
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
if cmd[:3] == ["systemctl", "--user", "is-active"]:
|
||||
return SimpleNamespace(returncode=0, stdout="active\n", stderr="")
|
||||
@@ -76,7 +76,7 @@ def test_systemd_install_checks_linger_status(monkeypatch, tmp_path, capsys):
|
||||
assert unit_path.exists()
|
||||
assert [cmd for cmd, _ in calls] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "enable", gateway.SERVICE_NAME],
|
||||
["systemctl", "--user", "enable", gateway.get_service_name()],
|
||||
]
|
||||
assert helper_calls == [True]
|
||||
assert "User service installed and enabled" in out
|
||||
@@ -110,7 +110,7 @@ def test_systemd_install_system_scope_skips_linger_and_uses_systemctl(monkeypatc
|
||||
assert unit_path.read_text(encoding="utf-8") == "scope=True user=alice\n"
|
||||
assert [cmd for cmd, _ in calls] == [
|
||||
["systemctl", "daemon-reload"],
|
||||
["systemctl", "enable", gateway.SERVICE_NAME],
|
||||
["systemctl", "enable", gateway.get_service_name()],
|
||||
]
|
||||
assert helper_calls == []
|
||||
assert "Configured to run as: alice" not in out # generated test unit has no User= line
|
||||
|
||||
@@ -114,7 +114,7 @@ def test_systemd_install_calls_linger_helper(monkeypatch, tmp_path, capsys):
|
||||
assert unit_path.exists()
|
||||
assert [cmd for cmd, _ in calls] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "enable", gateway.SERVICE_NAME],
|
||||
["systemctl", "--user", "enable", gateway.get_service_name()],
|
||||
]
|
||||
assert helper_calls == [True]
|
||||
assert "User service installed and enabled" in out
|
||||
|
||||
@@ -26,7 +26,7 @@ class TestSystemdServiceRefresh:
|
||||
assert unit_path.read_text(encoding="utf-8") == "new unit\n"
|
||||
assert calls[:2] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "start", gateway_cli.SERVICE_NAME],
|
||||
["systemctl", "--user", "start", gateway_cli.get_service_name()],
|
||||
]
|
||||
|
||||
def test_systemd_restart_refreshes_outdated_unit(self, tmp_path, monkeypatch):
|
||||
@@ -49,10 +49,27 @@ class TestSystemdServiceRefresh:
|
||||
assert unit_path.read_text(encoding="utf-8") == "new unit\n"
|
||||
assert calls[:2] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "restart", gateway_cli.SERVICE_NAME],
|
||||
["systemctl", "--user", "restart", gateway_cli.get_service_name()],
|
||||
]
|
||||
|
||||
|
||||
class TestGeneratedSystemdUnits:
|
||||
def test_user_unit_avoids_recursive_execstop_and_uses_extended_stop_timeout(self):
|
||||
unit = gateway_cli.generate_systemd_unit(system=False)
|
||||
|
||||
assert "ExecStart=" in unit
|
||||
assert "ExecStop=" not in unit
|
||||
assert "TimeoutStopSec=60" in unit
|
||||
|
||||
def test_system_unit_avoids_recursive_execstop_and_uses_extended_stop_timeout(self):
|
||||
unit = gateway_cli.generate_systemd_unit(system=True)
|
||||
|
||||
assert "ExecStart=" in unit
|
||||
assert "ExecStop=" not in unit
|
||||
assert "TimeoutStopSec=60" in unit
|
||||
assert "WantedBy=multi-user.target" in unit
|
||||
|
||||
|
||||
class TestGatewayStopCleanup:
|
||||
def test_stop_sweeps_manual_gateway_processes_after_service_stop(self, tmp_path, monkeypatch):
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
@@ -92,9 +109,9 @@ class TestGatewayServiceDetection:
|
||||
)
|
||||
|
||||
def fake_run(cmd, capture_output=True, text=True, **kwargs):
|
||||
if cmd == ["systemctl", "--user", "is-active", gateway_cli.SERVICE_NAME]:
|
||||
if cmd == ["systemctl", "--user", "is-active", gateway_cli.get_service_name()]:
|
||||
return SimpleNamespace(returncode=0, stdout="inactive\n", stderr="")
|
||||
if cmd == ["systemctl", "is-active", gateway_cli.SERVICE_NAME]:
|
||||
if cmd == ["systemctl", "is-active", gateway_cli.get_service_name()]:
|
||||
return SimpleNamespace(returncode=0, stdout="active\n", stderr="")
|
||||
raise AssertionError(f"Unexpected command: {cmd}")
|
||||
|
||||
|
||||
184
tests/hermes_cli/test_path_completion.py
Normal file
184
tests/hermes_cli/test_path_completion.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Tests for file path autocomplete in the CLI completer."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.formatted_text import to_plain_text
|
||||
|
||||
from hermes_cli.commands import SlashCommandCompleter, _file_size_label
|
||||
|
||||
|
||||
def _display_names(completions):
|
||||
"""Extract plain-text display names from a list of Completion objects."""
|
||||
return [to_plain_text(c.display) for c in completions]
|
||||
|
||||
|
||||
def _display_metas(completions):
|
||||
"""Extract plain-text display_meta from a list of Completion objects."""
|
||||
return [to_plain_text(c.display_meta) if c.display_meta else "" for c in completions]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completer():
|
||||
return SlashCommandCompleter()
|
||||
|
||||
|
||||
class TestExtractPathWord:
|
||||
def test_relative_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("look at ./src/main.py") == "./src/main.py"
|
||||
|
||||
def test_home_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("edit ~/docs/") == "~/docs/"
|
||||
|
||||
def test_absolute_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("read /etc/hosts") == "/etc/hosts"
|
||||
|
||||
def test_parent_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("check ../config.yaml") == "../config.yaml"
|
||||
|
||||
def test_path_with_slash_in_middle(self):
|
||||
assert SlashCommandCompleter._extract_path_word("open src/utils/helpers.py") == "src/utils/helpers.py"
|
||||
|
||||
def test_plain_word_not_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("hello world") is None
|
||||
|
||||
def test_empty_string(self):
|
||||
assert SlashCommandCompleter._extract_path_word("") is None
|
||||
|
||||
def test_single_word_no_slash(self):
|
||||
assert SlashCommandCompleter._extract_path_word("README.md") is None
|
||||
|
||||
def test_word_after_space(self):
|
||||
assert SlashCommandCompleter._extract_path_word("fix the bug in ./tools/") == "./tools/"
|
||||
|
||||
def test_just_dot_slash(self):
|
||||
assert SlashCommandCompleter._extract_path_word("./") == "./"
|
||||
|
||||
def test_just_tilde_slash(self):
|
||||
assert SlashCommandCompleter._extract_path_word("~/") == "~/"
|
||||
|
||||
|
||||
class TestPathCompletions:
|
||||
def test_lists_current_directory(self, tmp_path):
|
||||
(tmp_path / "file_a.py").touch()
|
||||
(tmp_path / "file_b.txt").touch()
|
||||
(tmp_path / "subdir").mkdir()
|
||||
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
completions = list(SlashCommandCompleter._path_completions("./"))
|
||||
names = _display_names(completions)
|
||||
assert "file_a.py" in names
|
||||
assert "file_b.txt" in names
|
||||
assert "subdir/" in names
|
||||
finally:
|
||||
os.chdir(old_cwd)
|
||||
|
||||
def test_filters_by_prefix(self, tmp_path):
|
||||
(tmp_path / "alpha.py").touch()
|
||||
(tmp_path / "beta.py").touch()
|
||||
(tmp_path / "alpha_test.py").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions(f"{tmp_path}/alpha"))
|
||||
names = _display_names(completions)
|
||||
assert "alpha.py" in names
|
||||
assert "alpha_test.py" in names
|
||||
assert "beta.py" not in names
|
||||
|
||||
def test_directories_have_trailing_slash(self, tmp_path):
|
||||
(tmp_path / "mydir").mkdir()
|
||||
(tmp_path / "myfile.txt").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions(f"{tmp_path}/"))
|
||||
names = _display_names(completions)
|
||||
metas = _display_metas(completions)
|
||||
assert "mydir/" in names
|
||||
idx = names.index("mydir/")
|
||||
assert metas[idx] == "dir"
|
||||
|
||||
def test_home_expansion(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
(tmp_path / "testfile.md").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions("~/test"))
|
||||
names = _display_names(completions)
|
||||
assert "testfile.md" in names
|
||||
|
||||
def test_nonexistent_dir_returns_empty(self):
|
||||
completions = list(SlashCommandCompleter._path_completions("/nonexistent_dir_xyz/"))
|
||||
assert completions == []
|
||||
|
||||
def test_respects_limit(self, tmp_path):
|
||||
for i in range(50):
|
||||
(tmp_path / f"file_{i:03d}.txt").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions(f"{tmp_path}/", limit=10))
|
||||
assert len(completions) == 10
|
||||
|
||||
def test_case_insensitive_prefix(self, tmp_path):
|
||||
(tmp_path / "README.md").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions(f"{tmp_path}/read"))
|
||||
names = _display_names(completions)
|
||||
assert "README.md" in names
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Test the completer produces path completions via the prompt_toolkit API."""
|
||||
|
||||
def test_slash_commands_still_work(self, completer):
|
||||
doc = Document("/hel", cursor_position=4)
|
||||
event = MagicMock()
|
||||
completions = list(completer.get_completions(doc, event))
|
||||
names = _display_names(completions)
|
||||
assert "/help" in names
|
||||
|
||||
def test_path_completion_triggers_on_dot_slash(self, completer, tmp_path):
|
||||
(tmp_path / "test.py").touch()
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
doc = Document("edit ./te", cursor_position=9)
|
||||
event = MagicMock()
|
||||
completions = list(completer.get_completions(doc, event))
|
||||
names = _display_names(completions)
|
||||
assert "test.py" in names
|
||||
finally:
|
||||
os.chdir(old_cwd)
|
||||
|
||||
def test_no_completion_for_plain_words(self, completer):
|
||||
doc = Document("hello world", cursor_position=11)
|
||||
event = MagicMock()
|
||||
completions = list(completer.get_completions(doc, event))
|
||||
assert completions == []
|
||||
|
||||
def test_absolute_path_triggers_completion(self, completer):
|
||||
doc = Document("check /etc/hos", cursor_position=14)
|
||||
event = MagicMock()
|
||||
completions = list(completer.get_completions(doc, event))
|
||||
names = _display_names(completions)
|
||||
# /etc/hosts should exist on Linux
|
||||
assert any("host" in n.lower() for n in names)
|
||||
|
||||
|
||||
class TestFileSizeLabel:
|
||||
def test_bytes(self, tmp_path):
|
||||
f = tmp_path / "small.txt"
|
||||
f.write_text("hi")
|
||||
assert _file_size_label(str(f)) == "2B"
|
||||
|
||||
def test_kilobytes(self, tmp_path):
|
||||
f = tmp_path / "medium.txt"
|
||||
f.write_bytes(b"x" * 2048)
|
||||
assert _file_size_label(str(f)) == "2K"
|
||||
|
||||
def test_megabytes(self, tmp_path):
|
||||
f = tmp_path / "large.bin"
|
||||
f.write_bytes(b"x" * (2 * 1024 * 1024))
|
||||
assert _file_size_label(str(f)) == "2.0M"
|
||||
|
||||
def test_nonexistent(self):
|
||||
assert _file_size_label("/nonexistent_xyz") == ""
|
||||
@@ -115,3 +115,13 @@ class TestConfigYamlRouting:
|
||||
set_config_value("terminal.docker_image", "python:3.12")
|
||||
config = _read_config(_isolated_hermes_home)
|
||||
assert "python:3.12" in config
|
||||
|
||||
def test_terminal_docker_cwd_mount_flag_goes_to_config_and_env(self, _isolated_hermes_home):
|
||||
set_config_value("terminal.docker_mount_cwd_to_workspace", "true")
|
||||
config = _read_config(_isolated_hermes_home)
|
||||
env_content = _read_env(_isolated_hermes_home)
|
||||
assert "docker_mount_cwd_to_workspace: 'true'" in config or "docker_mount_cwd_to_workspace: true" in config
|
||||
assert (
|
||||
"TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE=true" in env_content
|
||||
or "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE=True" in env_content
|
||||
)
|
||||
|
||||
305
tests/hermes_cli/test_update_gateway_restart.py
Normal file
305
tests/hermes_cli/test_update_gateway_restart.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Tests for cmd_update gateway auto-restart — systemd + launchd coverage.
|
||||
|
||||
Ensures ``hermes update`` correctly detects running gateways managed by
|
||||
systemd (Linux) or launchd (macOS) and restarts/informs the user properly,
|
||||
rather than leaving zombie processes or telling users to manually restart
|
||||
when launchd will auto-respawn.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import hermes_cli.gateway as gateway_cli
|
||||
from hermes_cli.main import cmd_update
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_run_side_effect(
|
||||
branch="main",
|
||||
verify_ok=True,
|
||||
commit_count="3",
|
||||
systemd_active=False,
|
||||
launchctl_loaded=False,
|
||||
):
|
||||
"""Build a subprocess.run side_effect that simulates git + service commands."""
|
||||
|
||||
def side_effect(cmd, **kwargs):
|
||||
joined = " ".join(str(c) for c in cmd)
|
||||
|
||||
# git rev-parse --abbrev-ref HEAD
|
||||
if "rev-parse" in joined and "--abbrev-ref" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout=f"{branch}\n", stderr="")
|
||||
|
||||
# git rev-parse --verify origin/{branch}
|
||||
if "rev-parse" in joined and "--verify" in joined:
|
||||
rc = 0 if verify_ok else 128
|
||||
return subprocess.CompletedProcess(cmd, rc, stdout="", stderr="")
|
||||
|
||||
# git rev-list HEAD..origin/{branch} --count
|
||||
if "rev-list" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout=f"{commit_count}\n", stderr="")
|
||||
|
||||
# systemctl --user is-active
|
||||
if "systemctl" in joined and "is-active" in joined:
|
||||
if systemd_active:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="active\n", stderr="")
|
||||
return subprocess.CompletedProcess(cmd, 3, stdout="inactive\n", stderr="")
|
||||
|
||||
# systemctl --user restart
|
||||
if "systemctl" in joined and "restart" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
|
||||
|
||||
# launchctl list ai.hermes.gateway
|
||||
if "launchctl" in joined and "list" in joined:
|
||||
if launchctl_loaded:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="PID\tStatus\tLabel\n123\t0\tai.hermes.gateway\n", stderr="")
|
||||
return subprocess.CompletedProcess(cmd, 113, stdout="", stderr="Could not find service")
|
||||
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
|
||||
|
||||
return side_effect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_args():
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Launchd plist includes --replace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLaunchdPlistReplace:
|
||||
"""The generated launchd plist must include --replace so respawned
|
||||
gateways kill stale instances."""
|
||||
|
||||
def test_plist_contains_replace_flag(self):
|
||||
plist = gateway_cli.generate_launchd_plist()
|
||||
assert "--replace" in plist
|
||||
|
||||
def test_plist_program_arguments_order(self):
|
||||
"""--replace comes after 'run' in the ProgramArguments."""
|
||||
plist = gateway_cli.generate_launchd_plist()
|
||||
lines = [line.strip() for line in plist.splitlines()]
|
||||
# Find 'run' and '--replace' in the string entries
|
||||
string_values = [
|
||||
line.replace("<string>", "").replace("</string>", "")
|
||||
for line in lines
|
||||
if "<string>" in line and "</string>" in line
|
||||
]
|
||||
assert "run" in string_values
|
||||
assert "--replace" in string_values
|
||||
run_idx = string_values.index("run")
|
||||
replace_idx = string_values.index("--replace")
|
||||
assert replace_idx == run_idx + 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cmd_update — macOS launchd detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLaunchdPlistRefresh:
|
||||
"""refresh_launchd_plist_if_needed rewrites stale plists (like systemd's
|
||||
refresh_systemd_unit_if_needed)."""
|
||||
|
||||
def test_refresh_rewrites_stale_plist(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text("<plist>old content</plist>")
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
calls = []
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append(cmd)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
result = gateway_cli.refresh_launchd_plist_if_needed()
|
||||
|
||||
assert result is True
|
||||
# Plist should now contain the generated content (which includes --replace)
|
||||
assert "--replace" in plist_path.read_text()
|
||||
# Should have unloaded then reloaded
|
||||
assert any("unload" in str(c) for c in calls)
|
||||
assert any("load" in str(c) for c in calls)
|
||||
|
||||
def test_refresh_skips_when_current(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
# Write the current expected content
|
||||
plist_path.write_text(gateway_cli.generate_launchd_plist())
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess, "run",
|
||||
lambda cmd, **kw: calls.append(cmd) or SimpleNamespace(returncode=0),
|
||||
)
|
||||
|
||||
result = gateway_cli.refresh_launchd_plist_if_needed()
|
||||
|
||||
assert result is False
|
||||
assert len(calls) == 0 # No launchctl calls needed
|
||||
|
||||
def test_refresh_skips_when_no_plist(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "nonexistent.plist"
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
result = gateway_cli.refresh_launchd_plist_if_needed()
|
||||
assert result is False
|
||||
|
||||
def test_launchd_start_calls_refresh(self, tmp_path, monkeypatch):
|
||||
"""launchd_start refreshes the plist before starting."""
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text("<plist>old</plist>")
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
calls = []
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append(cmd)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
gateway_cli.launchd_start()
|
||||
|
||||
# First calls should be refresh (unload/load), then start
|
||||
cmd_strs = [" ".join(c) for c in calls]
|
||||
assert any("unload" in s for s in cmd_strs)
|
||||
assert any("start" in s for s in cmd_strs)
|
||||
|
||||
|
||||
class TestCmdUpdateLaunchdRestart:
|
||||
"""cmd_update correctly detects and handles launchd on macOS."""
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_detects_launchd_and_skips_manual_restart_message(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""When launchd is running the gateway, update should print
|
||||
'auto-restart via launchd' instead of 'Restart it with: hermes gateway run'."""
|
||||
# Create a fake launchd plist so is_macos + plist.exists() passes
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text("<plist/>")
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "is_macos", lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_launchd_plist_path", lambda: plist_path,
|
||||
)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
launchctl_loaded=True,
|
||||
)
|
||||
|
||||
# Mock get_running_pid to return a PID
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Gateway restarted via launchd" in captured
|
||||
assert "Restart it with: hermes gateway run" not in captured
|
||||
# Verify launchctl stop + start were called (not manual SIGTERM)
|
||||
launchctl_calls = [
|
||||
c for c in mock_run.call_args_list
|
||||
if len(c.args[0]) > 0 and c.args[0][0] == "launchctl"
|
||||
]
|
||||
stop_calls = [c for c in launchctl_calls if "stop" in c.args[0]]
|
||||
start_calls = [c for c in launchctl_calls if "start" in c.args[0]]
|
||||
assert len(stop_calls) >= 1
|
||||
assert len(start_calls) >= 1
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_without_launchd_shows_manual_restart(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""When no service manager is running, update should show the manual restart hint."""
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "is_macos", lambda: True,
|
||||
)
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
# plist does NOT exist — no launchd service
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_launchd_plist_path", lambda: plist_path,
|
||||
)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
launchctl_loaded=False,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"), \
|
||||
patch("os.kill"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Restart it with: hermes gateway run" in captured
|
||||
assert "Gateway restarted via launchd" not in captured
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_with_systemd_still_restarts_via_systemd(
|
||||
self, mock_run, _mock_which, mock_args, capsys, monkeypatch,
|
||||
):
|
||||
"""On Linux with systemd active, update should restart via systemctl."""
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "is_macos", lambda: False,
|
||||
)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
systemd_active=True,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"), \
|
||||
patch("os.kill"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Gateway restarted" in captured
|
||||
# Verify systemctl restart was called
|
||||
restart_calls = [
|
||||
c for c in mock_run.call_args_list
|
||||
if "restart" in " ".join(str(a) for a in c.args[0])
|
||||
and "systemctl" in " ".join(str(a) for a in c.args[0])
|
||||
]
|
||||
assert len(restart_calls) == 1
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_no_gateway_running_skips_restart(
|
||||
self, mock_run, _mock_which, mock_args, capsys, monkeypatch,
|
||||
):
|
||||
"""When no gateway is running, update should skip the restart section entirely."""
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "is_macos", lambda: False,
|
||||
)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
systemd_active=False,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=None):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Stopped gateway" not in captured
|
||||
assert "Gateway restarted" not in captured
|
||||
assert "Gateway restarted via launchd" not in captured
|
||||
@@ -162,6 +162,57 @@ def test_runtime_resolution_rebuilds_agent_on_routing_change(monkeypatch):
|
||||
assert shell.api_mode == "codex_responses"
|
||||
|
||||
|
||||
def test_cli_turn_routing_uses_primary_when_disabled(monkeypatch):
|
||||
cli = _import_cli()
|
||||
shell = cli.HermesCLI(model="gpt-5", compact=True, max_turns=1)
|
||||
shell.provider = "openrouter"
|
||||
shell.api_mode = "chat_completions"
|
||||
shell.base_url = "https://openrouter.ai/api/v1"
|
||||
shell.api_key = "sk-primary"
|
||||
shell._smart_model_routing = {"enabled": False}
|
||||
|
||||
result = shell._resolve_turn_agent_config("what time is it in tokyo?")
|
||||
|
||||
assert result["model"] == "gpt-5"
|
||||
assert result["runtime"]["provider"] == "openrouter"
|
||||
assert result["label"] is None
|
||||
|
||||
|
||||
def test_cli_turn_routing_uses_cheap_model_when_simple(monkeypatch):
|
||||
cli = _import_cli()
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
assert kwargs["requested"] == "zai"
|
||||
return {
|
||||
"provider": "zai",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://open.z.ai/api/v1",
|
||||
"api_key": "cheap-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
|
||||
shell = cli.HermesCLI(model="anthropic/claude-sonnet-4", compact=True, max_turns=1)
|
||||
shell.provider = "openrouter"
|
||||
shell.api_mode = "chat_completions"
|
||||
shell.base_url = "https://openrouter.ai/api/v1"
|
||||
shell.api_key = "primary-key"
|
||||
shell._smart_model_routing = {
|
||||
"enabled": True,
|
||||
"cheap_model": {"provider": "zai", "model": "glm-5-air"},
|
||||
"max_simple_chars": 160,
|
||||
"max_simple_words": 28,
|
||||
}
|
||||
|
||||
result = shell._resolve_turn_agent_config("what time is it in tokyo?")
|
||||
|
||||
assert result["model"] == "glm-5-air"
|
||||
assert result["runtime"]["provider"] == "zai"
|
||||
assert result["runtime"]["api_key"] == "cheap-key"
|
||||
assert result["label"] is not None
|
||||
|
||||
|
||||
def test_cli_prefers_config_provider_over_stale_env_override(monkeypatch):
|
||||
cli = _import_cli()
|
||||
|
||||
|
||||
@@ -65,24 +65,39 @@ class TestCLIStatusBar:
|
||||
assert "claude-sonnet-4-20250514" in text
|
||||
assert "12.4K/200K" in text
|
||||
assert "6%" in text
|
||||
assert "$0.06" in text
|
||||
assert "$0.06" not in text # cost hidden by default
|
||||
assert "15m" in text
|
||||
|
||||
def test_build_status_bar_text_shows_cost_when_enabled(self):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=10000,
|
||||
completion_tokens=2400,
|
||||
total_tokens=12400,
|
||||
api_calls=7,
|
||||
context_tokens=12400,
|
||||
context_length=200_000,
|
||||
)
|
||||
cli_obj.show_cost = True
|
||||
|
||||
text = cli_obj._build_status_bar_text(width=120)
|
||||
assert "$" in text # cost is shown when enabled
|
||||
|
||||
def test_build_status_bar_text_collapses_for_narrow_terminal(self):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=10_230,
|
||||
completion_tokens=2_220,
|
||||
total_tokens=12_450,
|
||||
prompt_tokens=10000,
|
||||
completion_tokens=2400,
|
||||
total_tokens=12400,
|
||||
api_calls=7,
|
||||
context_tokens=12_450,
|
||||
context_tokens=12400,
|
||||
context_length=200_000,
|
||||
)
|
||||
|
||||
text = cli_obj._build_status_bar_text(width=60)
|
||||
|
||||
assert "⚕" in text
|
||||
assert "$0.06" in text
|
||||
assert "$0.06" not in text # cost hidden by default
|
||||
assert "15m" in text
|
||||
assert "200K" not in text
|
||||
|
||||
|
||||
340
tests/test_plugins.py
Normal file
340
tests/test_plugins.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""Tests for the Hermes plugin system (hermes_cli.plugins)."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from hermes_cli.plugins import (
|
||||
ENTRY_POINTS_GROUP,
|
||||
VALID_HOOKS,
|
||||
LoadedPlugin,
|
||||
PluginContext,
|
||||
PluginManager,
|
||||
PluginManifest,
|
||||
get_plugin_manager,
|
||||
get_plugin_tool_names,
|
||||
discover_plugins,
|
||||
invoke_hook,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_plugin_dir(base: Path, name: str, *, register_body: str = "pass",
|
||||
manifest_extra: dict | None = None) -> Path:
|
||||
"""Create a minimal plugin directory with plugin.yaml + __init__.py."""
|
||||
plugin_dir = base / name
|
||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
manifest = {"name": name, "version": "0.1.0", "description": f"Test plugin {name}"}
|
||||
if manifest_extra:
|
||||
manifest.update(manifest_extra)
|
||||
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump(manifest))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
f"def register(ctx):\n {register_body}\n"
|
||||
)
|
||||
return plugin_dir
|
||||
|
||||
|
||||
# ── TestPluginDiscovery ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginDiscovery:
|
||||
"""Tests for plugin discovery from directories and entry points."""
|
||||
|
||||
def test_discover_user_plugins(self, tmp_path, monkeypatch):
|
||||
"""Plugins in ~/.hermes/plugins/ are discovered."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "hello_plugin")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "hello_plugin" in mgr._plugins
|
||||
assert mgr._plugins["hello_plugin"].enabled
|
||||
|
||||
def test_discover_project_plugins(self, tmp_path, monkeypatch):
|
||||
"""Plugins in ./.hermes/plugins/ are discovered."""
|
||||
project_dir = tmp_path / "project"
|
||||
project_dir.mkdir()
|
||||
monkeypatch.chdir(project_dir)
|
||||
plugins_dir = project_dir / ".hermes" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "proj_plugin")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "proj_plugin" in mgr._plugins
|
||||
assert mgr._plugins["proj_plugin"].enabled
|
||||
|
||||
def test_discover_is_idempotent(self, tmp_path, monkeypatch):
|
||||
"""Calling discover_and_load() twice does not duplicate plugins."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "once_plugin")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
mgr.discover_and_load() # second call should no-op
|
||||
|
||||
assert len(mgr._plugins) == 1
|
||||
|
||||
def test_discover_skips_dir_without_manifest(self, tmp_path, monkeypatch):
|
||||
"""Directories without plugin.yaml are silently skipped."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
(plugins_dir / "no_manifest").mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert len(mgr._plugins) == 0
|
||||
|
||||
def test_entry_points_scanned(self, tmp_path, monkeypatch):
|
||||
"""Entry-point based plugins are discovered (mocked)."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
fake_module = types.ModuleType("fake_ep_plugin")
|
||||
fake_module.register = lambda ctx: None # type: ignore[attr-defined]
|
||||
|
||||
fake_ep = MagicMock()
|
||||
fake_ep.name = "ep_plugin"
|
||||
fake_ep.value = "fake_ep_plugin:register"
|
||||
fake_ep.group = ENTRY_POINTS_GROUP
|
||||
fake_ep.load.return_value = fake_module
|
||||
|
||||
def fake_entry_points():
|
||||
result = MagicMock()
|
||||
result.select = MagicMock(return_value=[fake_ep])
|
||||
return result
|
||||
|
||||
with patch("importlib.metadata.entry_points", fake_entry_points):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "ep_plugin" in mgr._plugins
|
||||
|
||||
|
||||
# ── TestPluginLoading ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginLoading:
|
||||
"""Tests for plugin module loading."""
|
||||
|
||||
def test_load_missing_init(self, tmp_path, monkeypatch):
|
||||
"""Plugin dir without __init__.py records an error."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "bad_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "bad_plugin"}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "bad_plugin" in mgr._plugins
|
||||
assert not mgr._plugins["bad_plugin"].enabled
|
||||
assert mgr._plugins["bad_plugin"].error is not None
|
||||
|
||||
def test_load_missing_register_fn(self, tmp_path, monkeypatch):
|
||||
"""Plugin without register() function records an error."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "no_reg"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "no_reg"}))
|
||||
(plugin_dir / "__init__.py").write_text("# no register function\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "no_reg" in mgr._plugins
|
||||
assert not mgr._plugins["no_reg"].enabled
|
||||
assert "no register()" in mgr._plugins["no_reg"].error
|
||||
|
||||
def test_load_registers_namespace_module(self, tmp_path, monkeypatch):
|
||||
"""Directory plugins are importable under hermes_plugins.<name>."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "ns_plugin")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
# Clean up any prior namespace module
|
||||
sys.modules.pop("hermes_plugins.ns_plugin", None)
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "hermes_plugins.ns_plugin" in sys.modules
|
||||
|
||||
|
||||
# ── TestPluginHooks ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginHooks:
|
||||
"""Tests for lifecycle hook registration and invocation."""
|
||||
|
||||
def test_register_and_invoke_hook(self, tmp_path, monkeypatch):
|
||||
"""Registered hooks are called on invoke_hook()."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "hook_plugin",
|
||||
register_body='ctx.register_hook("pre_tool_call", lambda **kw: None)',
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Should not raise
|
||||
mgr.invoke_hook("pre_tool_call", tool_name="test", args={}, task_id="t1")
|
||||
|
||||
def test_hook_exception_does_not_propagate(self, tmp_path, monkeypatch):
|
||||
"""A hook callback that raises does NOT crash the caller."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "bad_hook",
|
||||
register_body='ctx.register_hook("post_tool_call", lambda **kw: 1/0)',
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Should not raise despite 1/0
|
||||
mgr.invoke_hook("post_tool_call", tool_name="x", args={}, result="r", task_id="")
|
||||
|
||||
def test_invalid_hook_name_warns(self, tmp_path, monkeypatch, caplog):
|
||||
"""Registering an unknown hook name logs a warning."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "warn_plugin",
|
||||
register_body='ctx.register_hook("on_banana", lambda **kw: None)',
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="hermes_cli.plugins"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert any("on_banana" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
# ── TestPluginContext ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginContext:
|
||||
"""Tests for the PluginContext facade."""
|
||||
|
||||
def test_register_tool_adds_to_registry(self, tmp_path, monkeypatch):
|
||||
"""PluginContext.register_tool() puts the tool in the global registry."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "tool_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "tool_plugin"}))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
'def register(ctx):\n'
|
||||
' ctx.register_tool(\n'
|
||||
' name="plugin_echo",\n'
|
||||
' toolset="plugin_tool_plugin",\n'
|
||||
' schema={"name": "plugin_echo", "description": "Echo", "parameters": {"type": "object", "properties": {}}},\n'
|
||||
' handler=lambda args, **kw: "echo",\n'
|
||||
' )\n'
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "plugin_echo" in mgr._plugin_tool_names
|
||||
|
||||
from tools.registry import registry
|
||||
assert "plugin_echo" in registry._tools
|
||||
|
||||
|
||||
# ── TestPluginToolVisibility ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginToolVisibility:
|
||||
"""Plugin-registered tools appear in get_tool_definitions()."""
|
||||
|
||||
def test_plugin_tools_in_definitions(self, tmp_path, monkeypatch):
|
||||
"""Tools from plugins bypass the toolset filter."""
|
||||
import hermes_cli.plugins as plugins_mod
|
||||
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "vis_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "vis_plugin"}))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
'def register(ctx):\n'
|
||||
' ctx.register_tool(\n'
|
||||
' name="vis_tool",\n'
|
||||
' toolset="plugin_vis_plugin",\n'
|
||||
' schema={"name": "vis_tool", "description": "Visible", "parameters": {"type": "object", "properties": {}}},\n'
|
||||
' handler=lambda args, **kw: "ok",\n'
|
||||
' )\n'
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
monkeypatch.setattr(plugins_mod, "_plugin_manager", mgr)
|
||||
|
||||
from model_tools import get_tool_definitions
|
||||
tools = get_tool_definitions(enabled_toolsets=["terminal"], quiet_mode=True)
|
||||
tool_names = [t["function"]["name"] for t in tools]
|
||||
assert "vis_tool" in tool_names
|
||||
|
||||
|
||||
# ── TestPluginManagerList ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginManagerList:
|
||||
"""Tests for PluginManager.list_plugins()."""
|
||||
|
||||
def test_list_empty(self):
|
||||
"""Empty manager returns empty list."""
|
||||
mgr = PluginManager()
|
||||
assert mgr.list_plugins() == []
|
||||
|
||||
def test_list_returns_sorted(self, tmp_path, monkeypatch):
|
||||
"""list_plugins() returns results sorted by name."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "zulu")
|
||||
_make_plugin_dir(plugins_dir, "alpha")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
listing = mgr.list_plugins()
|
||||
names = [p["name"] for p in listing]
|
||||
assert names == sorted(names)
|
||||
|
||||
def test_list_with_plugins(self, tmp_path, monkeypatch):
|
||||
"""list_plugins() returns info dicts for each discovered plugin."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "alpha")
|
||||
_make_plugin_dir(plugins_dir, "beta")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
listing = mgr.list_plugins()
|
||||
names = [p["name"] for p in listing]
|
||||
assert "alpha" in names
|
||||
assert "beta" in names
|
||||
for p in listing:
|
||||
assert "enabled" in p
|
||||
assert "tools" in p
|
||||
assert "hooks" in p
|
||||
@@ -1,11 +1,31 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import docker as docker_env
|
||||
|
||||
|
||||
def _install_fake_minisweagent(monkeypatch, captured_run_args):
|
||||
class MockInnerDocker:
|
||||
container_id = "fake-container"
|
||||
config = type("Config", (), {"executable": "/usr/bin/docker", "forward_env": [], "env": {}})()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
captured_run_args.extend(kwargs.get("run_args", []))
|
||||
|
||||
minisweagent_mod = types.ModuleType("minisweagent")
|
||||
environments_mod = types.ModuleType("minisweagent.environments")
|
||||
docker_mod = types.ModuleType("minisweagent.environments.docker")
|
||||
docker_mod.DockerEnvironment = MockInnerDocker
|
||||
|
||||
monkeypatch.setitem(sys.modules, "minisweagent", minisweagent_mod)
|
||||
monkeypatch.setitem(sys.modules, "minisweagent.environments", environments_mod)
|
||||
monkeypatch.setitem(sys.modules, "minisweagent.environments.docker", docker_mod)
|
||||
|
||||
|
||||
def _make_dummy_env(**kwargs):
|
||||
"""Helper to construct DockerEnvironment with minimal required args."""
|
||||
return docker_env.DockerEnvironment(
|
||||
@@ -19,6 +39,8 @@ def _make_dummy_env(**kwargs):
|
||||
task_id=kwargs.get("task_id", "test-task"),
|
||||
volumes=kwargs.get("volumes", []),
|
||||
network=kwargs.get("network", True),
|
||||
host_cwd=kwargs.get("host_cwd"),
|
||||
auto_mount_cwd=kwargs.get("auto_mount_cwd", False),
|
||||
)
|
||||
|
||||
|
||||
@@ -88,65 +110,10 @@ def test_ensure_docker_available_uses_resolved_executable(monkeypatch):
|
||||
|
||||
|
||||
def test_auto_mount_host_cwd_adds_volume(monkeypatch, tmp_path):
|
||||
"""When host_cwd is provided, it should be auto-mounted to /workspace."""
|
||||
import os
|
||||
|
||||
# Create a temp directory to simulate user's project directory
|
||||
"""Opt-in docker cwd mounting should bind the host cwd to /workspace."""
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
|
||||
# Mock Docker availability
|
||||
def _run_docker_version(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 0, stdout="Docker version", stderr="")
|
||||
|
||||
def _run_docker_create(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 1, stdout="", stderr="storage-opt not supported")
|
||||
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
|
||||
# Mock the inner _Docker class to capture run_args
|
||||
captured_run_args = []
|
||||
|
||||
class MockInnerDocker:
|
||||
container_id = "mock-container-123"
|
||||
config = type("Config", (), {"executable": "/usr/bin/docker", "forward_env": [], "env": {}})()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
captured_run_args.extend(kwargs.get("run_args", []))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"minisweagent.environments.docker.DockerEnvironment",
|
||||
MockInnerDocker,
|
||||
)
|
||||
|
||||
# Create environment with host_cwd
|
||||
env = docker_env.DockerEnvironment(
|
||||
image="python:3.11",
|
||||
cwd="/workspace",
|
||||
timeout=60,
|
||||
persistent_filesystem=False, # Non-persistent mode uses tmpfs, should be overridden
|
||||
task_id="test-auto-mount",
|
||||
volumes=[],
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=True,
|
||||
)
|
||||
|
||||
# Check that the host_cwd was added as a volume mount
|
||||
volume_mount = f"-v {project_dir}:/workspace"
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{project_dir}:/workspace" in run_args_str, f"Expected auto-mount in run_args: {run_args_str}"
|
||||
|
||||
|
||||
def test_auto_mount_disabled_via_env(monkeypatch, tmp_path):
|
||||
"""Auto-mount should be disabled when TERMINAL_DOCKER_NO_AUTO_MOUNT is set."""
|
||||
import os
|
||||
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
|
||||
monkeypatch.setenv("TERMINAL_DOCKER_NO_AUTO_MOUNT", "true")
|
||||
|
||||
def _run_docker_version(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 0, stdout="Docker version", stderr="")
|
||||
|
||||
@@ -154,39 +121,44 @@ def test_auto_mount_disabled_via_env(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
|
||||
captured_run_args = []
|
||||
_install_fake_minisweagent(monkeypatch, captured_run_args)
|
||||
|
||||
class MockInnerDocker:
|
||||
container_id = "mock-container-456"
|
||||
config = type("Config", (), {"executable": "/usr/bin/docker", "forward_env": [], "env": {}})()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
captured_run_args.extend(kwargs.get("run_args", []))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"minisweagent.environments.docker.DockerEnvironment",
|
||||
MockInnerDocker,
|
||||
)
|
||||
|
||||
env = docker_env.DockerEnvironment(
|
||||
image="python:3.11",
|
||||
_make_dummy_env(
|
||||
cwd="/workspace",
|
||||
timeout=60,
|
||||
persistent_filesystem=False,
|
||||
task_id="test-no-auto-mount",
|
||||
volumes=[],
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=True,
|
||||
)
|
||||
|
||||
# Check that the host_cwd was NOT added (because env var disabled it)
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{project_dir}:/workspace" not in run_args_str, f"Auto-mount should be disabled: {run_args_str}"
|
||||
assert f"{project_dir}:/workspace" in run_args_str
|
||||
|
||||
|
||||
def test_auto_mount_disabled_by_default(monkeypatch, tmp_path):
|
||||
"""Host cwd should not be mounted unless the caller explicitly opts in."""
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
|
||||
def _run_docker_version(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 0, stdout="Docker version", stderr="")
|
||||
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
|
||||
captured_run_args = []
|
||||
_install_fake_minisweagent(monkeypatch, captured_run_args)
|
||||
|
||||
_make_dummy_env(
|
||||
cwd="/root",
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=False,
|
||||
)
|
||||
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{project_dir}:/workspace" not in run_args_str
|
||||
|
||||
|
||||
def test_auto_mount_skipped_when_workspace_already_mounted(monkeypatch, tmp_path):
|
||||
"""Auto-mount should be skipped if /workspace is already mounted via user volumes."""
|
||||
import os
|
||||
|
||||
"""Explicit user volumes for /workspace should take precedence over cwd mount."""
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
other_dir = tmp_path / "other"
|
||||
@@ -199,35 +171,43 @@ def test_auto_mount_skipped_when_workspace_already_mounted(monkeypatch, tmp_path
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
|
||||
captured_run_args = []
|
||||
_install_fake_minisweagent(monkeypatch, captured_run_args)
|
||||
|
||||
class MockInnerDocker:
|
||||
container_id = "mock-container-789"
|
||||
config = type("Config", (), {"executable": "/usr/bin/docker", "forward_env": [], "env": {}})()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
captured_run_args.extend(kwargs.get("run_args", []))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"minisweagent.environments.docker.DockerEnvironment",
|
||||
MockInnerDocker,
|
||||
)
|
||||
|
||||
# User already configured a volume mount for /workspace
|
||||
env = docker_env.DockerEnvironment(
|
||||
image="python:3.11",
|
||||
_make_dummy_env(
|
||||
cwd="/workspace",
|
||||
timeout=60,
|
||||
persistent_filesystem=False,
|
||||
task_id="test-workspace-exists",
|
||||
volumes=[f"{other_dir}:/workspace"], # User explicitly mounted something to /workspace
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=True,
|
||||
volumes=[f"{other_dir}:/workspace"],
|
||||
)
|
||||
|
||||
# The user's explicit mount should be present
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{other_dir}:/workspace" in run_args_str
|
||||
assert run_args_str.count(":/workspace") == 1
|
||||
|
||||
# But the auto-mount should NOT add a duplicate
|
||||
assert run_args_str.count(":/workspace") == 1, f"Should only have one /workspace mount: {run_args_str}"
|
||||
|
||||
def test_auto_mount_replaces_persistent_workspace_bind(monkeypatch, tmp_path):
|
||||
"""Persistent mode should still prefer the configured host cwd at /workspace."""
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
|
||||
def _run_docker_version(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 0, stdout="Docker version", stderr="")
|
||||
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
|
||||
captured_run_args = []
|
||||
_install_fake_minisweagent(monkeypatch, captured_run_args)
|
||||
|
||||
_make_dummy_env(
|
||||
cwd="/workspace",
|
||||
persistent_filesystem=True,
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=True,
|
||||
task_id="test-persistent-auto-mount",
|
||||
)
|
||||
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{project_dir}:/workspace" in run_args_str
|
||||
assert "/sandboxes/docker/test-persistent-auto-mount/workspace:/workspace" not in run_args_str
|
||||
|
||||
|
||||
@@ -91,8 +91,8 @@ class TestCwdHandling:
|
||||
"/home/ paths should be replaced for modal backend."
|
||||
)
|
||||
|
||||
def test_users_path_replaced_for_docker(self):
|
||||
"""TERMINAL_CWD=/Users/... should be replaced with /root for docker."""
|
||||
def test_users_path_replaced_for_docker_by_default(self):
|
||||
"""Docker should keep host paths out of the sandbox unless explicitly enabled."""
|
||||
with patch.dict(os.environ, {
|
||||
"TERMINAL_ENV": "docker",
|
||||
"TERMINAL_CWD": "/Users/someone/projects",
|
||||
@@ -100,8 +100,22 @@ class TestCwdHandling:
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == "/root", (
|
||||
f"Expected /root, got {config['cwd']}. "
|
||||
"/Users/ paths should be replaced for docker backend."
|
||||
"Host paths should be discarded for docker backend by default."
|
||||
)
|
||||
assert config["host_cwd"] is None
|
||||
assert config["docker_mount_cwd_to_workspace"] is False
|
||||
|
||||
def test_users_path_maps_to_workspace_for_docker_when_enabled(self):
|
||||
"""Docker should map the host cwd into /workspace only when explicitly enabled."""
|
||||
with patch.dict(os.environ, {
|
||||
"TERMINAL_ENV": "docker",
|
||||
"TERMINAL_CWD": "/Users/someone/projects",
|
||||
"TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE": "true",
|
||||
}):
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == "/workspace"
|
||||
assert config["host_cwd"] == "/Users/someone/projects"
|
||||
assert config["docker_mount_cwd_to_workspace"] is True
|
||||
|
||||
def test_windows_path_replaced_for_modal(self):
|
||||
"""TERMINAL_CWD=C:\\Users\\... should be replaced for modal."""
|
||||
@@ -119,12 +133,27 @@ class TestCwdHandling:
|
||||
# Remove TERMINAL_CWD so it uses default
|
||||
env = os.environ.copy()
|
||||
env.pop("TERMINAL_CWD", None)
|
||||
env.pop("TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == "/root", (
|
||||
f"Backend {backend}: expected /root default, got {config['cwd']}"
|
||||
)
|
||||
|
||||
def test_docker_default_cwd_maps_current_directory_when_enabled(self):
|
||||
"""Docker should use /workspace when cwd mounting is explicitly enabled."""
|
||||
with patch("tools.terminal_tool.os.getcwd", return_value="/home/user/project"):
|
||||
with patch.dict(os.environ, {
|
||||
"TERMINAL_ENV": "docker",
|
||||
"TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE": "true",
|
||||
}, clear=False):
|
||||
env = os.environ.copy()
|
||||
env.pop("TERMINAL_CWD", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == "/workspace"
|
||||
assert config["host_cwd"] == "/home/user/project"
|
||||
|
||||
def test_local_backend_uses_getcwd(self):
|
||||
"""Local backend should use os.getcwd(), not /root."""
|
||||
with patch.dict(os.environ, {"TERMINAL_ENV": "local"}, clear=False):
|
||||
@@ -134,6 +163,31 @@ class TestCwdHandling:
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == os.getcwd()
|
||||
|
||||
def test_create_environment_passes_docker_host_cwd_and_flag(self, monkeypatch):
|
||||
"""Docker host cwd and mount flag should reach DockerEnvironment."""
|
||||
captured = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_docker_environment(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(_tt_mod, "_DockerEnvironment", _fake_docker_environment)
|
||||
|
||||
env = _tt_mod._create_environment(
|
||||
env_type="docker",
|
||||
image="python:3.11",
|
||||
cwd="/workspace",
|
||||
timeout=60,
|
||||
container_config={"docker_mount_cwd_to_workspace": True},
|
||||
host_cwd="/home/user/project",
|
||||
)
|
||||
|
||||
assert env is sentinel
|
||||
assert captured["cwd"] == "/workspace"
|
||||
assert captured["host_cwd"] == "/home/user/project"
|
||||
assert captured["auto_mount_cwd"] is True
|
||||
|
||||
def test_ssh_preserves_home_paths(self):
|
||||
"""SSH backend should NOT replace /home/ paths (they're valid remotely)."""
|
||||
with patch.dict(os.environ, {
|
||||
|
||||
Reference in New Issue
Block a user