diff --git a/tests/tools/test_session_search.py b/tests/tools/test_session_search.py index c90023affd..6cb44341c4 100644 --- a/tests/tools/test_session_search.py +++ b/tests/tools/test_session_search.py @@ -10,6 +10,7 @@ from tools.session_search_tool import ( _format_conversation, _truncate_around_matches, _get_session_search_max_concurrency, + _list_recent_sessions, _HIDDEN_SESSION_SOURCES, MAX_SESSION_CHARS, SESSION_SEARCH_SCHEMA, @@ -240,6 +241,54 @@ class TestSessionSearchConcurrency: assert max_seen["value"] == 1 +class TestRecentSessionListing: + def test_current_child_session_excludes_root_lineage_even_when_child_id_is_longer(self): + from unittest.mock import MagicMock + + mock_db = MagicMock() + mock_db.list_sessions_rich.return_value = [ + { + "id": "root", + "title": "Current conversation", + "source": "cli", + "started_at": 1709500000, + "last_active": 1709500100, + "message_count": 4, + "preview": "current root", + "parent_session_id": None, + }, + { + "id": "other_session", + "title": "Other conversation", + "source": "cli", + "started_at": 1709400000, + "last_active": 1709400100, + "message_count": 3, + "preview": "other root", + "parent_session_id": None, + }, + ] + + def _get_session(session_id): + if session_id == "child_session_id_that_is_definitely_longer": + return {"parent_session_id": "root"} + if session_id == "root": + return {"parent_session_id": None} + return None + + mock_db.get_session.side_effect = _get_session + + result = json.loads(_list_recent_sessions( + mock_db, + limit=5, + current_session_id="child_session_id_that_is_definitely_longer", + )) + + assert result["success"] is True + assert [item["session_id"] for item in result["results"]] == ["other_session"] + assert all(item["session_id"] != "root" for item in result["results"]) + + # ========================================================================= # session_search (dispatcher) # ========================================================================= diff --git a/tools/session_search_tool.py b/tools/session_search_tool.py index 16aaea109f..ff3153afaf 100644 --- a/tools/session_search_tool.py +++ b/tools/session_search_tool.py @@ -274,12 +274,13 @@ def _list_recent_sessions(db, limit: int, current_session_id: str = None) -> str try: sid = current_session_id visited = set() + current_root = current_session_id while sid and sid not in visited: visited.add(sid) + current_root = sid s = db.get_session(sid) parent = s.get("parent_session_id") if s else None sid = parent if parent else None - current_root = max(visited, key=len) if visited else current_session_id except Exception: current_root = current_session_id