Compare commits

...

1 Commits

Author SHA1 Message Date
etcircle
dd6b5ffa74 fix(gateway): preserve queued voice events for STT 2026-04-11 14:26:18 -07:00
3 changed files with 269 additions and 196 deletions

View File

@@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str:
return "\n".join(parts) return "\n".join(parts)
def _dequeue_pending_text(adapter, session_key: str) -> str | None: def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
"""Consume and return the text of a pending queued message. """Consume and return the full pending event for a session.
Preserves media context for captionless photo/document events by Queued follow-ups must preserve their media metadata so they can re-enter
building a placeholder so the message isn't silently dropped. the normal image/STT/document preprocessing path instead of being reduced
to a placeholder string.
""" """
event = adapter.get_pending_message(session_key) return adapter.get_pending_message(session_key)
if not event:
return None
text = event.text
if not text and getattr(event, "media_urls", None):
text = _build_media_placeholder(event)
return text
def _check_unavailable_skill(command_name: str) -> str | None: def _check_unavailable_skill(command_name: str) -> str | None:
@@ -2775,6 +2770,162 @@ class GatewayRunner:
del self._running_agents[_quick_key] del self._running_agents[_quick_key]
self._running_agents_ts.pop(_quick_key, None) self._running_agents_ts.pop(_quick_key, None)
async def _prepare_inbound_message_text(
self,
*,
event: MessageEvent,
source: SessionSource,
history: List[Dict[str, Any]],
) -> Optional[str]:
"""Prepare inbound event text for the agent.
Keep the normal inbound path and the queued follow-up path on the same
preprocessing pipeline so sender attribution, image enrichment, STT,
document notes, reply context, and @ references all behave the same.
"""
history = history or []
message_text = event.text or ""
_is_shared_thread = (
source.chat_type != "dm"
and source.thread_id
and not getattr(self.config, "thread_sessions_per_user", False)
)
if _is_shared_thread and source.user_name:
message_text = f"[{source.user_name}] {message_text}"
if event.media_urls:
image_paths = []
audio_paths = []
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
if mtype.startswith("image/") or event.message_type == MessageType.PHOTO:
image_paths.append(path)
if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO):
audio_paths.append(path)
if image_paths:
message_text = await self._enrich_message_with_vision(
message_text,
image_paths,
)
if audio_paths:
message_text = await self._enrich_message_with_transcription(
message_text,
audio_paths,
)
_stt_fail_markers = (
"No STT provider",
"STT is disabled",
"can't listen",
"VOICE_TOOLS_OPENAI_KEY",
)
if any(marker in message_text for marker in _stt_fail_markers):
_stt_adapter = self.adapters.get(source.platform)
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
if _stt_adapter:
try:
_stt_msg = (
"🎤 I received your voice message but can't transcribe it — "
"no speech-to-text provider is configured.\n\n"
"To enable voice: install faster-whisper "
"(`pip install faster-whisper` in the Hermes venv) "
"and set `stt.enabled: true` in config.yaml, "
"then /restart the gateway."
)
if self._has_setup_skill():
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
await _stt_adapter.send(
source.chat_id,
_stt_msg,
metadata=_stt_meta,
)
except Exception:
pass
if event.media_urls and event.message_type == MessageType.DOCUMENT:
import mimetypes as _mimetypes
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
if mtype in ("", "application/octet-stream"):
import os as _os2
_ext = _os2.path.splitext(path)[1].lower()
if _ext in _TEXT_EXTENSIONS:
mtype = "text/plain"
else:
guessed, _ = _mimetypes.guess_type(path)
if guessed:
mtype = guessed
if not mtype.startswith(("application/", "text/")):
continue
import os as _os
import re as _re
basename = _os.path.basename(path)
parts = basename.split("_", 2)
display_name = parts[2] if len(parts) >= 3 else basename
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
if mtype.startswith("text/"):
context_note = (
f"[The user sent a text document: '{display_name}'. "
f"Its content has been included below. "
f"The file is also saved at: {path}]"
)
else:
context_note = (
f"[The user sent a document: '{display_name}'. "
f"The file is saved at: {path}. "
f"Ask the user what they'd like you to do with it.]"
)
message_text = f"{context_note}\n\n{message_text}"
if getattr(event, "reply_to_text", None) and event.reply_to_message_id:
reply_snippet = event.reply_to_text[:500]
found_in_history = any(
reply_snippet[:200] in (msg.get("content") or "")
for msg in history
if msg.get("role") in ("assistant", "user", "tool")
)
if not found_in_history:
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
if "@" in message_text:
try:
from agent.context_references import preprocess_context_references_async
from agent.model_metadata import get_model_context_length
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
_msg_ctx_len = get_model_context_length(
self._model,
base_url=self._base_url or "",
)
_ctx_result = await preprocess_context_references_async(
message_text,
cwd=_msg_cwd,
context_length=_msg_ctx_len,
allowed_root=_msg_cwd,
)
if _ctx_result.blocked:
_adapter = self.adapters.get(source.platform)
if _adapter:
await _adapter.send(
source.chat_id,
"\n".join(_ctx_result.warnings) or "Context injection refused.",
)
return None
if _ctx_result.expanded:
message_text = _ctx_result.message
except Exception as exc:
logger.debug("@ context reference expansion failed: %s", exc)
return message_text
async def _handle_message_with_agent(self, event, source, _quick_key: str): async def _handle_message_with_agent(self, event, source, _quick_key: str):
"""Inner handler that runs under the _running_agents sentinel guard.""" """Inner handler that runs under the _running_agents sentinel guard."""
_msg_start_time = time.time() _msg_start_time = time.time()
@@ -3215,149 +3366,13 @@ class GatewayRunner:
# attachments (documents, audio, etc.) are not sent to the vision # attachments (documents, audio, etc.) are not sent to the vision
# tool even when they appear in the same message. # tool even when they appear in the same message.
# ----------------------------------------------------------------- # -----------------------------------------------------------------
message_text = event.text or "" message_text = await self._prepare_inbound_message_text(
event=event,
# ----------------------------------------------------------------- source=source,
# Sender attribution for shared thread sessions. history=history,
#
# When multiple users share a single thread session (the default for
# threads), prefix each message with [sender name] so the agent can
# tell participants apart. Skip for DMs (single-user by nature) and
# when per-user thread isolation is explicitly enabled.
# -----------------------------------------------------------------
_is_shared_thread = (
source.chat_type != "dm"
and source.thread_id
and not getattr(self.config, "thread_sessions_per_user", False)
) )
if _is_shared_thread and source.user_name: if message_text is None:
message_text = f"[{source.user_name}] {message_text}" return
if event.media_urls:
image_paths = []
for i, path in enumerate(event.media_urls):
# Check media_types if available; otherwise infer from message type
mtype = event.media_types[i] if i < len(event.media_types) else ""
is_image = (
mtype.startswith("image/")
or event.message_type == MessageType.PHOTO
)
if is_image:
image_paths.append(path)
if image_paths:
message_text = await self._enrich_message_with_vision(
message_text, image_paths
)
# -----------------------------------------------------------------
# Auto-transcribe voice/audio messages sent by the user
# -----------------------------------------------------------------
if event.media_urls:
audio_paths = []
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
is_audio = (
mtype.startswith("audio/")
or event.message_type in (MessageType.VOICE, MessageType.AUDIO)
)
if is_audio:
audio_paths.append(path)
if audio_paths:
message_text = await self._enrich_message_with_transcription(
message_text, audio_paths
)
# If STT failed, send a direct message to the user so they
# know voice isn't configured — don't rely on the agent to
# relay the error clearly.
_stt_fail_markers = (
"No STT provider",
"STT is disabled",
"can't listen",
"VOICE_TOOLS_OPENAI_KEY",
)
if any(m in message_text for m in _stt_fail_markers):
_stt_adapter = self.adapters.get(source.platform)
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
if _stt_adapter:
try:
_stt_msg = (
"🎤 I received your voice message but can't transcribe it — "
"no speech-to-text provider is configured.\n\n"
"To enable voice: install faster-whisper "
"(`pip install faster-whisper` in the Hermes venv) "
"and set `stt.enabled: true` in config.yaml, "
"then /restart the gateway."
)
# Point to setup skill if it's installed
if self._has_setup_skill():
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
await _stt_adapter.send(
source.chat_id, _stt_msg,
metadata=_stt_meta,
)
except Exception:
pass
# -----------------------------------------------------------------
# Enrich document messages with context notes for the agent
# -----------------------------------------------------------------
if event.media_urls and event.message_type == MessageType.DOCUMENT:
import mimetypes as _mimetypes
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
# Fall back to extension-based detection when MIME type is unreliable.
if mtype in ("", "application/octet-stream"):
import os as _os2
_ext = _os2.path.splitext(path)[1].lower()
if _ext in _TEXT_EXTENSIONS:
mtype = "text/plain"
else:
guessed, _ = _mimetypes.guess_type(path)
if guessed:
mtype = guessed
if not mtype.startswith(("application/", "text/")):
continue
# Extract display filename by stripping the doc_{uuid12}_ prefix
import os as _os
basename = _os.path.basename(path)
# Format: doc_<12hex>_<original_filename>
parts = basename.split("_", 2)
display_name = parts[2] if len(parts) >= 3 else basename
# Sanitize to prevent prompt injection via filenames
import re as _re
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
if mtype.startswith("text/"):
context_note = (
f"[The user sent a text document: '{display_name}'. "
f"Its content has been included below. "
f"The file is also saved at: {path}]"
)
else:
context_note = (
f"[The user sent a document: '{display_name}'. "
f"The file is saved at: {path}. "
f"Ask the user what they'd like you to do with it.]"
)
message_text = f"{context_note}\n\n{message_text}"
# -----------------------------------------------------------------
# Inject reply context when user replies to a message not in history.
# Telegram (and other platforms) let users reply to specific messages,
# but if the quoted message is from a previous session, cron delivery,
# or background task, the agent has no context about what's being
# referenced. Prepend the quoted text so the agent understands. (#1594)
# -----------------------------------------------------------------
if getattr(event, 'reply_to_text', None) and event.reply_to_message_id:
reply_snippet = event.reply_to_text[:500]
found_in_history = any(
reply_snippet[:200] in (msg.get("content") or "")
for msg in history
if msg.get("role") in ("assistant", "user", "tool")
)
if not found_in_history:
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
try: try:
# Emit agent:start hook # Emit agent:start hook
@@ -3369,30 +3384,6 @@ class GatewayRunner:
} }
await self.hooks.emit("agent:start", hook_ctx) await self.hooks.emit("agent:start", hook_ctx)
# Expand @ context references (@file:, @folder:, @diff, etc.)
if "@" in message_text:
try:
from agent.context_references import preprocess_context_references_async
from agent.model_metadata import get_model_context_length
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
_msg_ctx_len = get_model_context_length(
self._model, base_url=self._base_url or "")
_ctx_result = await preprocess_context_references_async(
message_text, cwd=_msg_cwd,
context_length=_msg_ctx_len, allowed_root=_msg_cwd)
if _ctx_result.blocked:
_adapter = self.adapters.get(source.platform)
if _adapter:
await _adapter.send(
source.chat_id,
"\n".join(_ctx_result.warnings) or "Context injection refused.",
)
return
if _ctx_result.expanded:
message_text = _ctx_result.message
except Exception as exc:
logger.debug("@ context reference expansion failed: %s", exc)
# Run the agent # Run the agent
agent_result = await self._run_agent( agent_result = await self._run_agent(
message=message_text, message=message_text,
@@ -8057,16 +8048,15 @@ class GatewayRunner:
# Get pending message from adapter. # Get pending message from adapter.
# Use session_key (not source.chat_id) to match adapter's storage keys. # Use session_key (not source.chat_id) to match adapter's storage keys.
pending_event = None
pending = None pending = None
if result and adapter and session_key: if result and adapter and session_key:
if result.get("interrupted"): pending_event = _dequeue_pending_event(adapter, session_key)
pending = _dequeue_pending_text(adapter, session_key) if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
if not pending and result.get("interrupt_message"): pending = result.get("interrupt_message")
pending = result.get("interrupt_message") elif pending_event:
else: pending = pending_event.text or _build_media_placeholder(pending_event)
pending = _dequeue_pending_text(adapter, session_key) logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
if pending:
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
# Safety net: if the pending text is a slash command (e.g. "/stop", # Safety net: if the pending text is a slash command (e.g. "/stop",
# "/new"), discard it — commands should never be passed to the agent # "/new"), discard it — commands should never be passed to the agent
@@ -8085,19 +8075,21 @@ class GatewayRunner:
"commands must not be passed as agent input", "commands must not be passed as agent input",
_pending_cmd_word, _pending_cmd_word,
) )
pending_event = None
pending = None pending = None
except Exception: except Exception:
pass pass
if self._draining and pending: if self._draining and (pending_event or pending):
logger.info( logger.info(
"Discarding pending follow-up for session %s during gateway %s", "Discarding pending follow-up for session %s during gateway %s",
session_key[:20] if session_key else "?", session_key[:20] if session_key else "?",
self._status_action_label(), self._status_action_label(),
) )
pending_event = None
pending = None pending = None
if pending: if pending_event or pending:
logger.debug("Processing pending message: '%s...'", pending[:40]) logger.debug("Processing pending message: '%s...'", pending[:40])
# Clear the adapter's interrupt event so the next _run_agent call # Clear the adapter's interrupt event so the next _run_agent call
@@ -8114,9 +8106,10 @@ class GatewayRunner:
"queueing message instead of recursing.", "queueing message instead of recursing.",
_interrupt_depth, session_key, _interrupt_depth, session_key,
) )
# Queue the pending message for normal processing on next turn
adapter = self.adapters.get(source.platform) adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, 'queue_message'): if adapter and pending_event:
merge_pending_message_event(adapter._pending_messages, session_key, pending_event)
elif adapter and hasattr(adapter, 'queue_message'):
adapter.queue_message(session_key, pending) adapter.queue_message(session_key, pending)
return result_holder[0] or {"final_response": response, "messages": history} return result_holder[0] or {"final_response": response, "messages": history}
@@ -8138,16 +8131,30 @@ class GatewayRunner:
# interrupted." is just noise; the user already knows they sent a # interrupted." is just noise; the user already knows they sent a
# new message). # new message).
# Process the pending message with updated history
updated_history = result.get("messages", history) updated_history = result.get("messages", history)
next_source = source
next_message = pending
next_message_id = None
if pending_event is not None:
next_source = getattr(pending_event, "source", None) or source
next_message = await self._prepare_inbound_message_text(
event=pending_event,
source=next_source,
history=updated_history,
)
if next_message is None:
return result
next_message_id = getattr(pending_event, "message_id", None)
return await self._run_agent( return await self._run_agent(
message=pending, message=next_message,
context_prompt=context_prompt, context_prompt=context_prompt,
history=updated_history, history=updated_history,
source=source, source=next_source,
session_id=session_id, session_id=session_id,
session_key=session_key, session_key=session_key,
_interrupt_depth=_interrupt_depth + 1, _interrupt_depth=_interrupt_depth + 1,
event_message_id=next_message_id,
) )
finally: finally:
# Stop progress sender, interrupt monitor, and notification task # Stop progress sender, interrupt monitor, and notification task

View File

@@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from gateway.run import _dequeue_pending_event
from gateway.platforms.base import ( from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
@@ -79,6 +80,26 @@ class TestQueueMessageStorage:
# Should be consumed (cleared) # Should be consumed (cleared)
assert adapter.get_pending_message(session_key) is None assert adapter.get_pending_message(session_key) is None
def test_dequeue_pending_event_preserves_voice_media_metadata(self):
adapter = _StubAdapter()
session_key = "telegram:user:voice"
event = MessageEvent(
text="",
message_type=MessageType.VOICE,
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
message_id="voice-q1",
media_urls=["/tmp/voice.ogg"],
media_types=["audio/ogg"],
)
adapter._pending_messages[session_key] = event
retrieved = _dequeue_pending_event(adapter, session_key)
assert retrieved is event
assert retrieved.media_urls == ["/tmp/voice.ogg"]
assert retrieved.media_types == ["audio/ogg"]
assert adapter.get_pending_message(session_key) is None
def test_queue_does_not_set_interrupt_event(self): def test_queue_does_not_set_interrupt_event(self):
"""The whole point of /queue — no interrupt signal.""" """The whole point of /queue — no interrupt signal."""
adapter = _StubAdapter() adapter = _StubAdapter()

View File

@@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
import yaml import yaml
from gateway.config import GatewayConfig, load_gateway_config from gateway.config import GatewayConfig, Platform, load_gateway_config
from gateway.platforms.base import MessageEvent, MessageType
from gateway.session import SessionSource
def test_gateway_config_stt_disabled_from_dict_nested(): def test_gateway_config_stt_disabled_from_dict_nested():
@@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag
assert "No STT provider is configured" not in result assert "No STT provider is configured" not in result
assert "trouble transcribing" in result assert "trouble transcribing" in result
assert "caption" in result assert "caption" in result
@pytest.mark.asyncio
async def test_prepare_inbound_message_text_transcribes_queued_voice_event():
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner.config = GatewayConfig(stt_enabled=True)
runner.adapters = {}
runner._model = "test-model"
runner._base_url = ""
runner._has_setup_skill = lambda: False
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_type="dm",
)
event = MessageEvent(
text="",
message_type=MessageType.VOICE,
source=source,
media_urls=["/tmp/queued-voice.ogg"],
media_types=["audio/ogg"],
)
with patch(
"tools.transcription_tools.transcribe_audio",
return_value={
"success": True,
"transcript": "queued voice transcript",
"provider": "local_command",
},
):
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[],
)
assert result is not None
assert "queued voice transcript" in result
assert "voice message" in result.lower()