Compare commits

...

1 Commits

Author SHA1 Message Date
kshitijk4poor
8e8c31eadf fix(insights): persist token usage for non-CLI sessions 2026-04-02 10:14:41 -07:00
6 changed files with 73 additions and 112 deletions

View File

@@ -2804,20 +2804,12 @@ class GatewayRunner:
skip_db=agent_persisted,
)
# Update session with actual prompt token count and model from the agent
# Token counts and model are now persisted by the agent directly.
# Keep only last_prompt_tokens here for context-window tracking and
# compression decisions.
self.session_store.update_session(
session_entry.session_key,
input_tokens=agent_result.get("input_tokens", 0),
output_tokens=agent_result.get("output_tokens", 0),
cache_read_tokens=agent_result.get("cache_read_tokens", 0),
cache_write_tokens=agent_result.get("cache_write_tokens", 0),
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
model=agent_result.get("model"),
estimated_cost_usd=agent_result.get("estimated_cost_usd"),
cost_status=agent_result.get("cost_status"),
cost_source=agent_result.get("cost_source"),
provider=agent_result.get("provider"),
base_url=agent_result.get("base_url"),
)
# Auto voice reply: send TTS audio before the text response

View File

@@ -778,66 +778,18 @@ class SessionStore:
def update_session(
self,
session_key: str,
input_tokens: int = 0,
output_tokens: int = 0,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
last_prompt_tokens: int = None,
model: str = None,
estimated_cost_usd: Optional[float] = None,
cost_status: Optional[str] = None,
cost_source: Optional[str] = None,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> None:
"""Update a session's metadata after an interaction."""
db_session_id = None
"""Update lightweight session metadata after an interaction."""
with self._lock:
self._ensure_loaded_locked()
if session_key in self._entries:
entry = self._entries[session_key]
entry.updated_at = _now()
# Direct assignment — the gateway receives cumulative totals
# from the cached agent, not per-call deltas.
entry.input_tokens = input_tokens
entry.output_tokens = output_tokens
entry.cache_read_tokens = cache_read_tokens
entry.cache_write_tokens = cache_write_tokens
if last_prompt_tokens is not None:
entry.last_prompt_tokens = last_prompt_tokens
if estimated_cost_usd is not None:
entry.estimated_cost_usd = estimated_cost_usd
if cost_status:
entry.cost_status = cost_status
entry.total_tokens = (
entry.input_tokens
+ entry.output_tokens
+ entry.cache_read_tokens
+ entry.cache_write_tokens
)
self._save()
db_session_id = entry.session_id
if self._db and db_session_id:
try:
self._db.set_token_counts(
db_session_id,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
estimated_cost_usd=estimated_cost_usd,
cost_status=cost_status,
cost_source=cost_source,
billing_provider=provider,
billing_base_url=base_url,
model=model,
absolute=True,
)
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
"""Force reset a session, creating a new session ID."""

View File

@@ -7216,11 +7216,13 @@ class AIAgent:
self.session_cost_source = cost_result.source
# Persist token counts to session DB for /insights.
# Gateway sessions persist via session_store.update_session()
# after run_conversation returns, so only persist here for
# CLI (and other non-gateway) platforms to avoid double-counting.
if (self._session_db and self.session_id
and getattr(self, 'platform', None) == 'cli'):
# Do this for every platform with a session_id so non-CLI
# sessions (gateway, cron, delegated runs) cannot lose
# token/accounting data if a higher-level persistence path
# is skipped or fails. Gateway/session-store writes use
# absolute totals, so they safely overwrite these per-call
# deltas instead of double-counting them.
if self._session_db and self.session_id:
try:
self._session_db.update_token_counts(
self.session_id,

View File

@@ -825,43 +825,6 @@ class TestLastPromptTokens:
store.update_session("k1", last_prompt_tokens=0)
assert entry.last_prompt_tokens == 0
def test_update_session_passes_model_to_db(self, tmp_path):
"""Gateway session updates should forward the resolved model to SQLite."""
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._loaded = True
store._save = MagicMock()
store._db = MagicMock()
from gateway.session import SessionEntry
from datetime import datetime
entry = SessionEntry(
session_key="k1",
session_id="s1",
created_at=datetime.now(),
updated_at=datetime.now(),
)
store._entries = {"k1": entry}
store.update_session("k1", model="openai/gpt-5.4")
store._db.set_token_counts.assert_called_once_with(
"s1",
input_tokens=0,
output_tokens=0,
cache_read_tokens=0,
cache_write_tokens=0,
estimated_cost_usd=None,
cost_status=None,
cost_source=None,
billing_provider=None,
billing_base_url=None,
model="openai/gpt-5.4",
absolute=True,
)
class TestRewriteTranscriptPreservesReasoning:
"""rewrite_transcript must not drop reasoning fields from SQLite."""

View File

@@ -126,15 +126,5 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
assert result == "ok"
runner.session_store.update_session.assert_called_once_with(
session_entry.session_key,
input_tokens=120,
output_tokens=45,
cache_read_tokens=0,
cache_write_tokens=0,
last_prompt_tokens=80,
model="openai/test-model",
estimated_cost_usd=None,
cost_status=None,
cost_source=None,
provider=None,
base_url=None,
)

View File

@@ -0,0 +1,62 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from run_agent import AIAgent
def _mock_response(*, usage: dict, content: str = "done"):
msg = SimpleNamespace(content=content, tool_calls=None)
choice = SimpleNamespace(message=msg, finish_reason="stop")
return SimpleNamespace(
choices=[choice],
model="test/model",
usage=SimpleNamespace(**usage),
)
def _make_agent(session_db, *, platform: str):
with (
patch("run_agent.get_tool_definitions", return_value=[]),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
agent = AIAgent(
api_key="test-key",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
session_db=session_db,
session_id=f"{platform}-session",
platform=platform,
)
agent.client = MagicMock()
agent.client.chat.completions.create.return_value = _mock_response(
usage={
"prompt_tokens": 11,
"completion_tokens": 7,
"total_tokens": 18,
}
)
return agent
def test_run_conversation_persists_tokens_for_telegram_sessions():
session_db = MagicMock()
agent = _make_agent(session_db, platform="telegram")
result = agent.run_conversation("hello")
assert result["final_response"] == "done"
session_db.update_token_counts.assert_called_once()
assert session_db.update_token_counts.call_args.args[0] == "telegram-session"
def test_run_conversation_persists_tokens_for_cron_sessions():
session_db = MagicMock()
agent = _make_agent(session_db, platform="cron")
result = agent.run_conversation("hello")
assert result["final_response"] == "done"
session_db.update_token_counts.assert_called_once()
assert session_db.update_token_counts.call_args.args[0] == "cron-session"