mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
fix: integration hardening for gateway token tracking
Follow-up to 58dbd81 — ensures smooth transition for existing users:
- Backward compat: old session files without last_prompt_tokens
default to 0 via data.get('last_prompt_tokens', 0)
- /compress, /undo, /retry: reset last_prompt_tokens to 0 after
rewriting transcripts (stale token counts would under-report)
- Auto-compression hygiene: reset last_prompt_tokens after rewriting
- update_session: use None sentinel (not 0) as default so callers
can explicitly reset to 0 while normal calls don't clobber
- 6 new tests covering: default value, serialization roundtrip,
old-format migration, set/reset/no-change semantics
- /reset: new SessionEntry naturally gets last_prompt_tokens=0
2942 tests pass.
This commit is contained in:
@@ -1083,6 +1083,8 @@ class GatewayRunner:
|
||||
self.session_store.rewrite_transcript(
|
||||
session_entry.session_id, _compressed
|
||||
)
|
||||
# Reset stored token count — transcript was rewritten
|
||||
session_entry.last_prompt_tokens = 0
|
||||
history = _compressed
|
||||
_new_count = len(_compressed)
|
||||
_new_tokens = estimate_messages_tokens_rough(
|
||||
@@ -1747,6 +1749,8 @@ class GatewayRunner:
|
||||
# Truncate history to before the last user message and persist
|
||||
truncated = history[:last_user_idx]
|
||||
self.session_store.rewrite_transcript(session_entry.session_id, truncated)
|
||||
# Reset stored token count — transcript was truncated
|
||||
session_entry.last_prompt_tokens = 0
|
||||
|
||||
# Re-send by creating a fake text event with the old message
|
||||
retry_event = MessageEvent(
|
||||
@@ -1778,6 +1782,8 @@ class GatewayRunner:
|
||||
removed_msg = history[last_user_idx].get("content", "")
|
||||
removed_count = len(history) - last_user_idx
|
||||
self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx])
|
||||
# Reset stored token count — transcript was truncated
|
||||
session_entry.last_prompt_tokens = 0
|
||||
|
||||
preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg
|
||||
return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\""
|
||||
@@ -1911,6 +1917,10 @@ class GatewayRunner:
|
||||
)
|
||||
|
||||
self.session_store.rewrite_transcript(session_entry.session_id, compressed)
|
||||
# Reset stored token count — transcript changed, old value is stale
|
||||
self.session_store.update_session(
|
||||
session_entry.session_key, last_prompt_tokens=0,
|
||||
)
|
||||
new_count = len(compressed)
|
||||
new_tokens = estimate_messages_tokens_rough(compressed)
|
||||
|
||||
|
||||
@@ -556,7 +556,7 @@ class SessionStore:
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
last_prompt_tokens: int = 0,
|
||||
last_prompt_tokens: int = None,
|
||||
) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
@@ -566,7 +566,7 @@ class SessionStore:
|
||||
entry.updated_at = datetime.now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
if last_prompt_tokens > 0:
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||
self._save()
|
||||
|
||||
@@ -429,3 +429,119 @@ class TestHasAnySessions:
|
||||
|
||||
store._entries = {"key1": MagicMock()}
|
||||
assert store.has_any_sessions() is False
|
||||
|
||||
|
||||
class TestLastPromptTokens:
|
||||
"""Tests for the last_prompt_tokens field — actual API token tracking."""
|
||||
|
||||
def test_session_entry_default(self):
|
||||
"""New sessions should have last_prompt_tokens=0."""
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
assert entry.last_prompt_tokens == 0
|
||||
|
||||
def test_session_entry_roundtrip(self):
|
||||
"""last_prompt_tokens should survive serialization/deserialization."""
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
last_prompt_tokens=42000,
|
||||
)
|
||||
d = entry.to_dict()
|
||||
assert d["last_prompt_tokens"] == 42000
|
||||
restored = SessionEntry.from_dict(d)
|
||||
assert restored.last_prompt_tokens == 42000
|
||||
|
||||
def test_session_entry_from_old_data(self):
|
||||
"""Old session data without last_prompt_tokens should default to 0."""
|
||||
from gateway.session import SessionEntry
|
||||
data = {
|
||||
"session_key": "test",
|
||||
"session_id": "s1",
|
||||
"created_at": "2025-01-01T00:00:00",
|
||||
"updated_at": "2025-01-01T00:00:00",
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
# No last_prompt_tokens — old format
|
||||
}
|
||||
entry = SessionEntry.from_dict(data)
|
||||
assert entry.last_prompt_tokens == 0
|
||||
|
||||
def test_update_session_sets_last_prompt_tokens(self, tmp_path):
|
||||
"""update_session should store the actual prompt token count."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._db = None
|
||||
store._save = MagicMock()
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="k1",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
store._entries = {"k1": entry}
|
||||
|
||||
store.update_session("k1", last_prompt_tokens=85000)
|
||||
assert entry.last_prompt_tokens == 85000
|
||||
|
||||
def test_update_session_none_does_not_change(self, tmp_path):
|
||||
"""update_session with default (None) should not change last_prompt_tokens."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._db = None
|
||||
store._save = MagicMock()
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="k1",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
last_prompt_tokens=50000,
|
||||
)
|
||||
store._entries = {"k1": entry}
|
||||
|
||||
store.update_session("k1") # No last_prompt_tokens arg
|
||||
assert entry.last_prompt_tokens == 50000 # unchanged
|
||||
|
||||
def test_update_session_zero_resets(self, tmp_path):
|
||||
"""update_session with last_prompt_tokens=0 should reset the field."""
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._loaded = True
|
||||
store._db = None
|
||||
store._save = MagicMock()
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
from datetime import datetime
|
||||
entry = SessionEntry(
|
||||
session_key="k1",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
last_prompt_tokens=85000,
|
||||
)
|
||||
store._entries = {"k1": entry}
|
||||
|
||||
store.update_session("k1", last_prompt_tokens=0)
|
||||
assert entry.last_prompt_tokens == 0
|
||||
|
||||
Reference in New Issue
Block a user