mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
Two bugs fixed:
1. search_messages() crashes with OperationalError when user queries
contain FTS5 special characters (+, ", (, {, dangling AND/OR, etc).
Added _sanitize_fts5_query() to strip dangerous operators and a
fallback try-except for edge cases.
2. _append_to_sqlite() in mirror.py creates a new SessionDB per call
but never closes it, leaking SQLite connections. Added finally block
to ensure db.close() is always called.
425 lines
16 KiB
Python
425 lines
16 KiB
Python
"""Tests for hermes_state.py — SessionDB SQLite CRUD, FTS5 search, export."""
|
|
|
|
import time
|
|
import pytest
|
|
from pathlib import Path
|
|
|
|
from hermes_state import SessionDB
|
|
|
|
|
|
@pytest.fixture()
|
|
def db(tmp_path):
|
|
"""Create a SessionDB with a temp database file."""
|
|
db_path = tmp_path / "test_state.db"
|
|
session_db = SessionDB(db_path=db_path)
|
|
yield session_db
|
|
session_db.close()
|
|
|
|
|
|
# =========================================================================
|
|
# Session lifecycle
|
|
# =========================================================================
|
|
|
|
class TestSessionLifecycle:
|
|
def test_create_and_get_session(self, db):
|
|
sid = db.create_session(
|
|
session_id="s1",
|
|
source="cli",
|
|
model="test-model",
|
|
)
|
|
assert sid == "s1"
|
|
|
|
session = db.get_session("s1")
|
|
assert session is not None
|
|
assert session["source"] == "cli"
|
|
assert session["model"] == "test-model"
|
|
assert session["ended_at"] is None
|
|
|
|
def test_get_nonexistent_session(self, db):
|
|
assert db.get_session("nonexistent") is None
|
|
|
|
def test_end_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.end_session("s1", end_reason="user_exit")
|
|
|
|
session = db.get_session("s1")
|
|
assert isinstance(session["ended_at"], float)
|
|
assert session["end_reason"] == "user_exit"
|
|
|
|
def test_update_system_prompt(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.update_system_prompt("s1", "You are a helpful assistant.")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["system_prompt"] == "You are a helpful assistant."
|
|
|
|
def test_update_token_counts(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.update_token_counts("s1", input_tokens=100, output_tokens=50)
|
|
db.update_token_counts("s1", input_tokens=200, output_tokens=100)
|
|
|
|
session = db.get_session("s1")
|
|
assert session["input_tokens"] == 300
|
|
assert session["output_tokens"] == 150
|
|
|
|
def test_parent_session(self, db):
|
|
db.create_session(session_id="parent", source="cli")
|
|
db.create_session(session_id="child", source="cli", parent_session_id="parent")
|
|
|
|
child = db.get_session("child")
|
|
assert child["parent_session_id"] == "parent"
|
|
|
|
|
|
# =========================================================================
|
|
# Message storage
|
|
# =========================================================================
|
|
|
|
class TestMessageStorage:
|
|
def test_append_and_get_messages(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi there!")
|
|
|
|
messages = db.get_messages("s1")
|
|
assert len(messages) == 2
|
|
assert messages[0]["role"] == "user"
|
|
assert messages[0]["content"] == "Hello"
|
|
assert messages[1]["role"] == "assistant"
|
|
|
|
def test_message_increments_session_count(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["message_count"] == 2
|
|
|
|
def test_tool_message_increments_tool_count(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="tool", content="result", tool_name="web_search")
|
|
|
|
session = db.get_session("s1")
|
|
assert session["tool_call_count"] == 1
|
|
|
|
def test_tool_calls_serialization(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
tool_calls = [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}]
|
|
db.append_message("s1", role="assistant", tool_calls=tool_calls)
|
|
|
|
messages = db.get_messages("s1")
|
|
assert messages[0]["tool_calls"] == tool_calls
|
|
|
|
def test_get_messages_as_conversation(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi!")
|
|
|
|
conv = db.get_messages_as_conversation("s1")
|
|
assert len(conv) == 2
|
|
assert conv[0] == {"role": "user", "content": "Hello"}
|
|
assert conv[1] == {"role": "assistant", "content": "Hi!"}
|
|
|
|
def test_finish_reason_stored(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="assistant", content="Done", finish_reason="stop")
|
|
|
|
messages = db.get_messages("s1")
|
|
assert messages[0]["finish_reason"] == "stop"
|
|
|
|
|
|
# =========================================================================
|
|
# FTS5 search
|
|
# =========================================================================
|
|
|
|
class TestFTS5Search:
|
|
def test_search_finds_content(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="How do I deploy with Docker?")
|
|
db.append_message("s1", role="assistant", content="Use docker compose up.")
|
|
|
|
results = db.search_messages("docker")
|
|
assert len(results) == 2
|
|
# At least one result should mention docker
|
|
snippets = [r.get("snippet", "") for r in results]
|
|
assert any("docker" in s.lower() or "Docker" in s for s in snippets)
|
|
|
|
def test_search_empty_query(self, db):
|
|
assert db.search_messages("") == []
|
|
assert db.search_messages(" ") == []
|
|
|
|
def test_search_with_source_filter(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="CLI question about Python")
|
|
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.append_message("s2", role="user", content="Telegram question about Python")
|
|
|
|
results = db.search_messages("Python", source_filter=["telegram"])
|
|
# Should only find the telegram message
|
|
sources = [r["source"] for r in results]
|
|
assert all(s == "telegram" for s in sources)
|
|
|
|
def test_search_with_role_filter(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="What is FastAPI?")
|
|
db.append_message("s1", role="assistant", content="FastAPI is a web framework.")
|
|
|
|
results = db.search_messages("FastAPI", role_filter=["assistant"])
|
|
roles = [r["role"] for r in results]
|
|
assert all(r == "assistant" for r in roles)
|
|
|
|
def test_search_returns_context(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Tell me about Kubernetes")
|
|
db.append_message("s1", role="assistant", content="Kubernetes is an orchestrator.")
|
|
|
|
results = db.search_messages("Kubernetes")
|
|
assert len(results) == 2
|
|
assert "context" in results[0]
|
|
assert isinstance(results[0]["context"], list)
|
|
assert len(results[0]["context"]) > 0
|
|
|
|
def test_search_special_chars_do_not_crash(self, db):
|
|
"""FTS5 special characters in queries must not raise OperationalError."""
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="How do I use C++ templates?")
|
|
|
|
# Each of these previously caused sqlite3.OperationalError
|
|
dangerous_queries = [
|
|
'C++', # + is FTS5 column filter
|
|
'"unterminated', # unbalanced double-quote
|
|
'(problem', # unbalanced parenthesis
|
|
'hello AND', # dangling boolean operator
|
|
'***', # repeated wildcard
|
|
'{test}', # curly braces (column reference)
|
|
'OR hello', # leading boolean operator
|
|
'a AND OR b', # adjacent operators
|
|
]
|
|
for query in dangerous_queries:
|
|
# Must not raise — should return list (possibly empty)
|
|
results = db.search_messages(query)
|
|
assert isinstance(results, list), f"Query {query!r} did not return a list"
|
|
|
|
def test_search_sanitized_query_still_finds_content(self, db):
|
|
"""Sanitization must not break normal keyword search."""
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Learning C++ templates today")
|
|
|
|
# "C++" sanitized to "C" should still match "C++"
|
|
results = db.search_messages("C++")
|
|
# The word "C" appears in the content, so FTS5 should find it
|
|
assert isinstance(results, list)
|
|
|
|
def test_sanitize_fts5_query_strips_dangerous_chars(self):
|
|
"""Unit test for _sanitize_fts5_query static method."""
|
|
from hermes_state import SessionDB
|
|
s = SessionDB._sanitize_fts5_query
|
|
assert s('hello world') == 'hello world'
|
|
assert '+' not in s('C++')
|
|
assert '"' not in s('"unterminated')
|
|
assert '(' not in s('(problem')
|
|
assert '{' not in s('{test}')
|
|
# Dangling operators removed
|
|
assert s('hello AND') == 'hello'
|
|
assert s('OR world') == 'world'
|
|
# Leading bare * removed
|
|
assert s('***') == ''
|
|
# Valid prefix kept
|
|
assert s('deploy*') == 'deploy*'
|
|
|
|
|
|
# =========================================================================
|
|
# Session search and listing
|
|
# =========================================================================
|
|
|
|
class TestSearchSessions:
|
|
def test_list_all_sessions(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
sessions = db.search_sessions()
|
|
assert len(sessions) == 2
|
|
|
|
def test_filter_by_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
sessions = db.search_sessions(source="cli")
|
|
assert len(sessions) == 1
|
|
assert sessions[0]["source"] == "cli"
|
|
|
|
def test_pagination(self, db):
|
|
for i in range(5):
|
|
db.create_session(session_id=f"s{i}", source="cli")
|
|
|
|
page1 = db.search_sessions(limit=2)
|
|
page2 = db.search_sessions(limit=2, offset=2)
|
|
assert len(page1) == 2
|
|
assert len(page2) == 2
|
|
assert page1[0]["id"] != page2[0]["id"]
|
|
|
|
|
|
# =========================================================================
|
|
# Counts
|
|
# =========================================================================
|
|
|
|
class TestCounts:
|
|
def test_session_count(self, db):
|
|
assert db.session_count() == 0
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
assert db.session_count() == 2
|
|
|
|
def test_session_count_by_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.create_session(session_id="s3", source="cli")
|
|
assert db.session_count(source="cli") == 2
|
|
assert db.session_count(source="telegram") == 1
|
|
|
|
def test_message_count_total(self, db):
|
|
assert db.message_count() == 0
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
assert db.message_count() == 2
|
|
|
|
def test_message_count_per_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="cli")
|
|
db.append_message("s1", role="user", content="A")
|
|
db.append_message("s2", role="user", content="B")
|
|
db.append_message("s2", role="user", content="C")
|
|
assert db.message_count(session_id="s1") == 1
|
|
assert db.message_count(session_id="s2") == 2
|
|
|
|
|
|
# =========================================================================
|
|
# Delete and export
|
|
# =========================================================================
|
|
|
|
class TestDeleteAndExport:
|
|
def test_delete_session(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
|
|
assert db.delete_session("s1") is True
|
|
assert db.get_session("s1") is None
|
|
assert db.message_count(session_id="s1") == 0
|
|
|
|
def test_delete_nonexistent(self, db):
|
|
assert db.delete_session("nope") is False
|
|
|
|
def test_export_session(self, db):
|
|
db.create_session(session_id="s1", source="cli", model="test")
|
|
db.append_message("s1", role="user", content="Hello")
|
|
db.append_message("s1", role="assistant", content="Hi")
|
|
|
|
export = db.export_session("s1")
|
|
assert isinstance(export, dict)
|
|
assert export["source"] == "cli"
|
|
assert len(export["messages"]) == 2
|
|
|
|
def test_export_nonexistent(self, db):
|
|
assert db.export_session("nope") is None
|
|
|
|
def test_export_all(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
db.append_message("s1", role="user", content="A")
|
|
|
|
exports = db.export_all()
|
|
assert len(exports) == 2
|
|
|
|
def test_export_all_with_source(self, db):
|
|
db.create_session(session_id="s1", source="cli")
|
|
db.create_session(session_id="s2", source="telegram")
|
|
|
|
exports = db.export_all(source="cli")
|
|
assert len(exports) == 1
|
|
assert exports[0]["source"] == "cli"
|
|
|
|
|
|
# =========================================================================
|
|
# Prune
|
|
# =========================================================================
|
|
|
|
class TestPruneSessions:
|
|
def test_prune_old_ended_sessions(self, db):
|
|
# Create and end an "old" session
|
|
db.create_session(session_id="old", source="cli")
|
|
db.end_session("old", end_reason="done")
|
|
# Manually backdate started_at
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 100 * 86400, "old"),
|
|
)
|
|
db._conn.commit()
|
|
|
|
# Create a recent session
|
|
db.create_session(session_id="new", source="cli")
|
|
|
|
pruned = db.prune_sessions(older_than_days=90)
|
|
assert pruned == 1
|
|
assert db.get_session("old") is None
|
|
session = db.get_session("new")
|
|
assert session is not None
|
|
assert session["id"] == "new"
|
|
|
|
def test_prune_skips_active_sessions(self, db):
|
|
db.create_session(session_id="active", source="cli")
|
|
# Backdate but don't end
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 200 * 86400, "active"),
|
|
)
|
|
db._conn.commit()
|
|
|
|
pruned = db.prune_sessions(older_than_days=90)
|
|
assert pruned == 0
|
|
assert db.get_session("active") is not None
|
|
|
|
def test_prune_with_source_filter(self, db):
|
|
for sid, src in [("old_cli", "cli"), ("old_tg", "telegram")]:
|
|
db.create_session(session_id=sid, source=src)
|
|
db.end_session(sid, end_reason="done")
|
|
db._conn.execute(
|
|
"UPDATE sessions SET started_at = ? WHERE id = ?",
|
|
(time.time() - 200 * 86400, sid),
|
|
)
|
|
db._conn.commit()
|
|
|
|
pruned = db.prune_sessions(older_than_days=90, source="cli")
|
|
assert pruned == 1
|
|
assert db.get_session("old_cli") is None
|
|
assert db.get_session("old_tg") is not None
|
|
|
|
|
|
# =========================================================================
|
|
# Schema and WAL mode
|
|
# =========================================================================
|
|
|
|
class TestSchemaInit:
|
|
def test_wal_mode(self, db):
|
|
cursor = db._conn.execute("PRAGMA journal_mode")
|
|
mode = cursor.fetchone()[0]
|
|
assert mode == "wal"
|
|
|
|
def test_foreign_keys_enabled(self, db):
|
|
cursor = db._conn.execute("PRAGMA foreign_keys")
|
|
assert cursor.fetchone()[0] == 1
|
|
|
|
def test_tables_exist(self, db):
|
|
cursor = db._conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
|
)
|
|
tables = {row[0] for row in cursor.fetchall()}
|
|
assert "sessions" in tables
|
|
assert "messages" in tables
|
|
assert "schema_version" in tables
|
|
|
|
def test_schema_version(self, db):
|
|
cursor = db._conn.execute("SELECT version FROM schema_version")
|
|
version = cursor.fetchone()[0]
|
|
assert version == 2
|