mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 07:21:37 +08:00
Compare commits
1 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e8c31eadf |
@@ -2804,20 +2804,12 @@ class GatewayRunner:
|
||||
skip_db=agent_persisted,
|
||||
)
|
||||
|
||||
# Update session with actual prompt token count and model from the agent
|
||||
# Token counts and model are now persisted by the agent directly.
|
||||
# Keep only last_prompt_tokens here for context-window tracking and
|
||||
# compression decisions.
|
||||
self.session_store.update_session(
|
||||
session_entry.session_key,
|
||||
input_tokens=agent_result.get("input_tokens", 0),
|
||||
output_tokens=agent_result.get("output_tokens", 0),
|
||||
cache_read_tokens=agent_result.get("cache_read_tokens", 0),
|
||||
cache_write_tokens=agent_result.get("cache_write_tokens", 0),
|
||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||
model=agent_result.get("model"),
|
||||
estimated_cost_usd=agent_result.get("estimated_cost_usd"),
|
||||
cost_status=agent_result.get("cost_status"),
|
||||
cost_source=agent_result.get("cost_source"),
|
||||
provider=agent_result.get("provider"),
|
||||
base_url=agent_result.get("base_url"),
|
||||
)
|
||||
|
||||
# Auto voice reply: send TTS audio before the text response
|
||||
|
||||
@@ -778,66 +778,18 @@ class SessionStore:
|
||||
def update_session(
|
||||
self,
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
last_prompt_tokens: int = None,
|
||||
model: str = None,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
db_session_id = None
|
||||
|
||||
"""Update lightweight session metadata after an interaction."""
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = _now()
|
||||
# Direct assignment — the gateway receives cumulative totals
|
||||
# from the cached agent, not per-call deltas.
|
||||
entry.input_tokens = input_tokens
|
||||
entry.output_tokens = output_tokens
|
||||
entry.cache_read_tokens = cache_read_tokens
|
||||
entry.cache_write_tokens = cache_write_tokens
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd = estimated_cost_usd
|
||||
if cost_status:
|
||||
entry.cost_status = cost_status
|
||||
entry.total_tokens = (
|
||||
entry.input_tokens
|
||||
+ entry.output_tokens
|
||||
+ entry.cache_read_tokens
|
||||
+ entry.cache_write_tokens
|
||||
)
|
||||
self._save()
|
||||
db_session_id = entry.session_id
|
||||
|
||||
if self._db and db_session_id:
|
||||
try:
|
||||
self._db.set_token_counts(
|
||||
db_session_id,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
estimated_cost_usd=estimated_cost_usd,
|
||||
cost_status=cost_status,
|
||||
cost_source=cost_source,
|
||||
billing_provider=provider,
|
||||
billing_base_url=base_url,
|
||||
model=model,
|
||||
absolute=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
|
||||
12
run_agent.py
12
run_agent.py
@@ -7216,11 +7216,13 @@ class AIAgent:
|
||||
self.session_cost_source = cost_result.source
|
||||
|
||||
# Persist token counts to session DB for /insights.
|
||||
# Gateway sessions persist via session_store.update_session()
|
||||
# after run_conversation returns, so only persist here for
|
||||
# CLI (and other non-gateway) platforms to avoid double-counting.
|
||||
if (self._session_db and self.session_id
|
||||
and getattr(self, 'platform', None) == 'cli'):
|
||||
# Do this for every platform with a session_id so non-CLI
|
||||
# sessions (gateway, cron, delegated runs) cannot lose
|
||||
# token/accounting data if a higher-level persistence path
|
||||
# is skipped or fails. Gateway/session-store writes use
|
||||
# absolute totals, so they safely overwrite these per-call
|
||||
# deltas instead of double-counting them.
|
||||
if self._session_db and self.session_id:
|
||||
try:
|
||||
self._session_db.update_token_counts(
|
||||
self.session_id,
|
||||
|
||||
@@ -825,43 +825,6 @@ class TestLastPromptTokens:
|
||||
store.update_session("k1", last_prompt_tokens=0)
|
||||
assert entry.last_prompt_tokens == 0
|
||||
|
||||
def test_update_session_passes_model_to_db(self, tmp_path):
|
||||
"""Gateway session updates should forward the resolved model to SQLite."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._save = MagicMock()
|
||||
store._db = MagicMock()
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="k1",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
store._entries = {"k1": entry}
|
||||
|
||||
store.update_session("k1", model="openai/gpt-5.4")
|
||||
|
||||
store._db.set_token_counts.assert_called_once_with(
|
||||
"s1",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
estimated_cost_usd=None,
|
||||
cost_status=None,
|
||||
cost_source=None,
|
||||
billing_provider=None,
|
||||
billing_base_url=None,
|
||||
model="openai/gpt-5.4",
|
||||
absolute=True,
|
||||
)
|
||||
|
||||
|
||||
class TestRewriteTranscriptPreservesReasoning:
|
||||
"""rewrite_transcript must not drop reasoning fields from SQLite."""
|
||||
|
||||
|
||||
@@ -126,15 +126,5 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
assert result == "ok"
|
||||
runner.session_store.update_session.assert_called_once_with(
|
||||
session_entry.session_key,
|
||||
input_tokens=120,
|
||||
output_tokens=45,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
last_prompt_tokens=80,
|
||||
model="openai/test-model",
|
||||
estimated_cost_usd=None,
|
||||
cost_status=None,
|
||||
cost_source=None,
|
||||
provider=None,
|
||||
base_url=None,
|
||||
)
|
||||
|
||||
62
tests/test_token_persistence_non_cli.py
Normal file
62
tests/test_token_persistence_non_cli.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _mock_response(*, usage: dict, content: str = "done"):
|
||||
msg = SimpleNamespace(content=content, tool_calls=None)
|
||||
choice = SimpleNamespace(message=msg, finish_reason="stop")
|
||||
return SimpleNamespace(
|
||||
choices=[choice],
|
||||
model="test/model",
|
||||
usage=SimpleNamespace(**usage),
|
||||
)
|
||||
|
||||
|
||||
def _make_agent(session_db, *, platform: str):
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
session_db=session_db,
|
||||
session_id=f"{platform}-session",
|
||||
platform=platform,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
agent.client.chat.completions.create.return_value = _mock_response(
|
||||
usage={
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 18,
|
||||
}
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def test_run_conversation_persists_tokens_for_telegram_sessions():
|
||||
session_db = MagicMock()
|
||||
agent = _make_agent(session_db, platform="telegram")
|
||||
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
assert result["final_response"] == "done"
|
||||
session_db.update_token_counts.assert_called_once()
|
||||
assert session_db.update_token_counts.call_args.args[0] == "telegram-session"
|
||||
|
||||
|
||||
def test_run_conversation_persists_tokens_for_cron_sessions():
|
||||
session_db = MagicMock()
|
||||
agent = _make_agent(session_db, platform="cron")
|
||||
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
assert result["final_response"] == "done"
|
||||
session_db.update_token_counts.assert_called_once()
|
||||
assert session_db.update_token_counts.call_args.args[0] == "cron-session"
|
||||
Reference in New Issue
Block a user