mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-10 12:18:44 +08:00
Compare commits
4 Commits
feat/plugi
...
salvage/pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8758182ac0 | ||
|
|
a914c319b1 | ||
|
|
a6a2e16e3c | ||
|
|
f86a585e14 |
@@ -9009,7 +9009,9 @@ class GatewayRunner:
|
||||
context = build_session_context(source, self.config, session_entry)
|
||||
|
||||
# Set session context variables for tools (task-local, concurrency-safe)
|
||||
_session_env_tokens = self._set_session_env(context)
|
||||
_session_env_tokens = self._set_session_env(
|
||||
context, cwd=session_entry.cwd or ""
|
||||
)
|
||||
|
||||
# Read privacy.redact_pii from config (re-read per message)
|
||||
_redact_pii = False
|
||||
@@ -9621,6 +9623,39 @@ class GatewayRunner:
|
||||
source, session_entry, reason="agent-result-compression",
|
||||
)
|
||||
|
||||
# Persist the agent's current working directory so it survives
|
||||
# gateway restarts (#41128). Read the terminal environment's live
|
||||
# cwd (which reflects cd commands executed during the turn) rather
|
||||
# than the static ContextVar set at turn start.
|
||||
#
|
||||
# _active_environments is keyed by the terminal task_id (see
|
||||
# tools/terminal_tool.py), which for gateway turns is usually the
|
||||
# agent's task_id rather than the gateway session_id. Try both:
|
||||
# session_id first (covers backends that key on it), then the
|
||||
# task_id returned in the agent result.
|
||||
try:
|
||||
from tools.terminal_tool import _active_environments
|
||||
|
||||
_env = _active_environments.get(session_entry.session_id)
|
||||
if _env is None:
|
||||
_task_id = agent_result.get("task_id", "")
|
||||
if _task_id:
|
||||
_env = _active_environments.get(_task_id)
|
||||
if _env is not None:
|
||||
_live_cwd = getattr(_env, "cwd", "")
|
||||
if _live_cwd and _live_cwd != session_entry.cwd:
|
||||
session_entry.cwd = _live_cwd
|
||||
self.session_store._save()
|
||||
except Exception:
|
||||
# Best-effort; never let cwd persistence break the turn.
|
||||
# Log at debug so a lookup-key regression is diagnosable
|
||||
# instead of failing silently.
|
||||
logger.debug(
|
||||
"Failed to persist session cwd for %s",
|
||||
getattr(session_entry, "session_id", "?"),
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Prepend reasoning/thinking if display is enabled (per-platform)
|
||||
try:
|
||||
from gateway.display_config import resolve_display_setting as _rds
|
||||
@@ -15728,7 +15763,7 @@ class GatewayRunner:
|
||||
|
||||
return delivered
|
||||
|
||||
def _set_session_env(self, context: SessionContext) -> list:
|
||||
def _set_session_env(self, context: SessionContext, cwd: str = "") -> list:
|
||||
"""Set session context variables for the current async task.
|
||||
|
||||
Uses ``contextvars`` instead of ``os.environ`` so that concurrent
|
||||
@@ -15747,6 +15782,7 @@ class GatewayRunner:
|
||||
user_name=str(context.source.user_name) if context.source.user_name else "",
|
||||
session_key=context.session_key,
|
||||
message_id=str(context.source.message_id) if context.source.message_id else "",
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
def _clear_session_env(self, tokens: list) -> None:
|
||||
|
||||
@@ -491,6 +491,12 @@ class SessionEntry:
|
||||
resume_reason: Optional[str] = None # e.g. "restart_timeout"
|
||||
last_resume_marked_at: Optional[datetime] = None
|
||||
|
||||
# Logical working directory for this session. Persisted across gateway
|
||||
# restarts so that long-running conversations don't lose their file-system
|
||||
# context when the agent changes directory via the terminal tool.
|
||||
# See issue #41128.
|
||||
cwd: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"session_key": self.session_key,
|
||||
@@ -521,6 +527,7 @@ class SessionEntry:
|
||||
"was_auto_reset": self.was_auto_reset,
|
||||
"auto_reset_reason": self.auto_reset_reason,
|
||||
"reset_had_activity": self.reset_had_activity,
|
||||
"cwd": self.cwd,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
@@ -573,6 +580,7 @@ class SessionEntry:
|
||||
was_auto_reset=data.get("was_auto_reset", False),
|
||||
auto_reset_reason=data.get("auto_reset_reason"),
|
||||
reset_had_activity=data.get("reset_had_activity", False),
|
||||
cwd=data.get("cwd"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ AUTHOR_MAP = {
|
||||
"129007007+HeLLGURD@users.noreply.github.com": "HeLLGURD",
|
||||
"290859878+synapsesx@users.noreply.github.com": "synapsesx",
|
||||
"dirtyren@users.noreply.github.com": "dirtyren",
|
||||
"islam666@users.noreply.github.com": "islam666",
|
||||
"zhaolei.vc@bytedance.com": "zhaoleibd",
|
||||
"jeffrobodie@gmail.com": "jeffrobodie-glitch",
|
||||
"kyssta-exe@users.noreply.github.com": "kyssta-exe",
|
||||
|
||||
273
tests/gateway/test_session_cwd_persistence.py
Normal file
273
tests/gateway/test_session_cwd_persistence.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Tests for gateway session cwd persistence (issue #41128).
|
||||
|
||||
The gateway's SessionEntry now tracks the agent's working directory
|
||||
so that it survives gateway restarts. When the agent changes
|
||||
directory via the terminal tool, the new cwd is saved to the
|
||||
session entry at the end of the turn.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionEntry cwd field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionEntryCwd:
|
||||
def test_default_cwd_is_none(self):
|
||||
from gateway.session import SessionEntry
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
assert entry.cwd is None
|
||||
|
||||
def test_cwd_roundtrip_to_dict(self):
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
cwd="/home/user/projects/foo",
|
||||
)
|
||||
d = entry.to_dict()
|
||||
assert d["cwd"] == "/home/user/projects/foo"
|
||||
|
||||
def test_cwd_roundtrip_from_dict(self):
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
data = {
|
||||
"session_key": "test",
|
||||
"session_id": "s1",
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
"cwd": "/home/user/projects/bar",
|
||||
}
|
||||
entry = SessionEntry.from_dict(data)
|
||||
assert entry.cwd == "/home/user/projects/bar"
|
||||
|
||||
def test_cwd_none_in_dict(self):
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
data = {
|
||||
"session_key": "test",
|
||||
"session_id": "s1",
|
||||
"created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
}
|
||||
entry = SessionEntry.from_dict(data)
|
||||
assert entry.cwd is None
|
||||
|
||||
def test_cwd_serialization_to_json(self, tmp_path):
|
||||
"""SessionEntry with cwd should survive JSON roundtrip (sessions.json)."""
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
cwd="/tmp/workspace",
|
||||
)
|
||||
# Simulate sessions.json write/read
|
||||
sessions_file = tmp_path / "sessions.json"
|
||||
sessions_file.write_text(json.dumps({"test": entry.to_dict()}))
|
||||
loaded = json.loads(sessions_file.read_text())
|
||||
restored = SessionEntry.from_dict(loaded["test"])
|
||||
assert restored.cwd == "/tmp/workspace"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _set_session_env passes cwd
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetSessionEnvCwd:
|
||||
def test_set_session_vars_accepts_cwd(self):
|
||||
"""set_session_vars should accept and forward cwd kwarg."""
|
||||
from gateway.session_context import set_session_vars
|
||||
with patch("gateway.session_context._SESSION_PLATFORM") as mock_p, \
|
||||
patch("gateway.session_context._SESSION_CHAT_ID") as mock_c, \
|
||||
patch("gateway.session_context._SESSION_CHAT_NAME") as mock_cn, \
|
||||
patch("gateway.session_context._SESSION_THREAD_ID") as mock_t, \
|
||||
patch("gateway.session_context._SESSION_USER_ID") as mock_u, \
|
||||
patch("gateway.session_context._SESSION_USER_NAME") as mock_un, \
|
||||
patch("gateway.session_context._SESSION_KEY") as mock_sk, \
|
||||
patch("gateway.session_context._SESSION_MESSAGE_ID") as mock_m:
|
||||
for m in (mock_p, mock_c, mock_cn, mock_t, mock_u, mock_un, mock_sk, mock_m):
|
||||
m.set.return_value = "token"
|
||||
with patch("agent.runtime_cwd.set_session_cwd") as mock_sc:
|
||||
set_session_vars(
|
||||
platform="telegram", chat_id="123", chat_name="",
|
||||
thread_id="", user_id="456", user_name="",
|
||||
session_key="test", message_id="", cwd="/home/user/proj",
|
||||
)
|
||||
mock_sc.assert_called_once_with("/home/user/proj")
|
||||
|
||||
def test_set_session_vars_empty_cwd(self):
|
||||
"""set_session_vars with empty cwd should still call set_session_cwd."""
|
||||
from gateway.session_context import set_session_vars
|
||||
with patch("gateway.session_context._SESSION_PLATFORM") as mock_p, \
|
||||
patch("gateway.session_context._SESSION_CHAT_ID") as mock_c, \
|
||||
patch("gateway.session_context._SESSION_CHAT_NAME") as mock_cn, \
|
||||
patch("gateway.session_context._SESSION_THREAD_ID") as mock_t, \
|
||||
patch("gateway.session_context._SESSION_USER_ID") as mock_u, \
|
||||
patch("gateway.session_context._SESSION_USER_NAME") as mock_un, \
|
||||
patch("gateway.session_context._SESSION_KEY") as mock_sk, \
|
||||
patch("gateway.session_context._SESSION_MESSAGE_ID") as mock_m:
|
||||
for m in (mock_p, mock_c, mock_cn, mock_t, mock_u, mock_un, mock_sk, mock_m):
|
||||
m.set.return_value = "token"
|
||||
with patch("agent.runtime_cwd.set_session_cwd") as mock_sc:
|
||||
set_session_vars(
|
||||
platform="telegram", chat_id="123", chat_name="",
|
||||
thread_id="", user_id="456", user_name="",
|
||||
session_key="test", message_id="",
|
||||
)
|
||||
mock_sc.assert_called_once_with("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cwd persistence at end of turn
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCwdPersistence:
|
||||
def test_cwd_saved_from_terminal_env(self):
|
||||
"""When terminal env has a different cwd, it should be saved."""
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
cwd="/old/path",
|
||||
)
|
||||
|
||||
mock_env = MagicMock()
|
||||
mock_env.cwd = "/new/path"
|
||||
|
||||
with patch("tools.terminal_tool._active_environments", {"s1": mock_env}):
|
||||
from tools.terminal_tool import _active_environments
|
||||
_env = _active_environments.get(entry.session_id)
|
||||
if _env is not None:
|
||||
_live_cwd = getattr(_env, "cwd", "")
|
||||
if _live_cwd and _live_cwd != entry.cwd:
|
||||
entry.cwd = _live_cwd
|
||||
|
||||
assert entry.cwd == "/new/path"
|
||||
|
||||
def test_cwd_not_saved_when_unchanged(self):
|
||||
"""When cwd hasn't changed, no update should occur."""
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
cwd="/same/path",
|
||||
)
|
||||
|
||||
mock_env = MagicMock()
|
||||
mock_env.cwd = "/same/path"
|
||||
|
||||
with patch("tools.terminal_tool._active_environments", {"s1": mock_env}):
|
||||
from tools.terminal_tool import _active_environments
|
||||
_env = _active_environments.get(entry.session_id)
|
||||
if _env is not None:
|
||||
_live_cwd = getattr(_env, "cwd", "")
|
||||
if _live_cwd and _live_cwd != entry.cwd:
|
||||
entry.cwd = _live_cwd
|
||||
|
||||
assert entry.cwd == "/same/path"
|
||||
|
||||
def test_cwd_saved_when_previously_none(self):
|
||||
"""When cwd was None (old session), it should be saved."""
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
cwd=None,
|
||||
)
|
||||
|
||||
mock_env = MagicMock()
|
||||
mock_env.cwd = "/first/path"
|
||||
|
||||
with patch("tools.terminal_tool._active_environments", {"s1": mock_env}):
|
||||
from tools.terminal_tool import _active_environments
|
||||
_env = _active_environments.get(entry.session_id)
|
||||
if _env is not None:
|
||||
_live_cwd = getattr(_env, "cwd", "")
|
||||
if _live_cwd and _live_cwd != entry.cwd:
|
||||
entry.cwd = _live_cwd
|
||||
|
||||
assert entry.cwd == "/first/path"
|
||||
|
||||
def test_cwd_fallback_to_task_id(self):
|
||||
"""When session_id doesn't match, fall back to task_id lookup."""
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
cwd=None,
|
||||
)
|
||||
|
||||
mock_env = MagicMock()
|
||||
mock_env.cwd = "/task/path"
|
||||
|
||||
# No match on session_id, but match on task_id
|
||||
with patch("tools.terminal_tool._active_environments", {"task-123": mock_env}):
|
||||
from tools.terminal_tool import _active_environments
|
||||
_env = _active_environments.get(entry.session_id)
|
||||
if _env is None:
|
||||
_task_id = "task-123"
|
||||
if _task_id:
|
||||
_env = _active_environments.get(_task_id)
|
||||
if _env is not None:
|
||||
_live_cwd = getattr(_env, "cwd", "")
|
||||
if _live_cwd and _live_cwd != entry.cwd:
|
||||
entry.cwd = _live_cwd
|
||||
|
||||
assert entry.cwd == "/task/path"
|
||||
|
||||
def test_cwd_no_env_graceful(self):
|
||||
"""When no terminal env exists, cwd should not change."""
|
||||
from gateway.session import SessionEntry
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
cwd="/unchanged",
|
||||
)
|
||||
|
||||
with patch("tools.terminal_tool._active_environments", {}):
|
||||
from tools.terminal_tool import _active_environments
|
||||
_env = _active_environments.get(entry.session_id)
|
||||
if _env is not None:
|
||||
_live_cwd = getattr(_env, "cwd", "")
|
||||
if _live_cwd and _live_cwd != entry.cwd:
|
||||
entry.cwd = _live_cwd
|
||||
|
||||
assert entry.cwd == "/unchanged"
|
||||
Reference in New Issue
Block a user