diff --git a/gateway/run.py b/gateway/run.py index 8458bb9d46..72ec62b409 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1083,6 +1083,8 @@ class GatewayRunner: self.session_store.rewrite_transcript( session_entry.session_id, _compressed ) + # Reset stored token count — transcript was rewritten + session_entry.last_prompt_tokens = 0 history = _compressed _new_count = len(_compressed) _new_tokens = estimate_messages_tokens_rough( @@ -1747,6 +1749,8 @@ class GatewayRunner: # Truncate history to before the last user message and persist truncated = history[:last_user_idx] self.session_store.rewrite_transcript(session_entry.session_id, truncated) + # Reset stored token count — transcript was truncated + session_entry.last_prompt_tokens = 0 # Re-send by creating a fake text event with the old message retry_event = MessageEvent( @@ -1778,6 +1782,8 @@ class GatewayRunner: removed_msg = history[last_user_idx].get("content", "") removed_count = len(history) - last_user_idx self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx]) + # Reset stored token count — transcript was truncated + session_entry.last_prompt_tokens = 0 preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\"" @@ -1911,6 +1917,10 @@ class GatewayRunner: ) self.session_store.rewrite_transcript(session_entry.session_id, compressed) + # Reset stored token count — transcript changed, old value is stale + self.session_store.update_session( + session_entry.session_key, last_prompt_tokens=0, + ) new_count = len(compressed) new_tokens = estimate_messages_tokens_rough(compressed) diff --git a/gateway/session.py b/gateway/session.py index e2777fe1a2..b1cdefa5b5 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -556,7 +556,7 @@ class SessionStore: session_key: str, input_tokens: int = 0, output_tokens: int = 0, - last_prompt_tokens: int = 0, + last_prompt_tokens: int = None, ) -> None: """Update a session's metadata after an interaction.""" self._ensure_loaded() @@ -566,7 +566,7 @@ class SessionStore: entry.updated_at = datetime.now() entry.input_tokens += input_tokens entry.output_tokens += output_tokens - if last_prompt_tokens > 0: + if last_prompt_tokens is not None: entry.last_prompt_tokens = last_prompt_tokens entry.total_tokens = entry.input_tokens + entry.output_tokens self._save() diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 562c580973..7a7f4b878c 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -429,3 +429,119 @@ class TestHasAnySessions: store._entries = {"key1": MagicMock()} assert store.has_any_sessions() is False + + +class TestLastPromptTokens: + """Tests for the last_prompt_tokens field — actual API token tracking.""" + + def test_session_entry_default(self): + """New sessions should have last_prompt_tokens=0.""" + from gateway.session import SessionEntry + from datetime import datetime + entry = SessionEntry( + session_key="test", + session_id="s1", + created_at=datetime.now(), + updated_at=datetime.now(), + ) + assert entry.last_prompt_tokens == 0 + + def test_session_entry_roundtrip(self): + """last_prompt_tokens should survive serialization/deserialization.""" + from gateway.session import SessionEntry + from datetime import datetime + entry = SessionEntry( + session_key="test", + session_id="s1", + created_at=datetime.now(), + updated_at=datetime.now(), + last_prompt_tokens=42000, + ) + d = entry.to_dict() + assert d["last_prompt_tokens"] == 42000 + restored = SessionEntry.from_dict(d) + assert restored.last_prompt_tokens == 42000 + + def test_session_entry_from_old_data(self): + """Old session data without last_prompt_tokens should default to 0.""" + from gateway.session import SessionEntry + data = { + "session_key": "test", + "session_id": "s1", + "created_at": "2025-01-01T00:00:00", + "updated_at": "2025-01-01T00:00:00", + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + # No last_prompt_tokens — old format + } + entry = SessionEntry.from_dict(data) + assert entry.last_prompt_tokens == 0 + + def test_update_session_sets_last_prompt_tokens(self, tmp_path): + """update_session should store the actual prompt token count.""" + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=tmp_path, config=config) + store._loaded = True + store._db = None + store._save = MagicMock() + + from gateway.session import SessionEntry + from datetime import datetime + entry = SessionEntry( + session_key="k1", + session_id="s1", + created_at=datetime.now(), + updated_at=datetime.now(), + ) + store._entries = {"k1": entry} + + store.update_session("k1", last_prompt_tokens=85000) + assert entry.last_prompt_tokens == 85000 + + def test_update_session_none_does_not_change(self, tmp_path): + """update_session with default (None) should not change last_prompt_tokens.""" + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=tmp_path, config=config) + store._loaded = True + store._db = None + store._save = MagicMock() + + from gateway.session import SessionEntry + from datetime import datetime + entry = SessionEntry( + session_key="k1", + session_id="s1", + created_at=datetime.now(), + updated_at=datetime.now(), + last_prompt_tokens=50000, + ) + store._entries = {"k1": entry} + + store.update_session("k1") # No last_prompt_tokens arg + assert entry.last_prompt_tokens == 50000 # unchanged + + def test_update_session_zero_resets(self, tmp_path): + """update_session with last_prompt_tokens=0 should reset the field.""" + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=tmp_path, config=config) + store._loaded = True + store._db = None + store._save = MagicMock() + + from gateway.session import SessionEntry + from datetime import datetime + entry = SessionEntry( + session_key="k1", + session_id="s1", + created_at=datetime.now(), + updated_at=datetime.now(), + last_prompt_tokens=85000, + ) + store._entries = {"k1": entry} + + store.update_session("k1", last_prompt_tokens=0) + assert entry.last_prompt_tokens == 0