Compare commits

...

1 Commits

Author SHA1 Message Date
teknium1
b2aff451ed fix(gateway): make /status report live state and tokens 2026-03-15 19:17:26 -07:00
2 changed files with 147 additions and 1 deletions

View File

@@ -1114,6 +1114,9 @@ class GatewayRunner:
# let the adapter-level batching/queueing logic absorb them. # let the adapter-level batching/queueing logic absorb them.
_quick_key = build_session_key(source) _quick_key = build_session_key(source)
if _quick_key in self._running_agents: if _quick_key in self._running_agents:
if event.get_command() == "status":
return await self._handle_status_command(event)
if event.message_type == MessageType.PHOTO: if event.message_type == MessageType.PHOTO:
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20]) logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
adapter = self.adapters.get(source.platform) adapter = self.adapters.get(source.platform)
@@ -1822,6 +1825,8 @@ class GatewayRunner:
# Update session with actual prompt token count and model from the agent # Update session with actual prompt token count and model from the agent
self.session_store.update_session( self.session_store.update_session(
session_entry.session_key, session_entry.session_key,
input_tokens=agent_result.get("input_tokens", 0),
output_tokens=agent_result.get("output_tokens", 0),
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
model=agent_result.get("model"), model=agent_result.get("model"),
) )
@@ -4171,11 +4176,15 @@ class GatewayRunner:
# Return final response, or a message if something went wrong # Return final response, or a message if something went wrong
final_response = result.get("final_response") final_response = result.get("final_response")
# Extract last actual prompt token count from the agent's compressor # Extract actual token counts from the agent instance used for this run
_last_prompt_toks = 0 _last_prompt_toks = 0
_input_toks = 0
_output_toks = 0
_agent = agent_holder[0] _agent = agent_holder[0]
if _agent and hasattr(_agent, "context_compressor"): if _agent and hasattr(_agent, "context_compressor"):
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0) _last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
_input_toks = getattr(_agent, "session_prompt_tokens", 0)
_output_toks = getattr(_agent, "session_completion_tokens", 0)
_resolved_model = getattr(_agent, "model", None) if _agent else None _resolved_model = getattr(_agent, "model", None) if _agent else None
if not final_response: if not final_response:
@@ -4187,6 +4196,8 @@ class GatewayRunner:
"tools": tools_holder[0] or [], "tools": tools_holder[0] or [],
"history_offset": len(agent_history), "history_offset": len(agent_history),
"last_prompt_tokens": _last_prompt_toks, "last_prompt_tokens": _last_prompt_toks,
"input_tokens": _input_toks,
"output_tokens": _output_toks,
"model": _resolved_model, "model": _resolved_model,
} }
@@ -4250,6 +4261,8 @@ class GatewayRunner:
"tools": tools_holder[0] or [], "tools": tools_holder[0] or [],
"history_offset": len(agent_history), "history_offset": len(agent_history),
"last_prompt_tokens": _last_prompt_toks, "last_prompt_tokens": _last_prompt_toks,
"input_tokens": _input_toks,
"output_tokens": _output_toks,
"model": _resolved_model, "model": _resolved_model,
"session_id": effective_session_id, "session_id": effective_session_id,
} }

View File

@@ -0,0 +1,133 @@
"""Tests for gateway /status behavior and token persistence."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(
text=text,
source=_make_source(),
message_id="m1",
)
def _make_runner(session_entry: SessionEntry):
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.load_transcript.return_value = []
runner.session_store.has_any_sessions.return_value = True
runner.session_store.append_to_transcript = MagicMock()
runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
runner._show_reasoning = False
runner._is_user_authorized = lambda _source: True
runner._set_session_env = lambda _context: None
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
runner._send_voice_reply = AsyncMock()
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
runner._emit_gateway_run_progress = AsyncMock()
return runner
@pytest.mark.asyncio
async def test_status_command_reports_running_agent_without_interrupt(monkeypatch):
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
total_tokens=321,
)
runner = _make_runner(session_entry)
running_agent = MagicMock()
runner._running_agents[build_session_key(_make_source())] = running_agent
result = await runner._handle_message(_make_event("/status"))
assert "**Tokens:** 321" in result
assert "**Agent Running:** Yes ⚡" in result
running_agent.interrupt.assert_not_called()
assert runner._pending_messages == {}
@pytest.mark.asyncio
async def test_handle_message_persists_agent_token_counts(monkeypatch):
import gateway.run as gateway_run
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
runner._run_agent = AsyncMock(
return_value={
"final_response": "ok",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result == "ok"
runner.session_store.update_session.assert_called_once_with(
session_entry.session_key,
input_tokens=120,
output_tokens=45,
last_prompt_tokens=80,
model="openai/test-model",
)