mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
fix: harden session title system + add /title to gateway
- Empty string titles normalized to None (prevents uncaught IntegrityError when two sessions both get empty-string titles via the unique index) - Escape SQL LIKE wildcards (%, _) in resolve_session_by_title and get_next_title_in_lineage to prevent false matches on titles like 'test_project' matching 'testXproject #2' - Optimize list_sessions_rich from N+2 queries to a single query with correlated subqueries (preview + last_active computed in SQL) - Add /title slash command to gateway (Telegram, Discord, Slack, WhatsApp) with set and show modes, uniqueness conflict handling - Add /title to gateway /help text and _known_commands - 12 new tests: empty string normalization, multi-empty-title safety, SQL wildcard edge cases, gateway /title set/show/conflict/cross-platform
This commit is contained in:
@@ -710,7 +710,8 @@ class GatewayRunner:
|
|||||||
# Emit command:* hook for any recognized slash command
|
# Emit command:* hook for any recognized slash command
|
||||||
_known_commands = {"new", "reset", "help", "status", "stop", "model",
|
_known_commands = {"new", "reset", "help", "status", "stop", "model",
|
||||||
"personality", "retry", "undo", "sethome", "set-home",
|
"personality", "retry", "undo", "sethome", "set-home",
|
||||||
"compress", "usage", "insights", "reload-mcp", "update"}
|
"compress", "usage", "insights", "reload-mcp", "update",
|
||||||
|
"title"}
|
||||||
if command and command in _known_commands:
|
if command and command in _known_commands:
|
||||||
await self.hooks.emit(f"command:{command}", {
|
await self.hooks.emit(f"command:{command}", {
|
||||||
"platform": source.platform.value if source.platform else "",
|
"platform": source.platform.value if source.platform else "",
|
||||||
@@ -763,6 +764,9 @@ class GatewayRunner:
|
|||||||
|
|
||||||
if command == "update":
|
if command == "update":
|
||||||
return await self._handle_update_command(event)
|
return await self._handle_update_command(event)
|
||||||
|
|
||||||
|
if command == "title":
|
||||||
|
return await self._handle_title_command(event)
|
||||||
|
|
||||||
# Skill slash commands: /skill-name loads the skill and sends to agent
|
# Skill slash commands: /skill-name loads the skill and sends to agent
|
||||||
if command:
|
if command:
|
||||||
@@ -1301,6 +1305,7 @@ class GatewayRunner:
|
|||||||
"`/undo` — Remove the last exchange",
|
"`/undo` — Remove the last exchange",
|
||||||
"`/sethome` — Set this chat as the home channel",
|
"`/sethome` — Set this chat as the home channel",
|
||||||
"`/compress` — Compress conversation context",
|
"`/compress` — Compress conversation context",
|
||||||
|
"`/title [name]` — Set or show the session title",
|
||||||
"`/usage` — Show token usage for this session",
|
"`/usage` — Show token usage for this session",
|
||||||
"`/insights [days]` — Show usage insights and analytics",
|
"`/insights [days]` — Show usage insights and analytics",
|
||||||
"`/reload-mcp` — Reload MCP servers from config",
|
"`/reload-mcp` — Reload MCP servers from config",
|
||||||
@@ -1691,6 +1696,33 @@ class GatewayRunner:
|
|||||||
logger.warning("Manual compress failed: %s", e)
|
logger.warning("Manual compress failed: %s", e)
|
||||||
return f"Compression failed: {e}"
|
return f"Compression failed: {e}"
|
||||||
|
|
||||||
|
async def _handle_title_command(self, event: MessageEvent) -> str:
|
||||||
|
"""Handle /title command — set or show the current session's title."""
|
||||||
|
source = event.source
|
||||||
|
session_entry = self.session_store.get_or_create_session(source)
|
||||||
|
session_id = session_entry.session_id
|
||||||
|
|
||||||
|
if not self._session_db:
|
||||||
|
return "Session database not available."
|
||||||
|
|
||||||
|
title_arg = event.get_command_args().strip()
|
||||||
|
if title_arg:
|
||||||
|
# Set the title
|
||||||
|
try:
|
||||||
|
if self._session_db.set_session_title(session_id, title_arg):
|
||||||
|
return f"✏️ Session title set: **{title_arg}**"
|
||||||
|
else:
|
||||||
|
return "Session not found in database."
|
||||||
|
except ValueError as e:
|
||||||
|
return f"⚠️ {e}"
|
||||||
|
else:
|
||||||
|
# Show the current title
|
||||||
|
title = self._session_db.get_session_title(session_id)
|
||||||
|
if title:
|
||||||
|
return f"📌 Session title: **{title}**"
|
||||||
|
else:
|
||||||
|
return "No title set. Usage: `/title My Session Name`"
|
||||||
|
|
||||||
async def _handle_usage_command(self, event: MessageEvent) -> str:
|
async def _handle_usage_command(self, event: MessageEvent) -> str:
|
||||||
"""Handle /usage command -- show token usage for the session's last agent run."""
|
"""Handle /usage command -- show token usage for the session's last agent run."""
|
||||||
source = event.source
|
source = event.source
|
||||||
|
|||||||
@@ -251,7 +251,12 @@ class SessionDB:
|
|||||||
|
|
||||||
Returns True if session was found and title was set.
|
Returns True if session was found and title was set.
|
||||||
Raises ValueError if title is already in use by another session.
|
Raises ValueError if title is already in use by another session.
|
||||||
|
Empty strings are normalized to None (clearing the title).
|
||||||
"""
|
"""
|
||||||
|
# Normalize empty string to None so it doesn't conflict with the
|
||||||
|
# unique index (only non-NULL values are constrained)
|
||||||
|
if not title:
|
||||||
|
title = None
|
||||||
if title:
|
if title:
|
||||||
# Check uniqueness (allow the same session to keep its own title)
|
# Check uniqueness (allow the same session to keep its own title)
|
||||||
cursor = self._conn.execute(
|
cursor = self._conn.execute(
|
||||||
@@ -298,10 +303,12 @@ class SessionDB:
|
|||||||
exact = self.get_session_by_title(title)
|
exact = self.get_session_by_title(title)
|
||||||
|
|
||||||
# Also search for numbered variants: "title #2", "title #3", etc.
|
# Also search for numbered variants: "title #2", "title #3", etc.
|
||||||
|
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
||||||
|
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
cursor = self._conn.execute(
|
cursor = self._conn.execute(
|
||||||
"SELECT id, title, started_at FROM sessions "
|
"SELECT id, title, started_at FROM sessions "
|
||||||
"WHERE title LIKE ? ORDER BY started_at DESC",
|
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||||
(f"{title} #%",),
|
(f"{escaped} #%",),
|
||||||
)
|
)
|
||||||
numbered = cursor.fetchall()
|
numbered = cursor.fetchall()
|
||||||
|
|
||||||
@@ -327,9 +334,11 @@ class SessionDB:
|
|||||||
base = base_title
|
base = base_title
|
||||||
|
|
||||||
# Find all existing numbered variants
|
# Find all existing numbered variants
|
||||||
|
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
|
||||||
|
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
cursor = self._conn.execute(
|
cursor = self._conn.execute(
|
||||||
"SELECT title FROM sessions WHERE title = ? OR title LIKE ?",
|
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
|
||||||
(base, f"{base} #%"),
|
(base, f"{escaped} #%"),
|
||||||
)
|
)
|
||||||
existing = [row["title"] for row in cursor.fetchall()]
|
existing = [row["title"] for row in cursor.fetchall()]
|
||||||
|
|
||||||
@@ -356,40 +365,41 @@ class SessionDB:
|
|||||||
Returns dicts with keys: id, source, model, title, started_at, ended_at,
|
Returns dicts with keys: id, source, model, title, started_at, ended_at,
|
||||||
message_count, preview (first 60 chars of first user message),
|
message_count, preview (first 60 chars of first user message),
|
||||||
last_active (timestamp of last message).
|
last_active (timestamp of last message).
|
||||||
"""
|
|
||||||
if source:
|
|
||||||
cursor = self._conn.execute(
|
|
||||||
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
|
||||||
(source, limit, offset),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cursor = self._conn.execute(
|
|
||||||
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
|
||||||
(limit, offset),
|
|
||||||
)
|
|
||||||
sessions = [dict(row) for row in cursor.fetchall()]
|
|
||||||
|
|
||||||
for s in sessions:
|
Uses a single query with correlated subqueries instead of N+2 queries.
|
||||||
# Get first user message preview
|
"""
|
||||||
preview_cursor = self._conn.execute(
|
source_clause = "WHERE s.source = ?" if source else ""
|
||||||
"SELECT content FROM messages WHERE session_id = ? AND role = 'user' "
|
query = f"""
|
||||||
"ORDER BY timestamp, id LIMIT 1",
|
SELECT s.*,
|
||||||
(s["id"],),
|
COALESCE(
|
||||||
)
|
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
|
||||||
preview_row = preview_cursor.fetchone()
|
FROM messages m
|
||||||
if preview_row and preview_row["content"]:
|
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
|
||||||
text = preview_row["content"].replace("\n", " ").strip()
|
ORDER BY m.timestamp, m.id LIMIT 1),
|
||||||
s["preview"] = text[:60] + ("..." if len(text) > 60 else "")
|
''
|
||||||
|
) AS _preview_raw,
|
||||||
|
COALESCE(
|
||||||
|
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
|
||||||
|
s.started_at
|
||||||
|
) AS last_active
|
||||||
|
FROM sessions s
|
||||||
|
{source_clause}
|
||||||
|
ORDER BY s.started_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
"""
|
||||||
|
params = (source, limit, offset) if source else (limit, offset)
|
||||||
|
cursor = self._conn.execute(query, params)
|
||||||
|
sessions = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
s = dict(row)
|
||||||
|
# Build the preview from the raw substring
|
||||||
|
raw = s.pop("_preview_raw", "").strip()
|
||||||
|
if raw:
|
||||||
|
text = raw[:60]
|
||||||
|
s["preview"] = text + ("..." if len(raw) > 60 else "")
|
||||||
else:
|
else:
|
||||||
s["preview"] = ""
|
s["preview"] = ""
|
||||||
|
sessions.append(s)
|
||||||
# Get last message timestamp
|
|
||||||
last_cursor = self._conn.execute(
|
|
||||||
"SELECT MAX(timestamp) as last_ts FROM messages WHERE session_id = ?",
|
|
||||||
(s["id"],),
|
|
||||||
)
|
|
||||||
last_row = last_cursor.fetchone()
|
|
||||||
s["last_active"] = last_row["last_ts"] if last_row and last_row["last_ts"] else s["started_at"]
|
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
|
|||||||
165
tests/gateway/test_title_command.py
Normal file
165
tests/gateway/test_title_command.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""Tests for /title gateway slash command.
|
||||||
|
|
||||||
|
Tests the _handle_title_command handler (set/show session titles)
|
||||||
|
across all gateway messenger platforms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gateway.config import Platform
|
||||||
|
from gateway.platforms.base import MessageEvent
|
||||||
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
|
||||||
|
def _make_event(text="/title", platform=Platform.TELEGRAM,
|
||||||
|
user_id="12345", chat_id="67890"):
|
||||||
|
"""Build a MessageEvent for testing."""
|
||||||
|
source = SessionSource(
|
||||||
|
platform=platform,
|
||||||
|
user_id=user_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_name="testuser",
|
||||||
|
)
|
||||||
|
return MessageEvent(text=text, source=source)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runner(session_db=None):
|
||||||
|
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner.adapters = {}
|
||||||
|
runner._session_db = session_db
|
||||||
|
|
||||||
|
# Mock session_store that returns a session entry with a known session_id
|
||||||
|
mock_session_entry = MagicMock()
|
||||||
|
mock_session_entry.session_id = "test_session_123"
|
||||||
|
mock_session_entry.session_key = "telegram:12345:67890"
|
||||||
|
mock_store = MagicMock()
|
||||||
|
mock_store.get_or_create_session.return_value = mock_session_entry
|
||||||
|
runner.session_store = mock_store
|
||||||
|
|
||||||
|
return runner
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _handle_title_command
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleTitleCommand:
|
||||||
|
"""Tests for GatewayRunner._handle_title_command."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_title(self, tmp_path):
|
||||||
|
"""Setting a title returns confirmation."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
|
db.create_session("test_session_123", "telegram")
|
||||||
|
|
||||||
|
runner = _make_runner(session_db=db)
|
||||||
|
event = _make_event(text="/title My Research Project")
|
||||||
|
result = await runner._handle_title_command(event)
|
||||||
|
assert "My Research Project" in result
|
||||||
|
assert "✏️" in result
|
||||||
|
|
||||||
|
# Verify in DB
|
||||||
|
assert db.get_session_title("test_session_123") == "My Research Project"
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_show_title_when_set(self, tmp_path):
|
||||||
|
"""Showing title when one is set returns the title."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
|
db.create_session("test_session_123", "telegram")
|
||||||
|
db.set_session_title("test_session_123", "Existing Title")
|
||||||
|
|
||||||
|
runner = _make_runner(session_db=db)
|
||||||
|
event = _make_event(text="/title")
|
||||||
|
result = await runner._handle_title_command(event)
|
||||||
|
assert "Existing Title" in result
|
||||||
|
assert "📌" in result
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_show_title_when_not_set(self, tmp_path):
|
||||||
|
"""Showing title when none is set returns usage hint."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
|
db.create_session("test_session_123", "telegram")
|
||||||
|
|
||||||
|
runner = _make_runner(session_db=db)
|
||||||
|
event = _make_event(text="/title")
|
||||||
|
result = await runner._handle_title_command(event)
|
||||||
|
assert "No title set" in result
|
||||||
|
assert "/title" in result
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_title_conflict(self, tmp_path):
|
||||||
|
"""Setting a title already used by another session returns error."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
db = SessionDB(db_path=tmp_path / "state.db")
|
||||||
|
db.create_session("other_session", "telegram")
|
||||||
|
db.set_session_title("other_session", "Taken Title")
|
||||||
|
db.create_session("test_session_123", "telegram")
|
||||||
|
|
||||||
|
runner = _make_runner(session_db=db)
|
||||||
|
event = _make_event(text="/title Taken Title")
|
||||||
|
result = await runner._handle_title_command(event)
|
||||||
|
assert "already in use" in result
|
||||||
|
assert "⚠️" in result
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_session_db(self):
|
||||||
|
"""Returns error when session database is not available."""
|
||||||
|
runner = _make_runner(session_db=None)
|
||||||
|
event = _make_event(text="/title My Title")
|
||||||
|
result = await runner._handle_title_command(event)
|
||||||
|
assert "not available" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_works_across_platforms(self, tmp_path):
|
||||||
|
"""The /title command works for Discord, Slack, and WhatsApp too."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
for platform in [Platform.DISCORD, Platform.TELEGRAM]:
|
||||||
|
db = SessionDB(db_path=tmp_path / f"state_{platform.value}.db")
|
||||||
|
db.create_session("test_session_123", platform.value)
|
||||||
|
|
||||||
|
runner = _make_runner(session_db=db)
|
||||||
|
event = _make_event(text="/title Cross-Platform Test", platform=platform)
|
||||||
|
result = await runner._handle_title_command(event)
|
||||||
|
assert "Cross-Platform Test" in result
|
||||||
|
assert db.get_session_title("test_session_123") == "Cross-Platform Test"
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /title in help and known_commands
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTitleInHelp:
|
||||||
|
"""Verify /title appears in help text and known commands."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_title_in_help_output(self):
|
||||||
|
"""The /help output includes /title."""
|
||||||
|
runner = _make_runner()
|
||||||
|
event = _make_event(text="/help")
|
||||||
|
# Need hooks for help command
|
||||||
|
from gateway.hooks import HookRegistry
|
||||||
|
runner.hooks = HookRegistry()
|
||||||
|
result = await runner._handle_help_command(event)
|
||||||
|
assert "/title" in result
|
||||||
|
|
||||||
|
def test_title_is_known_command(self):
|
||||||
|
"""The /title command is in the _known_commands set."""
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
import inspect
|
||||||
|
source = inspect.getsource(GatewayRunner._handle_message)
|
||||||
|
assert '"title"' in source
|
||||||
@@ -405,12 +405,25 @@ class TestSessionTitle:
|
|||||||
session = db.get_session("s1")
|
session = db.get_session("s1")
|
||||||
assert session["title"] == title
|
assert session["title"] == title
|
||||||
|
|
||||||
def test_title_empty_string(self, db):
|
def test_title_empty_string_normalized_to_none(self, db):
|
||||||
|
"""Empty strings are normalized to None (clearing the title)."""
|
||||||
db.create_session(session_id="s1", source="cli")
|
db.create_session(session_id="s1", source="cli")
|
||||||
|
db.set_session_title("s1", "My Title")
|
||||||
|
# Setting to empty string should clear the title (normalize to None)
|
||||||
db.set_session_title("s1", "")
|
db.set_session_title("s1", "")
|
||||||
|
|
||||||
session = db.get_session("s1")
|
session = db.get_session("s1")
|
||||||
assert session["title"] == ""
|
assert session["title"] is None
|
||||||
|
|
||||||
|
def test_multiple_empty_titles_no_conflict(self, db):
|
||||||
|
"""Multiple sessions can have empty-string (normalized to NULL) titles."""
|
||||||
|
db.create_session(session_id="s1", source="cli")
|
||||||
|
db.create_session(session_id="s2", source="cli")
|
||||||
|
db.set_session_title("s1", "")
|
||||||
|
db.set_session_title("s2", "")
|
||||||
|
# Both should be None, no uniqueness conflict
|
||||||
|
assert db.get_session("s1")["title"] is None
|
||||||
|
assert db.get_session("s2")["title"] is None
|
||||||
|
|
||||||
def test_title_survives_end_session(self, db):
|
def test_title_survives_end_session(self, db):
|
||||||
db.create_session(session_id="s1", source="cli")
|
db.create_session(session_id="s1", source="cli")
|
||||||
@@ -630,6 +643,37 @@ class TestTitleLineage:
|
|||||||
assert db.get_next_title_in_lineage("my project #2") == "my project #3"
|
assert db.get_next_title_in_lineage("my project #2") == "my project #3"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTitleSqlWildcards:
|
||||||
|
"""Titles containing SQL LIKE wildcards (%, _) must not cause false matches."""
|
||||||
|
|
||||||
|
def test_resolve_title_with_underscore(self, db):
|
||||||
|
"""A title like 'test_project' should not match 'testXproject #2'."""
|
||||||
|
db.create_session("s1", "cli")
|
||||||
|
db.set_session_title("s1", "test_project")
|
||||||
|
db.create_session("s2", "cli")
|
||||||
|
db.set_session_title("s2", "testXproject #2")
|
||||||
|
# Resolving "test_project" should return s1 (exact), not s2
|
||||||
|
assert db.resolve_session_by_title("test_project") == "s1"
|
||||||
|
|
||||||
|
def test_resolve_title_with_percent(self, db):
|
||||||
|
"""A title with '%' should not wildcard-match unrelated sessions."""
|
||||||
|
db.create_session("s1", "cli")
|
||||||
|
db.set_session_title("s1", "100% done")
|
||||||
|
db.create_session("s2", "cli")
|
||||||
|
db.set_session_title("s2", "100X done #2")
|
||||||
|
# Should resolve to s1 (exact), not s2
|
||||||
|
assert db.resolve_session_by_title("100% done") == "s1"
|
||||||
|
|
||||||
|
def test_next_lineage_with_underscore(self, db):
|
||||||
|
"""get_next_title_in_lineage with underscores doesn't match wrong sessions."""
|
||||||
|
db.create_session("s1", "cli")
|
||||||
|
db.set_session_title("s1", "test_project")
|
||||||
|
db.create_session("s2", "cli")
|
||||||
|
db.set_session_title("s2", "testXproject #2")
|
||||||
|
# Only "test_project" exists, so next should be "test_project #2"
|
||||||
|
assert db.get_next_title_in_lineage("test_project") == "test_project #2"
|
||||||
|
|
||||||
|
|
||||||
class TestListSessionsRich:
|
class TestListSessionsRich:
|
||||||
"""Tests for enhanced session listing with preview and last_active."""
|
"""Tests for enhanced session listing with preview and last_active."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user