mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-02 08:47:26 +08:00
Compare commits
26 Commits
feat/teleg
...
feat/strea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d41115aa31 | ||
|
|
172a38c344 | ||
|
|
9f086173be | ||
|
|
8bc0d4f77d | ||
|
|
8eabdefa8a | ||
|
|
f658af45c2 | ||
|
|
5212644861 | ||
|
|
1151f84351 | ||
|
|
9abd6bf342 | ||
|
|
d2c7ef6b41 | ||
|
|
57faddd808 | ||
|
|
a34102049b | ||
|
|
ef5d811aba | ||
|
|
af6a92a4c2 | ||
|
|
4d6c90c6d0 | ||
|
|
2d44ed1c5b | ||
|
|
fa2e72ae9c | ||
|
|
5bfc4ed53b | ||
|
|
520aec20e0 | ||
|
|
64bec1d060 | ||
|
|
ac58309dbd | ||
|
|
a5a5d82a21 | ||
|
|
34e8d088c2 | ||
|
|
36214d14db | ||
|
|
15561ec425 | ||
|
|
7d79ce92ac |
@@ -560,12 +560,16 @@ def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
forced = _get_auxiliary_provider("vision")
|
||||
if forced != "auto":
|
||||
return _resolve_forced_provider(forced)
|
||||
# Auto: only multimodal-capable providers
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex):
|
||||
# Auto: try providers known to support multimodal first, then fall
|
||||
# back to the user's custom endpoint. Many local models (Qwen-VL,
|
||||
# LLaVA, Pixtral, etc.) support vision — skipping them entirely
|
||||
# caused silent failures for local-only users.
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex,
|
||||
_try_custom_endpoint):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.debug("Auxiliary vision client: none available (auto only tries OpenRouter/Nous/Codex)")
|
||||
logger.debug("Auxiliary vision client: none available")
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
72
cli.py
72
cli.py
@@ -158,6 +158,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"singularity_image": "docker://python:3.11",
|
||||
"modal_image": "python:3.11",
|
||||
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"docker_volumes": [], # host:container volume mounts for Docker backend
|
||||
},
|
||||
"browser": {
|
||||
"inactivity_timeout": 120, # Auto-cleanup inactive browser sessions after 2 min
|
||||
@@ -1186,6 +1187,7 @@ class HermesCLI:
|
||||
# History file for persistent input recall across sessions
|
||||
self._history_file = Path.home() / ".hermes_history"
|
||||
self._last_invalidate: float = 0.0 # throttle UI repaints
|
||||
self._stream_buf = ""
|
||||
|
||||
def _invalidate(self, min_interval: float = 0.25) -> None:
|
||||
"""Throttled UI repaint — prevents terminal blinking on slow/SSH connections."""
|
||||
@@ -1385,6 +1387,7 @@ class HermesCLI:
|
||||
platform="cli",
|
||||
session_db=self._session_db,
|
||||
clarify_callback=self._clarify_callback,
|
||||
stream_delta_callback=self._stream_delta,
|
||||
honcho_session_key=self.session_id,
|
||||
fallback_model=self._fallback_model,
|
||||
)
|
||||
@@ -2904,6 +2907,28 @@ class HermesCLI:
|
||||
except Exception as e:
|
||||
print(f" ❌ MCP reload failed: {e}")
|
||||
|
||||
_stream_started = False
|
||||
|
||||
def _stream_delta(self, text: str):
|
||||
"""Buffer streaming tokens; emit complete lines via _cprint."""
|
||||
if not text:
|
||||
return
|
||||
if not self._stream_started:
|
||||
text = text.lstrip("\n")
|
||||
if not text:
|
||||
return
|
||||
self._stream_started = True
|
||||
self._stream_buf += text
|
||||
while "\n" in self._stream_buf:
|
||||
line, self._stream_buf = self._stream_buf.split("\n", 1)
|
||||
_cprint(line)
|
||||
|
||||
def _flush_stream(self):
|
||||
"""Emit any remaining partial line from the stream buffer."""
|
||||
if self._stream_buf:
|
||||
_cprint(self._stream_buf)
|
||||
self._stream_buf = ""
|
||||
|
||||
def _clarify_callback(self, question, choices):
|
||||
"""
|
||||
Platform callback for the clarify tool. Called from the agent thread.
|
||||
@@ -3075,12 +3100,12 @@ class HermesCLI:
|
||||
message if isinstance(message, str) else "", images
|
||||
)
|
||||
|
||||
# Add user message to history
|
||||
self.conversation_history.append({"role": "user", "content": message})
|
||||
|
||||
self._stream_buf = ""
|
||||
self._stream_started = False
|
||||
|
||||
w = shutil.get_terminal_size().columns
|
||||
_cprint(f"{_GOLD}{'─' * w}{_RST}")
|
||||
print(flush=True)
|
||||
_cprint(f"\n{_GOLD}╭─ ⚕ Hermes {'─' * max(w - 15, 0)}╮{_RST}")
|
||||
|
||||
try:
|
||||
# Run the conversation with interrupt monitoring
|
||||
@@ -3126,43 +3151,28 @@ class HermesCLI:
|
||||
|
||||
agent_thread.join() # Ensure agent thread completes
|
||||
|
||||
# Drain any remaining agent output still in the StdoutProxy
|
||||
# buffer so tool/status lines render ABOVE our response box.
|
||||
# The flush pushes data into the renderer queue; the short
|
||||
# sleep lets the renderer actually paint it before we draw.
|
||||
import time as _time
|
||||
self._flush_stream()
|
||||
sys.stdout.flush()
|
||||
import time as _time
|
||||
_time.sleep(0.15)
|
||||
|
||||
# Update history with full conversation
|
||||
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
||||
|
||||
# Get the final response
|
||||
response = result.get("final_response", "") if result else ""
|
||||
|
||||
# Handle failed results (e.g., non-retryable errors like invalid model)
|
||||
|
||||
if result and result.get("failed") and not response:
|
||||
error_detail = result.get("error", "Unknown error")
|
||||
response = f"Error: {error_detail}"
|
||||
|
||||
# Handle interrupt - check if we were interrupted
|
||||
response = f"Error: {result.get('error', 'Unknown error')}"
|
||||
|
||||
pending_message = None
|
||||
if result and result.get("interrupted"):
|
||||
pending_message = result.get("interrupt_message") or interrupt_msg
|
||||
# Add indicator that we were interrupted
|
||||
if response and pending_message:
|
||||
response = response + "\n\n---\n_[Interrupted - processing new message]_"
|
||||
|
||||
if response:
|
||||
w = shutil.get_terminal_size().columns
|
||||
label = " ⚕ Hermes "
|
||||
fill = w - 2 - len(label) # 2 for ╭ and ╮
|
||||
top = f"{_GOLD}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}"
|
||||
bot = f"{_GOLD}╰{'─' * (w - 2)}╯{_RST}"
|
||||
response += "\n\n---\n_[Interrupted - processing new message]_"
|
||||
|
||||
# Render box + response as a single _cprint call so
|
||||
# nothing can interleave between the box borders.
|
||||
_cprint(f"\n{top}\n{response}\n\n{bot}")
|
||||
if response and not (self.agent and self.agent.stream_delta_callback):
|
||||
_cprint(f"\n{response}")
|
||||
|
||||
w = shutil.get_terminal_size().columns
|
||||
_cprint(f"{_GOLD}╰{'─' * (w - 2)}╯{_RST}")
|
||||
|
||||
# Play terminal bell when agent finishes (if enabled).
|
||||
# Works over SSH — the bell propagates to the user's terminal.
|
||||
@@ -3619,7 +3629,7 @@ class HermesCLI:
|
||||
return ""
|
||||
if cli_ref._agent_running:
|
||||
return "type a message + Enter to interrupt, Ctrl+C to cancel"
|
||||
return ""
|
||||
return "Ask Hermes anything... (Alt+Enter for newline)"
|
||||
|
||||
input_area.control.input_processors.append(_PlaceholderProcessor(_get_placeholder))
|
||||
|
||||
|
||||
46
datagen-config-examples/web_research.yaml
Normal file
46
datagen-config-examples/web_research.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
# datagen-config-examples/web_research.yaml
|
||||
#
|
||||
# Batch data generation config for WebResearchEnv.
|
||||
# Generates tool-calling trajectories for multi-step web research tasks.
|
||||
#
|
||||
# Usage:
|
||||
# python batch_runner.py \
|
||||
# --config datagen-config-examples/web_research.yaml \
|
||||
# --run_name web_research_v1
|
||||
|
||||
environment: web-research
|
||||
|
||||
# Toolsets available to the agent during data generation
|
||||
toolsets:
|
||||
- web
|
||||
- file
|
||||
|
||||
# How many parallel workers to use
|
||||
num_workers: 4
|
||||
|
||||
# Questions per batch
|
||||
batch_size: 20
|
||||
|
||||
# Total trajectories to generate (comment out to run full dataset)
|
||||
max_items: 500
|
||||
|
||||
# Model to use for generation (override with --model flag)
|
||||
model: openrouter/nousresearch/hermes-3-llama-3.1-405b
|
||||
|
||||
# System prompt additions (ephemeral — not saved to trajectories)
|
||||
ephemeral_system_prompt: |
|
||||
You are a highly capable research agent. When asked a factual question,
|
||||
always use web_search to find current, accurate information before answering.
|
||||
Cite at least 2 sources. Be concise and accurate.
|
||||
|
||||
# Output directory
|
||||
output_dir: data/web_research_v1
|
||||
|
||||
# Trajectory compression settings (for fitting into training token budgets)
|
||||
compression:
|
||||
enabled: true
|
||||
target_max_tokens: 16000
|
||||
|
||||
# Eval settings
|
||||
eval_every: 100 # Run eval every N trajectories
|
||||
eval_size: 25 # Number of held-out questions per eval run
|
||||
643
environments/web_research_env.py
Normal file
643
environments/web_research_env.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
WebResearchEnv — RL Environment for Multi-Step Web Research
|
||||
============================================================
|
||||
|
||||
Trains models to do accurate, efficient, multi-source web research.
|
||||
|
||||
Reward signals:
|
||||
- Answer correctness (LLM judge, 0.0–1.0)
|
||||
- Source diversity (used ≥2 distinct domains)
|
||||
- Efficiency (penalizes excessive tool calls)
|
||||
- Tool usage (bonus for actually using web tools)
|
||||
|
||||
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||
HuggingFace: google/frames-benchmark
|
||||
Fallback: built-in sample questions (no HF token needed)
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenAI-compatible server)
|
||||
python environments/web_research_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai
|
||||
|
||||
# Process mode (offline data generation)
|
||||
python environments/web_research_env.py process \\
|
||||
--env.data_path_to_save_groups data/web_research.jsonl
|
||||
|
||||
# Standalone eval
|
||||
python environments/web_research_env.py evaluate \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel
|
||||
|
||||
Built by: github.com/jackx707
|
||||
Inspired by: GroceryMind — production Hermes agent doing live web research
|
||||
across German grocery stores (firecrawl + hermes-agent)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
# Ensure hermes-agent root is on path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional HuggingFace datasets import
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||
# Multi-hop questions requiring real web search to answer.
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_QUESTIONS = [
|
||||
{
|
||||
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
|
||||
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
|
||||
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
|
||||
"answer": "Django, which Instagram was built on, is written in Python.",
|
||||
"difficulty": "easy",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
|
||||
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
|
||||
"difficulty": "hard",
|
||||
"hops": 3,
|
||||
},
|
||||
{
|
||||
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
|
||||
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "How many employees does the parent company of Instagram have?",
|
||||
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
|
||||
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
|
||||
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
|
||||
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
|
||||
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||
"""Configuration for the web research RL environment."""
|
||||
|
||||
# Reward weights
|
||||
correctness_weight: float = Field(
|
||||
default=0.6,
|
||||
description="Weight for answer correctness in reward (LLM judge score).",
|
||||
)
|
||||
tool_usage_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||
)
|
||||
efficiency_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||
)
|
||||
diversity_bonus: float = Field(
|
||||
default=0.1,
|
||||
description="Bonus reward for citing ≥2 distinct domains.",
|
||||
)
|
||||
|
||||
# Efficiency thresholds
|
||||
efficient_max_calls: int = Field(
|
||||
default=5,
|
||||
description="Maximum tool calls before efficiency penalty begins.",
|
||||
)
|
||||
heavy_penalty_calls: int = Field(
|
||||
default=10,
|
||||
description="Tool call count where efficiency penalty steepens.",
|
||||
)
|
||||
|
||||
# Eval
|
||||
eval_size: int = Field(
|
||||
default=20,
|
||||
description="Number of held-out items for evaluation.",
|
||||
)
|
||||
eval_split_ratio: float = Field(
|
||||
default=0.1,
|
||||
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||
)
|
||||
|
||||
# Dataset
|
||||
dataset_name: str = Field(
|
||||
default="google/frames-benchmark",
|
||||
description="HuggingFace dataset name for research questions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
RL environment for training multi-step web research skills.
|
||||
|
||||
The model is given a factual question requiring 2-3 hops of web research
|
||||
and must use web_search / web_extract tools to find and synthesize the answer.
|
||||
|
||||
Reward is multi-signal:
|
||||
60% — answer correctness (LLM judge)
|
||||
20% — tool usage (did the model actually search the web?)
|
||||
20% — efficiency (penalizes >5 tool calls)
|
||||
|
||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||
"""
|
||||
|
||||
name = "web-research"
|
||||
env_config_cls = WebResearchEnvConfig
|
||||
|
||||
# Default toolsets for this environment — web + file for saving notes
|
||||
default_toolsets = ["web", "file"]
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||
"""Default configuration for the web research environment."""
|
||||
env_config = WebResearchEnvConfig(
|
||||
enabled_toolsets=["web", "file"],
|
||||
max_agent_turns=15,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a highly capable research agent. When asked a factual question, "
|
||||
"always use web_search to find current, accurate information before answering. "
|
||||
"Cite at least 2 sources. Be concise and accurate."
|
||||
),
|
||||
group_size=4,
|
||||
total_steps=1000,
|
||||
steps_per_eval=100,
|
||||
use_wandb=True,
|
||||
wandb_name="web-research",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.5",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._items: list[dict] = []
|
||||
self._eval_items: list[dict] = []
|
||||
self._index: int = 0
|
||||
|
||||
# Metrics tracking for wandb
|
||||
self._reward_buffer: list[float] = []
|
||||
self._correctness_buffer: list[float] = []
|
||||
self._tool_usage_buffer: list[float] = []
|
||||
self._efficiency_buffer: list[float] = []
|
||||
self._diversity_buffer: list[float] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Setup — load dataset
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the FRAMES benchmark or fall back to built-in samples."""
|
||||
if HF_AVAILABLE:
|
||||
try:
|
||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||
ds = load_dataset(self.config.dataset_name, split="test")
|
||||
self._items = [
|
||||
{
|
||||
"question": row["Prompt"],
|
||||
"answer": row["Answer"],
|
||||
"difficulty": row.get("reasoning_types", "unknown"),
|
||||
"hops": 2,
|
||||
}
|
||||
for row in ds
|
||||
]
|
||||
# Hold out for eval
|
||||
eval_size = max(
|
||||
self.config.eval_size,
|
||||
int(len(self._items) * self.config.eval_split_ratio),
|
||||
)
|
||||
random.shuffle(self._items)
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
logger.info(
|
||||
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
|
||||
f"from FRAMES benchmark."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
|
||||
|
||||
# Fallback
|
||||
random.shuffle(SAMPLE_QUESTIONS)
|
||||
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
|
||||
self._items = SAMPLE_QUESTIONS[:split]
|
||||
self._eval_items = SAMPLE_QUESTIONS[split:]
|
||||
logger.info(
|
||||
f"Using built-in sample dataset: {len(self._items)} train / "
|
||||
f"{len(self._eval_items)} eval items."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. get_next_item — return the next question
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return the next item, cycling through the dataset."""
|
||||
if not self._items:
|
||||
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. format_prompt — build the user-facing prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Format the research question as a task prompt."""
|
||||
return (
|
||||
f"Research the following question thoroughly using web search. "
|
||||
f"You MUST search the web to find current, accurate information — "
|
||||
f"do not rely solely on your training data.\n\n"
|
||||
f"Question: {item['question']}\n\n"
|
||||
f"Requirements:\n"
|
||||
f"- Use web_search and/or web_extract tools to find information\n"
|
||||
f"- Search at least 2 different sources\n"
|
||||
f"- Provide a concise, accurate answer (2-4 sentences)\n"
|
||||
f"- Cite the sources you used"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. compute_reward — multi-signal scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_reward(
|
||||
self,
|
||||
item: dict,
|
||||
result: AgentResult,
|
||||
ctx: ToolContext,
|
||||
) -> float:
|
||||
"""
|
||||
Multi-signal reward function:
|
||||
|
||||
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||
"""
|
||||
final_response: str = result.final_response or ""
|
||||
tools_used: list[str] = [
|
||||
tc.tool_name for tc in (result.tool_calls or [])
|
||||
] if hasattr(result, "tool_calls") and result.tool_calls else []
|
||||
tool_call_count: int = result.turns_used or len(tools_used)
|
||||
|
||||
cfg = self.config
|
||||
|
||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=final_response,
|
||||
)
|
||||
|
||||
# ---- Signal 2: Web tool usage --------------------------------
|
||||
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
|
||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||
|
||||
# ---- Signal 3: Efficiency ------------------------------------
|
||||
if tool_call_count <= cfg.efficient_max_calls:
|
||||
efficiency = 1.0
|
||||
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||
else:
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||
|
||||
# ---- Bonus: Source diversity ---------------------------------
|
||||
domains = self._extract_domains(final_response)
|
||||
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||
|
||||
# ---- Combine ------------------------------------------------
|
||||
reward = (
|
||||
cfg.correctness_weight * correctness
|
||||
+ cfg.tool_usage_weight * tool_used
|
||||
+ cfg.efficiency_weight * efficiency
|
||||
+ diversity
|
||||
)
|
||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||
|
||||
# Track for wandb
|
||||
self._reward_buffer.append(reward)
|
||||
self._correctness_buffer.append(correctness)
|
||||
self._tool_usage_buffer.append(tool_used)
|
||||
self._efficiency_buffer.append(efficiency)
|
||||
self._diversity_buffer.append(diversity)
|
||||
|
||||
logger.debug(
|
||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. evaluate — run on held-out eval split
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run evaluation on the held-out split using the agent loop."""
|
||||
import time
|
||||
|
||||
items = self._eval_items
|
||||
if not items:
|
||||
logger.warning("No eval items available.")
|
||||
return
|
||||
|
||||
eval_size = min(self.config.eval_size, len(items))
|
||||
eval_items = items[:eval_size]
|
||||
|
||||
logger.info(f"Running eval on {len(eval_items)} questions...")
|
||||
start_time = time.time()
|
||||
samples = []
|
||||
|
||||
for item in eval_items:
|
||||
try:
|
||||
# Use the base env's agent loop for eval (same as training)
|
||||
prompt = self.format_prompt(item)
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
# Score the response
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=response_content,
|
||||
)
|
||||
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": response_content,
|
||||
"expected": item["answer"],
|
||||
"correctness": correctness,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": item["answer"],
|
||||
"correctness": 0.0,
|
||||
})
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Compute metrics
|
||||
correctness_scores = [s["correctness"] for s in samples]
|
||||
eval_metrics = {
|
||||
"eval/mean_correctness": (
|
||||
sum(correctness_scores) / len(correctness_scores)
|
||||
if correctness_scores else 0.0
|
||||
),
|
||||
"eval/n_items": len(samples),
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. wandb_log — custom metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||
"""Log reward breakdown metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||
wandb_metrics["train/total_rollouts"] = n
|
||||
|
||||
# Accuracy buckets
|
||||
wandb_metrics["train/correct_rate"] = (
|
||||
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||
)
|
||||
wandb_metrics["train/tool_usage_rate"] = (
|
||||
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||
)
|
||||
|
||||
# Clear buffers
|
||||
self._reward_buffer.clear()
|
||||
self._correctness_buffer.clear()
|
||||
self._tool_usage_buffer.clear()
|
||||
self._efficiency_buffer.clear()
|
||||
self._diversity_buffer.clear()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _llm_judge(
|
||||
self,
|
||||
question: str,
|
||||
expected: str,
|
||||
model_answer: str,
|
||||
) -> float:
|
||||
"""
|
||||
Use the server's LLM to judge answer correctness.
|
||||
Falls back to keyword heuristic if LLM call fails.
|
||||
"""
|
||||
if not model_answer or not model_answer.strip():
|
||||
return 0.0
|
||||
|
||||
judge_prompt = (
|
||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
f"Reference answer: {expected}\n\n"
|
||||
f"Model answer: {model_answer}\n\n"
|
||||
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
|
||||
" 1.0 = fully correct and complete\n"
|
||||
" 0.7 = mostly correct with minor gaps\n"
|
||||
" 0.4 = partially correct\n"
|
||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||
" 0.0 = completely wrong or no answer\n\n"
|
||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
n=1,
|
||||
max_tokens=150,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
text = response.choices[0].message.content if response.choices else ""
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||
|
||||
return self._heuristic_score(expected, model_answer)
|
||||
|
||||
@staticmethod
|
||||
def _parse_judge_json(text: str) -> Optional[float]:
|
||||
"""Extract the score float from LLM judge JSON response."""
|
||||
try:
|
||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||
data = json.loads(clean)
|
||||
score = float(data.get("score", -1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
except Exception:
|
||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||
"""Lightweight keyword overlap score as fallback."""
|
||||
stopwords = {
|
||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||
"this", "that", "as", "by", "from", "be", "has", "have", "had",
|
||||
}
|
||||
|
||||
def tokenize(text: str) -> set:
|
||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||
|
||||
expected_tokens = tokenize(expected)
|
||||
answer_tokens = tokenize(model_answer)
|
||||
|
||||
if not expected_tokens:
|
||||
return 0.5
|
||||
|
||||
overlap = len(expected_tokens & answer_tokens)
|
||||
union = len(expected_tokens | answer_tokens)
|
||||
|
||||
jaccard = overlap / union if union > 0 else 0.0
|
||||
recall = overlap / len(expected_tokens)
|
||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||
|
||||
@staticmethod
|
||||
def _extract_domains(text: str) -> set:
|
||||
"""Extract unique domains from URLs cited in the response."""
|
||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc.lower().lstrip("www.")
|
||||
if domain:
|
||||
domains.add(domain)
|
||||
except Exception:
|
||||
pass
|
||||
return domains
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
WebResearchEnv.cli()
|
||||
@@ -319,7 +319,7 @@ class SendResult:
|
||||
raw_response: Any = None
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
# Handler may return str (sent by base) or dict(content=..., already_sent=True).
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
||||
|
||||
|
||||
@@ -691,11 +691,20 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
try:
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
handler_result = await self._message_handler(event)
|
||||
|
||||
# Normalise: handler may return str or dict(content, already_sent)
|
||||
already_sent = False
|
||||
if isinstance(handler_result, dict):
|
||||
response = handler_result.get("content") or ""
|
||||
already_sent = handler_result.get("already_sent", False)
|
||||
else:
|
||||
response = handler_result
|
||||
|
||||
# Send response if any
|
||||
if not response:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if not already_sent:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
@@ -706,7 +715,7 @@ class BasePlatformAdapter(ABC):
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
if text_content and not already_sent:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
|
||||
@@ -10,6 +10,7 @@ Uses slack-bolt (Python) with Socket Mode for:
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
try:
|
||||
@@ -33,6 +34,8 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
@@ -96,6 +99,13 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
async def handle_message_event(event, say):
|
||||
await self._handle_slack_message(event)
|
||||
|
||||
# Acknowledge app_mention events to prevent Bolt 404 errors.
|
||||
# The "message" handler above already processes @mentions in
|
||||
# channels, so this is intentionally a no-op to avoid duplicates.
|
||||
@self._app.event("app_mention")
|
||||
async def handle_app_mention(event, say):
|
||||
pass
|
||||
|
||||
# Register slash command handler
|
||||
@self._app.command("/hermes")
|
||||
async def handle_hermes_command(ack, command):
|
||||
@@ -266,6 +276,65 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a video file to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=video_path,
|
||||
filename=os.path.basename(video_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send video: {e}")
|
||||
return await super().send_video(chat_id, video_path, caption, reply_to)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
display_name = file_name or os.path.basename(file_path)
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=file_path,
|
||||
filename=display_name,
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send document: {e}")
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Slack channel."""
|
||||
if not self._app:
|
||||
@@ -347,6 +416,58 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
msg_type = MessageType.VOICE
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache audio: {e}", flush=True)
|
||||
elif url:
|
||||
# Try to handle as a document attachment
|
||||
try:
|
||||
original_filename = f.get("name", "")
|
||||
ext = ""
|
||||
if original_filename:
|
||||
_, ext = os.path.splitext(original_filename)
|
||||
ext = ext.lower()
|
||||
|
||||
# Fallback: reverse-lookup from MIME type
|
||||
if not ext and mimetype:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(mimetype, "")
|
||||
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
continue # Skip unsupported file types silently
|
||||
|
||||
# Check file size (Slack limit: 20 MB for bots)
|
||||
file_size = f.get("size", 0)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not file_size or file_size > MAX_DOC_BYTES:
|
||||
print(f"[Slack] Document too large or unknown size: {file_size}", flush=True)
|
||||
continue
|
||||
|
||||
# Download and cache
|
||||
raw_bytes = await self._download_slack_file_bytes(url)
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, original_filename or f"document{ext}"
|
||||
)
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
msg_type = MessageType.DOCUMENT
|
||||
print(f"[Slack] Cached user document: {cached_path}", flush=True)
|
||||
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if text:
|
||||
text = f"{injection}\n\n{text}"
|
||||
else:
|
||||
text = injection
|
||||
except UnicodeDecodeError:
|
||||
pass # Binary content, skip injection
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache document: {e}", flush=True)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
@@ -427,3 +548,16 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
"""Download a Slack file and return raw bytes."""
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@@ -75,11 +75,16 @@ if _config_path.exists():
|
||||
"container_memory": "TERMINAL_CONTAINER_MEMORY",
|
||||
"container_disk": "TERMINAL_CONTAINER_DISK",
|
||||
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
|
||||
"docker_volumes": "TERMINAL_DOCKER_VOLUMES",
|
||||
"sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
}
|
||||
for _cfg_key, _env_var in _terminal_env_map.items():
|
||||
if _cfg_key in _terminal_cfg:
|
||||
os.environ[_env_var] = str(_terminal_cfg[_cfg_key])
|
||||
_val = _terminal_cfg[_cfg_key]
|
||||
if isinstance(_val, list):
|
||||
os.environ[_env_var] = json.dumps(_val)
|
||||
else:
|
||||
os.environ[_env_var] = str(_val)
|
||||
_compression_cfg = _cfg.get("compression", {})
|
||||
if _compression_cfg and isinstance(_compression_cfg, dict):
|
||||
_compression_env_map = {
|
||||
@@ -1291,7 +1296,9 @@ class GatewayRunner:
|
||||
|
||||
# Update session
|
||||
self.session_store.update_session(session_entry.session_key)
|
||||
|
||||
|
||||
if agent_result.get("already_sent"):
|
||||
return {"content": response, "already_sent": True}
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -2450,6 +2457,83 @@ class GatewayRunner:
|
||||
# Queue for progress messages (thread-safe)
|
||||
progress_queue = queue.Queue() if tool_progress_enabled else None
|
||||
last_tool = [None] # Mutable container for tracking in closure
|
||||
|
||||
# Streaming token queue — same pattern as progress_queue but for
|
||||
# assistant text deltas. An async drain task sends/edits a single
|
||||
# platform message with the accumulated text.
|
||||
stream_queue = queue.Queue()
|
||||
stream_sent = [False] # set True once any delta was delivered
|
||||
|
||||
def _stream_delta(text: str):
|
||||
stream_queue.put(text)
|
||||
|
||||
async def send_stream_messages():
|
||||
"""Drain stream_queue, deliver via send/edit_message."""
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if not _adapter:
|
||||
return
|
||||
|
||||
accumulated = []
|
||||
msg_id = None
|
||||
can_edit = True
|
||||
last_edit_ts = 0.0
|
||||
EDIT_INTERVAL = 0.6 # seconds between edits (rate-limit safe)
|
||||
|
||||
while True:
|
||||
try:
|
||||
delta = stream_queue.get_nowait()
|
||||
accumulated.append(delta)
|
||||
stream_sent[0] = True
|
||||
|
||||
now = asyncio.get_event_loop().time()
|
||||
if now - last_edit_ts < EDIT_INTERVAL:
|
||||
# Coalesce — will flush on next poll cycle
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
full_text = "".join(accumulated)
|
||||
if msg_id is None:
|
||||
res = await _adapter.send(
|
||||
chat_id=source.chat_id, content=full_text)
|
||||
if res.success and res.message_id:
|
||||
msg_id = res.message_id
|
||||
elif can_edit:
|
||||
res = await _adapter.edit_message(
|
||||
chat_id=source.chat_id,
|
||||
message_id=msg_id,
|
||||
content=full_text,
|
||||
)
|
||||
if not res.success:
|
||||
can_edit = False
|
||||
last_edit_ts = now
|
||||
|
||||
except queue.Empty:
|
||||
await asyncio.sleep(0.15)
|
||||
except asyncio.CancelledError:
|
||||
# Final flush
|
||||
while not stream_queue.empty():
|
||||
try:
|
||||
accumulated.append(stream_queue.get_nowait())
|
||||
except Exception:
|
||||
break
|
||||
if accumulated:
|
||||
full_text = "".join(accumulated)
|
||||
if msg_id is None:
|
||||
await _adapter.send(
|
||||
chat_id=source.chat_id, content=full_text)
|
||||
elif can_edit:
|
||||
try:
|
||||
await _adapter.edit_message(
|
||||
chat_id=source.chat_id,
|
||||
message_id=msg_id,
|
||||
content=full_text,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error("Stream message error: %s", e)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
def progress_callback(tool_name: str, preview: str = None, args: dict = None):
|
||||
"""Callback invoked by agent when a tool is called."""
|
||||
@@ -2693,6 +2777,7 @@ class GatewayRunner:
|
||||
session_id=session_id,
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
|
||||
stream_delta_callback=_stream_delta,
|
||||
platform=platform_key,
|
||||
honcho_session_key=session_key,
|
||||
session_db=self._session_db,
|
||||
@@ -2815,12 +2900,16 @@ class GatewayRunner:
|
||||
"api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0,
|
||||
"tools": tools_holder[0] or [],
|
||||
"history_offset": len(agent_history),
|
||||
"already_sent": stream_sent[0],
|
||||
}
|
||||
|
||||
# Start progress message sender if enabled
|
||||
progress_task = None
|
||||
if tool_progress_enabled:
|
||||
progress_task = asyncio.create_task(send_progress_messages())
|
||||
|
||||
# Start stream message sender
|
||||
stream_task = asyncio.create_task(send_stream_messages())
|
||||
|
||||
# Track this agent as running for this session (for interrupt support)
|
||||
# We do this in a callback after the agent is created
|
||||
@@ -2896,9 +2985,10 @@ class GatewayRunner:
|
||||
session_key=session_key
|
||||
)
|
||||
finally:
|
||||
# Stop progress sender and interrupt monitor
|
||||
# Stop progress sender, stream sender, and interrupt monitor
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
stream_task.cancel()
|
||||
interrupt_monitor.cancel()
|
||||
|
||||
# Clean up tracking
|
||||
@@ -2907,7 +2997,7 @@ class GatewayRunner:
|
||||
del self._running_agents[session_key]
|
||||
|
||||
# Wait for cancelled tasks
|
||||
for task in [progress_task, interrupt_monitor, tracking_task]:
|
||||
for task in [progress_task, stream_task, interrupt_monitor, tracking_task]:
|
||||
if task:
|
||||
try:
|
||||
await task
|
||||
|
||||
@@ -47,7 +47,7 @@ def _fetch_models_from_api(access_token: str) -> List[str]:
|
||||
if item.get("supported_in_api") is False:
|
||||
continue
|
||||
visibility = item.get("visibility", "")
|
||||
if isinstance(visibility, str) and visibility.strip().lower() == "hide":
|
||||
if isinstance(visibility, str) and visibility.strip().lower() == "hidden":
|
||||
continue
|
||||
priority = item.get("priority")
|
||||
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
|
||||
|
||||
@@ -77,6 +77,10 @@ DEFAULT_CONFIG = {
|
||||
"container_memory": 5120, # MB (default 5GB)
|
||||
"container_disk": 51200, # MB (default 50GB)
|
||||
"container_persistent": True, # Persist filesystem across sessions
|
||||
# Docker volume mounts — share host directories with the container.
|
||||
# Each entry is "host_path:container_path" (standard Docker -v syntax).
|
||||
# Example: ["/home/user/projects:/workspace/projects", "/data:/data"]
|
||||
"docker_volumes": [],
|
||||
},
|
||||
|
||||
"browser": {
|
||||
@@ -401,14 +405,18 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "messaging",
|
||||
},
|
||||
"SLACK_BOT_TOKEN": {
|
||||
"description": "Slack bot integration",
|
||||
"description": "Slack bot token (xoxb-). Get from OAuth & Permissions after installing your app. "
|
||||
"Required scopes: chat:write, app_mentions:read, channels:history, groups:history, "
|
||||
"im:history, im:read, im:write, users:read, files:write",
|
||||
"prompt": "Slack Bot Token (xoxb-...)",
|
||||
"url": "https://api.slack.com/apps",
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
"SLACK_APP_TOKEN": {
|
||||
"description": "Slack Socket Mode connection",
|
||||
"description": "Slack app-level token (xapp-) for Socket Mode. Get from Basic Information → "
|
||||
"App-Level Tokens. Also ensure Event Subscriptions include: message.im, "
|
||||
"message.channels, message.groups, app_mention",
|
||||
"prompt": "Slack App Token (xapp-...)",
|
||||
"url": "https://api.slack.com/apps",
|
||||
"password": True,
|
||||
|
||||
@@ -482,14 +482,19 @@ _PLATFORMS = [
|
||||
"token_var": "SLACK_BOT_TOKEN",
|
||||
"setup_instructions": [
|
||||
"1. Go to https://api.slack.com/apps → Create New App → From Scratch",
|
||||
"2. Enable Socket Mode: App Settings → Socket Mode → Enable",
|
||||
"3. Get Bot Token: OAuth & Permissions → Install to Workspace → copy xoxb-... token",
|
||||
"4. Get App Token: Basic Information → App-Level Tokens → Generate",
|
||||
" Name it anything, add scope: connections:write → copy xapp-... token",
|
||||
"5. Add bot scopes: OAuth & Permissions → Scopes → chat:write, im:history,",
|
||||
" im:read, im:write, channels:history, channels:read",
|
||||
"6. Reinstall the app to your workspace after adding scopes",
|
||||
"2. Enable Socket Mode: Settings → Socket Mode → Enable",
|
||||
" Create an App-Level Token with scope: connections:write → copy xapp-... token",
|
||||
"3. Add Bot Token Scopes: Features → OAuth & Permissions → Scopes",
|
||||
" Required: chat:write, app_mentions:read, channels:history, channels:read,",
|
||||
" groups:history, im:history, im:read, im:write, users:read, files:write",
|
||||
"4. Subscribe to Events: Features → Event Subscriptions → Enable",
|
||||
" Required events: message.im, message.channels, app_mention",
|
||||
" Optional: message.groups (for private channels)",
|
||||
" ⚠ Without message.channels the bot will ONLY work in DMs!",
|
||||
"5. Install to Workspace: Settings → Install App → copy xoxb-... token",
|
||||
"6. Reinstall the app after any scope or event changes",
|
||||
"7. Find your user ID: click your profile → three dots → Copy member ID",
|
||||
"8. Invite the bot to channels: /invite @YourBot",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "SLACK_BOT_TOKEN", "prompt": "Bot Token (xoxb-...)", "password": True,
|
||||
|
||||
@@ -1572,10 +1572,22 @@ def setup_gateway(config: dict):
|
||||
|
||||
if not existing_slack and prompt_yes_no("Set up Slack bot?", False):
|
||||
print_info("Steps to create a Slack app:")
|
||||
print_info(" 1. Go to https://api.slack.com/apps → Create New App")
|
||||
print_info(" 2. Enable Socket Mode: App Settings → Socket Mode → Enable")
|
||||
print_info(" 3. Bot Token: OAuth & Permissions → Install to Workspace")
|
||||
print_info(" 4. App Token: Basic Information → App-Level Tokens → Generate")
|
||||
print_info(" 1. Go to https://api.slack.com/apps → Create New App (from scratch)")
|
||||
print_info(" 2. Enable Socket Mode: Settings → Socket Mode → Enable")
|
||||
print_info(" • Create an App-Level Token with 'connections:write' scope")
|
||||
print_info(" 3. Add Bot Token Scopes: Features → OAuth & Permissions")
|
||||
print_info(" Required scopes: chat:write, app_mentions:read,")
|
||||
print_info(" channels:history, channels:read, groups:history,")
|
||||
print_info(" im:history, im:read, im:write, users:read, files:write")
|
||||
print_info(" 4. Subscribe to Events: Features → Event Subscriptions → Enable")
|
||||
print_info(" Required events: message.im, message.channels,")
|
||||
print_info(" message.groups, app_mention")
|
||||
print_warning(" ⚠ Without message.channels/message.groups events,")
|
||||
print_warning(" the bot will ONLY work in DMs, not channels!")
|
||||
print_info(" 5. Install to Workspace: Settings → Install App")
|
||||
print_info(" 6. After installing, invite the bot to channels: /invite @YourBot")
|
||||
print()
|
||||
print_info(" Full guide: https://hermes-agent.ai/docs/user-guide/messaging/slack")
|
||||
print()
|
||||
bot_token = prompt("Slack Bot Token (xoxb-...)", password=True)
|
||||
if bot_token:
|
||||
@@ -1587,7 +1599,7 @@ def setup_gateway(config: dict):
|
||||
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can use your bot")
|
||||
print_info(" Find Slack user IDs in your profile or via the Slack API")
|
||||
print_info(" To find a Member ID: click a user's name → View full profile → ⋮ → Copy member ID")
|
||||
print()
|
||||
allowed_users = prompt("Allowed user IDs (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
modal = ["swe-rex[modal]>=1.4.0"]
|
||||
daytona = ["daytona>=0.148.0"]
|
||||
dev = ["pytest", "pytest-asyncio"]
|
||||
dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
cron = ["croniter"]
|
||||
slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
|
||||
158
run_agent.py
158
run_agent.py
@@ -174,6 +174,7 @@ class AIAgent:
|
||||
tool_progress_callback: callable = None,
|
||||
clarify_callback: callable = None,
|
||||
step_callback: callable = None,
|
||||
stream_delta_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
@@ -258,6 +259,7 @@ class AIAgent:
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
@@ -2158,6 +2160,137 @@ class AIAgent:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
def _interruptible_streaming_api_call(self, api_kwargs: dict, on_first_delta=None):
|
||||
"""Streaming variant of _interruptible_api_call for chat_completions.
|
||||
|
||||
Fires self.stream_delta_callback(text) as content tokens arrive and
|
||||
accumulates the full response into a SimpleNamespace matching the shape
|
||||
downstream code expects. Falls back to the non-streaming path when the
|
||||
provider rejects the stream request.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
|
||||
result = {"response": None, "error": None}
|
||||
first_delta_fired = [False]
|
||||
|
||||
def _stream():
|
||||
try:
|
||||
stream_kwargs = {**api_kwargs, "stream": True,
|
||||
"stream_options": {"include_usage": True}}
|
||||
stream = self.client.chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts = []
|
||||
tool_calls_acc = {}
|
||||
finish_reason = "stop"
|
||||
usage = None
|
||||
reasoning_content = None
|
||||
model = None
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = chunk.usage
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
if model is None and hasattr(chunk, "model"):
|
||||
model = chunk.model
|
||||
|
||||
delta = choice.delta
|
||||
if delta is None:
|
||||
continue
|
||||
|
||||
if delta.content:
|
||||
content_parts.append(delta.content)
|
||||
if not first_delta_fired[0]:
|
||||
first_delta_fired[0] = True
|
||||
if on_first_delta:
|
||||
on_first_delta()
|
||||
if self.stream_delta_callback:
|
||||
try:
|
||||
self.stream_delta_callback(delta.content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc_delta.id or "",
|
||||
"type": tc_delta.type or "function",
|
||||
"function": {
|
||||
"name": getattr(tc_delta.function, "name", None) or "",
|
||||
"arguments": getattr(tc_delta.function, "arguments", None) or "",
|
||||
},
|
||||
}
|
||||
else:
|
||||
entry = tool_calls_acc[idx]
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
fn = tc_delta.function
|
||||
if fn:
|
||||
if fn.name:
|
||||
entry["function"]["name"] = fn.name
|
||||
if fn.arguments:
|
||||
entry["function"]["arguments"] += fn.arguments
|
||||
|
||||
rc = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
||||
if rc:
|
||||
reasoning_content = (reasoning_content or "") + rc
|
||||
|
||||
tool_calls_list = None
|
||||
if tool_calls_acc:
|
||||
tool_calls_list = [
|
||||
SimpleNamespace(
|
||||
id=tc["id"], call_id=tc["id"], type=tc["type"],
|
||||
function=SimpleNamespace(name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"]),
|
||||
)
|
||||
for idx, tc in sorted(tool_calls_acc.items())
|
||||
]
|
||||
|
||||
message = SimpleNamespace(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls_list,
|
||||
reasoning=reasoning_content,
|
||||
reasoning_content=reasoning_content,
|
||||
reasoning_details=None,
|
||||
)
|
||||
result["response"] = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=message, finish_reason=finish_reason)],
|
||||
usage=usage,
|
||||
model=model,
|
||||
)
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
|
||||
t = threading.Thread(target=_stream, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
if self._interrupt_requested:
|
||||
try:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during streaming API call")
|
||||
|
||||
if result["error"] is not None:
|
||||
err = result["error"]
|
||||
err_str = str(err).lower()
|
||||
if any(kw in err_str for kw in ("stream", "not support", "unsupported")):
|
||||
logger.debug("Streaming failed (%s), falling back to non-streaming.", err)
|
||||
return self._interruptible_api_call(api_kwargs)
|
||||
raise err
|
||||
return result["response"]
|
||||
|
||||
# ── Provider fallback ──────────────────────────────────────────────────
|
||||
|
||||
# API-key providers: provider → (base_url, [env_var_names])
|
||||
@@ -3353,12 +3486,27 @@ class AIAgent:
|
||||
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
if self.stream_delta_callback and self.api_mode != "codex_responses":
|
||||
def _stop_spinner():
|
||||
nonlocal thinking_spinner
|
||||
if thinking_spinner:
|
||||
thinking_spinner.stop("")
|
||||
thinking_spinner = None
|
||||
|
||||
response = self._interruptible_streaming_api_call(
|
||||
api_kwargs, on_first_delta=_stop_spinner)
|
||||
|
||||
# Separate streamed content from tool status lines
|
||||
msg = getattr(response, "choices", [None])[0]
|
||||
if msg and getattr(msg, "message", None):
|
||||
m = msg.message
|
||||
if m.content and m.tool_calls:
|
||||
print(flush=True)
|
||||
else:
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
# Stop thinking spinner silently -- the response box or tool
|
||||
# execution messages that follow are more informative.
|
||||
if thinking_spinner:
|
||||
thinking_spinner.stop("")
|
||||
thinking_spinner = None
|
||||
@@ -4055,8 +4203,8 @@ class AIAgent:
|
||||
turn_content = assistant_message.content or ""
|
||||
if turn_content and self._has_content_after_think_block(turn_content):
|
||||
self._last_content_with_tools = turn_content
|
||||
# Show intermediate commentary so the user can follow along
|
||||
if self.quiet_mode:
|
||||
# Show intermediate commentary — skip when streaming (already in buffer)
|
||||
if self.quiet_mode and not self.stream_delta_callback:
|
||||
clean = self._strip_think_blocks(turn_content).strip()
|
||||
if clean:
|
||||
print(f" ┊ 💬 {clean}")
|
||||
|
||||
@@ -176,14 +176,18 @@ class TestVisionClientFallback:
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.3-codex"
|
||||
|
||||
def test_vision_auto_skips_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is skipped in vision auto mode."""
|
||||
def test_vision_auto_falls_back_to_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is used as fallback in vision auto mode.
|
||||
|
||||
Many local models (Qwen-VL, LLaVA, etc.) support vision.
|
||||
When no OpenRouter/Nous/Codex is available, try the custom endpoint.
|
||||
"""
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:1234/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None):
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is None
|
||||
assert model is None
|
||||
assert client is not None # Custom endpoint picked up as fallback
|
||||
|
||||
def test_vision_uses_openrouter_when_available(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
532
tests/gateway/test_slack.py
Normal file
532
tests/gateway/test_slack.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""
|
||||
Tests for Slack platform adapter.
|
||||
|
||||
Covers: app_mention handler, send_document, send_video,
|
||||
incoming document handling, message routing.
|
||||
|
||||
Note: slack-bolt may not be installed in the test environment.
|
||||
We mock the slack modules at import time to avoid collection errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock the slack-bolt package if it's not installed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_slack_mock():
|
||||
"""Install mock slack modules so SlackAdapter can be imported."""
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return # Real library installed
|
||||
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
|
||||
|
||||
slack_sdk = MagicMock()
|
||||
slack_sdk.web.async_client.AsyncWebClient = MagicMock
|
||||
|
||||
for name, mod in [
|
||||
("slack_bolt", slack_bolt),
|
||||
("slack_bolt.async_app", slack_bolt.async_app),
|
||||
("slack_bolt.adapter", slack_bolt.adapter),
|
||||
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
|
||||
("slack_bolt.adapter.socket_mode.async_handler", slack_bolt.adapter.socket_mode.async_handler),
|
||||
("slack_sdk", slack_sdk),
|
||||
("slack_sdk.web", slack_sdk.web),
|
||||
("slack_sdk.web.async_client", slack_sdk.web.async_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
# Patch SLACK_AVAILABLE before importing the adapter
|
||||
import gateway.platforms.slack as _slack_mod
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter():
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake-token")
|
||||
a = SlackAdapter(config)
|
||||
# Mock the Slack app client
|
||||
a._app = MagicMock()
|
||||
a._app.client = AsyncMock()
|
||||
a._bot_user_id = "U_BOT"
|
||||
a._running = True
|
||||
# Capture events instead of processing them
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests don't touch ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestAppMentionHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAppMentionHandler:
|
||||
"""Verify that the app_mention event handler is registered."""
|
||||
|
||||
def test_app_mention_registered_on_connect(self):
|
||||
"""connect() should register both 'message' and 'app_mention' handlers."""
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
adapter = SlackAdapter(config)
|
||||
|
||||
# Track which events get registered
|
||||
registered_events = []
|
||||
registered_commands = []
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
def mock_event(event_type):
|
||||
def decorator(fn):
|
||||
registered_events.append(event_type)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
def mock_command(cmd):
|
||||
def decorator(fn):
|
||||
registered_commands.append(cmd)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
mock_app.event = mock_event
|
||||
mock_app.command = mock_command
|
||||
mock_app.client = AsyncMock()
|
||||
mock_app.client.auth_test = AsyncMock(return_value={
|
||||
"user_id": "U_BOT",
|
||||
"user": "testbot",
|
||||
})
|
||||
|
||||
with patch.object(_slack_mod, "AsyncApp", return_value=mock_app), \
|
||||
patch.object(_slack_mod, "AsyncSocketModeHandler", return_value=MagicMock()), \
|
||||
patch.dict(os.environ, {"SLACK_APP_TOKEN": "xapp-fake"}), \
|
||||
patch("asyncio.create_task"):
|
||||
asyncio.get_event_loop().run_until_complete(adapter.connect())
|
||||
|
||||
assert "message" in registered_events
|
||||
assert "app_mention" in registered_events
|
||||
assert "/hermes" in registered_commands
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendDocument
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendDocument:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_success(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "report.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
caption="Here's the report",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
adapter._app.client.files_upload_v2.assert_called_once()
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["channel"] == "C123"
|
||||
assert call_kwargs["file"] == str(test_file)
|
||||
assert call_kwargs["filename"] == "report.pdf"
|
||||
assert call_kwargs["initial_comment"] == "Here's the report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_custom_name(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "data.csv"
|
||||
test_file.write_bytes(b"a,b,c\n1,2,3")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
file_name="quarterly-report.csv",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["filename"] == "quarterly-report.csv"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_missing_file(self, adapter):
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path="/nonexistent/file.pdf",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path="/some/file.pdf",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_api_error_falls_back(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "doc.pdf"
|
||||
test_file.write_bytes(b"content")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=RuntimeError("Slack API error")
|
||||
)
|
||||
|
||||
# Should fall back to base class (text message)
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
)
|
||||
|
||||
# Base class send() is also mocked, so check it was attempted
|
||||
adapter._app.client.chat_postMessage.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_with_thread(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "notes.txt"
|
||||
test_file.write_bytes(b"some notes")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
reply_to="1234567890.123456",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["thread_ts"] == "1234567890.123456"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendVideo
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendVideo:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_success(self, adapter, tmp_path):
|
||||
video = tmp_path / "clip.mp4"
|
||||
video.write_bytes(b"fake video data")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path=str(video),
|
||||
caption="Check this out",
|
||||
)
|
||||
|
||||
assert result.success
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["filename"] == "clip.mp4"
|
||||
assert call_kwargs["initial_comment"] == "Check this out"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_missing_file(self, adapter):
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path="/nonexistent/video.mp4",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path="/some/video.mp4",
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_api_error_falls_back(self, adapter, tmp_path):
|
||||
video = tmp_path / "clip.mp4"
|
||||
video.write_bytes(b"fake video")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=RuntimeError("Slack API error")
|
||||
)
|
||||
|
||||
# Should fall back to base class (text message)
|
||||
result = await adapter.send_video(
|
||||
chat_id="C123",
|
||||
video_path=str(video),
|
||||
)
|
||||
|
||||
adapter._app.client.chat_postMessage.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestIncomingDocumentHandling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIncomingDocumentHandling:
|
||||
def _make_event(self, files=None, text="hello", channel_type="im"):
|
||||
"""Build a mock Slack message event with file attachments."""
|
||||
return {
|
||||
"text": text,
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": channel_type,
|
||||
"ts": "1234567890.000001",
|
||||
"files": files or [],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_document_cached(self, adapter):
|
||||
"""A PDF attachment should be downloaded, cached, and set as DOCUMENT type."""
|
||||
pdf_bytes = b"%PDF-1.4 fake content"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = pdf_bytes
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "report.pdf",
|
||||
"url_private_download": "https://files.slack.com/report.pdf",
|
||||
"size": len(pdf_bytes),
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.DOCUMENT
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert os.path.exists(msg_event.media_urls[0])
|
||||
assert msg_event.media_types == ["application/pdf"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_txt_document_injects_content(self, adapter):
|
||||
"""A .txt file under 100KB should have its content injected into event text."""
|
||||
content = b"Hello from a text file"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(
|
||||
text="summarize this",
|
||||
files=[{
|
||||
"mimetype": "text/plain",
|
||||
"name": "notes.txt",
|
||||
"url_private_download": "https://files.slack.com/notes.txt",
|
||||
"size": len(content),
|
||||
}],
|
||||
)
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "Hello from a text file" in msg_event.text
|
||||
assert "[Content of notes.txt]" in msg_event.text
|
||||
assert "summarize this" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_md_document_injects_content(self, adapter):
|
||||
"""A .md file under 100KB should have its content injected."""
|
||||
content = b"# Title\nSome markdown content"
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "text/markdown",
|
||||
"name": "readme.md",
|
||||
"url_private_download": "https://files.slack.com/readme.md",
|
||||
"size": len(content),
|
||||
}], text="")
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "# Title" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_txt_not_injected(self, adapter):
|
||||
"""A .txt file over 100KB should be cached but NOT injected."""
|
||||
content = b"x" * (200 * 1024)
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "text/plain",
|
||||
"name": "big.txt",
|
||||
"url_private_download": "https://files.slack.com/big.txt",
|
||||
"size": len(content),
|
||||
}], text="")
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert "[Content of" not in (msg_event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_file_type_skipped(self, adapter):
|
||||
"""A .zip file should be silently skipped."""
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/zip",
|
||||
"name": "archive.zip",
|
||||
"url_private_download": "https://files.slack.com/archive.zip",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.TEXT
|
||||
assert len(msg_event.media_urls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_document_skipped(self, adapter):
|
||||
"""A document over 20MB should be skipped."""
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "huge.pdf",
|
||||
"url_private_download": "https://files.slack.com/huge.pdf",
|
||||
"size": 25 * 1024 * 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_document_download_error_handled(self, adapter):
|
||||
"""If document download fails, handler should not crash."""
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.side_effect = RuntimeError("download failed")
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "application/pdf",
|
||||
"name": "report.pdf",
|
||||
"url_private_download": "https://files.slack.com/report.pdf",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Handler should still be called (the exception is caught)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_still_handled(self, adapter):
|
||||
"""Image attachments should still go through the image path, not document."""
|
||||
with patch.object(adapter, "_download_slack_file", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = "/tmp/cached_image.jpg"
|
||||
event = self._make_event(files=[{
|
||||
"mimetype": "image/jpeg",
|
||||
"name": "photo.jpg",
|
||||
"url_private_download": "https://files.slack.com/photo.jpg",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.PHOTO
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMessageRouting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMessageRouting:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_processed_without_mention(self, adapter):
|
||||
"""DM messages should be processed without requiring a bot mention."""
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_message_requires_mention(self, adapter):
|
||||
"""Channel messages without a bot mention should be ignored."""
|
||||
event = {
|
||||
"text": "just talking",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "channel",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_mention_strips_bot_id(self, adapter):
|
||||
"""When mentioned in a channel, the bot mention should be stripped."""
|
||||
event = {
|
||||
"text": "<@U_BOT> what's the weather?",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "channel",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "what's the weather?"
|
||||
assert "<@U_BOT>" not in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_messages_ignored(self, adapter):
|
||||
"""Messages from bots should be ignored."""
|
||||
event = {
|
||||
"text": "bot response",
|
||||
"bot_id": "B_OTHER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_edits_ignored(self, adapter):
|
||||
"""Message edits should be ignored."""
|
||||
event = {
|
||||
"text": "edited message",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
"subtype": "message_changed",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
256
tests/test_streaming.py
Normal file
256
tests/test_streaming.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Tests for streaming token output — accumulator shape, callback order, fallback."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_defs(*names):
|
||||
return [
|
||||
{"type": "function", "function": {"name": n, "description": f"{n}", "parameters": {"type": "object", "properties": {}}}}
|
||||
for n in names
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent():
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
cb = MagicMock()
|
||||
a = AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=cb,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
a._stream_cb = cb
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — fake streaming chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _chunk(content=None, tool_call_delta=None, finish_reason=None, usage=None, model=None):
|
||||
delta = SimpleNamespace(content=content, tool_calls=tool_call_delta)
|
||||
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason)
|
||||
c = SimpleNamespace(choices=[choice])
|
||||
if usage is not None:
|
||||
c.usage = SimpleNamespace(**usage)
|
||||
if model:
|
||||
c.model = model
|
||||
return c
|
||||
|
||||
|
||||
def _usage_chunk(**kw):
|
||||
c = SimpleNamespace(choices=[], usage=SimpleNamespace(**kw))
|
||||
return c
|
||||
|
||||
|
||||
def _tc_delta(index, id=None, name=None, arguments=None, type=None):
|
||||
fn = SimpleNamespace(name=name, arguments=arguments)
|
||||
return SimpleNamespace(index=index, id=id, type=type, function=fn)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: accumulator shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamingAccumulator:
|
||||
def test_text_only_response(self, agent):
|
||||
"""Streaming text-only response produces correct synthetic shape."""
|
||||
chunks = [
|
||||
_chunk(content="Hello", model="test/m"),
|
||||
_chunk(content=" world"),
|
||||
_chunk(finish_reason="stop"),
|
||||
_usage_chunk(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "Hello world"
|
||||
assert resp.choices[0].message.tool_calls is None
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert resp.usage.prompt_tokens == 10
|
||||
assert resp.model == "test/m"
|
||||
|
||||
def test_tool_call_response(self, agent):
|
||||
"""Streaming tool-call response accumulates function name + arguments."""
|
||||
chunks = [
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="call_1", name="web_search", arguments='{"q', type="function")]),
|
||||
_chunk(tool_call_delta=[_tc_delta(0, arguments='uery": "hi"}')]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert tc is not None
|
||||
assert len(tc) == 1
|
||||
assert tc[0].id == "call_1"
|
||||
assert tc[0].function.name == "web_search"
|
||||
assert tc[0].function.arguments == '{"query": "hi"}'
|
||||
assert resp.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
def test_mixed_content_and_tool_calls(self, agent):
|
||||
"""Content + tool calls in same stream are both accumulated."""
|
||||
chunks = [
|
||||
_chunk(content="Let me check."),
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="web_search", arguments="{}", type="function")]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "Let me check."
|
||||
assert len(resp.choices[0].message.tool_calls) == 1
|
||||
|
||||
|
||||
class TestStreamingCallbacks:
|
||||
def test_deltas_fire_in_order(self, agent):
|
||||
"""stream_delta_callback receives content deltas in order."""
|
||||
received = []
|
||||
agent.stream_delta_callback = lambda t: received.append(t)
|
||||
chunks = [_chunk(content="a"), _chunk(content="b"), _chunk(content="c"), _chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert received == ["a", "b", "c"]
|
||||
|
||||
def test_on_first_delta_fires_once(self, agent):
|
||||
first = MagicMock()
|
||||
chunks = [_chunk(content="x"), _chunk(content="y"), _chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"}, on_first_delta=first)
|
||||
|
||||
first.assert_called_once()
|
||||
|
||||
def test_tool_only_does_not_fire_callback(self, agent):
|
||||
"""Tool-call-only stream does not invoke stream_delta_callback."""
|
||||
received = []
|
||||
agent.stream_delta_callback = lambda t: received.append(t)
|
||||
chunks = [
|
||||
_chunk(tool_call_delta=[_tc_delta(0, id="c1", name="t", arguments="{}", type="function")]),
|
||||
_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert received == []
|
||||
|
||||
|
||||
class TestStreamingFallback:
|
||||
def test_stream_error_falls_back(self, agent):
|
||||
"""When streaming fails with 'not support', falls back to non-streaming."""
|
||||
agent.client.chat.completions.create.side_effect = [
|
||||
Exception("streaming not supported by this provider"),
|
||||
SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content="ok", tool_calls=None, reasoning=None, reasoning_content=None, reasoning_details=None),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
model="test/m",
|
||||
),
|
||||
]
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
assert resp.choices[0].message.content == "ok"
|
||||
assert agent.client.chat.completions.create.call_count == 2
|
||||
|
||||
def test_non_stream_error_raises(self, agent):
|
||||
"""Non-stream-related errors propagate normally."""
|
||||
agent.client.chat.completions.create.side_effect = ValueError("bad request")
|
||||
|
||||
with pytest.raises(ValueError, match="bad request"):
|
||||
agent._interruptible_streaming_api_call({"model": "test"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: base.py already_sent contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAlreadySentContract:
|
||||
def _make_adapter(self, send_side_effect=None):
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
class FakeAdapter(BasePlatformAdapter):
|
||||
async def connect(self): return True
|
||||
async def disconnect(self): pass
|
||||
async def get_chat_info(self, chat_id): return {"name": "test"}
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
if send_side_effect is not None:
|
||||
send_side_effect(content)
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
cfg = PlatformConfig(enabled=True)
|
||||
adapter = FakeAdapter(cfg, Platform.TELEGRAM)
|
||||
adapter._running = True
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_sent_skips_send(self):
|
||||
"""Handler returning already_sent=True prevents base from calling send()."""
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
sent = []
|
||||
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
|
||||
|
||||
async def handler(event):
|
||||
return {"content": "hello", "already_sent": True}
|
||||
adapter.set_message_handler(handler)
|
||||
|
||||
event = MessageEvent(
|
||||
text="hi",
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
|
||||
)
|
||||
await adapter._process_message_background(event, "s1")
|
||||
|
||||
assert sent == [], "send() should not be called when already_sent=True"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_response_sends_normally(self):
|
||||
"""Handler returning a plain string triggers send() as before."""
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
sent = []
|
||||
adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c))
|
||||
|
||||
async def handler(event):
|
||||
return "hello"
|
||||
adapter.set_message_handler(handler)
|
||||
|
||||
event = MessageEvent(
|
||||
text="hi",
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"),
|
||||
)
|
||||
await adapter._process_message_background(event, "s1")
|
||||
|
||||
assert "hello" in sent
|
||||
@@ -505,6 +505,25 @@ class TestExpandPath:
|
||||
assert result == str(Path.home())
|
||||
_assert_clean(result)
|
||||
|
||||
def test_tilde_injection_blocked(self, ops):
|
||||
"""Paths like ~; rm -rf / must NOT execute shell commands."""
|
||||
malicious = "~; echo PWNED > /tmp/_hermes_injection_test"
|
||||
result = ops._expand_path(malicious)
|
||||
# The invalid username (contains ";") should prevent shell expansion.
|
||||
# The path should be returned as-is (no expansion).
|
||||
assert result == malicious
|
||||
# Verify the injected command did NOT execute
|
||||
import os
|
||||
assert not os.path.exists("/tmp/_hermes_injection_test")
|
||||
|
||||
def test_tilde_username_with_subpath(self, ops):
|
||||
"""~root/file.txt should attempt expansion (valid username)."""
|
||||
result = ops._expand_path("~root/file.txt")
|
||||
# On most systems ~root expands to /root
|
||||
if result != "~root/file.txt":
|
||||
assert result.endswith("/file.txt")
|
||||
assert "~" not in result
|
||||
|
||||
|
||||
# ── Terminal output cleanliness ──────────────────────────────────────────
|
||||
|
||||
|
||||
351
tests/tools/test_vision_tools.py
Normal file
351
tests/tools/test_vision_tools.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""Tests for tools/vision_tools.py — URL validation, type hints, error logging."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Awaitable
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.vision_tools import (
|
||||
_validate_image_url,
|
||||
_handle_vision_analyze,
|
||||
_determine_mime_type,
|
||||
_image_to_base64_data_url,
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements,
|
||||
get_debug_session_info,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_image_url — urlparse-based validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateImageUrl:
|
||||
"""Tests for URL validation, including urlparse-based netloc check."""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
assert _validate_image_url("https://example.com/image.jpg") is True
|
||||
|
||||
def test_valid_http_url(self):
|
||||
assert _validate_image_url("http://cdn.example.org/photo.png") is True
|
||||
|
||||
def test_valid_url_without_extension(self):
|
||||
"""CDN endpoints that redirect to images should still pass."""
|
||||
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
|
||||
|
||||
def test_valid_url_with_query_params(self):
|
||||
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
|
||||
|
||||
def test_valid_url_with_port(self):
|
||||
assert _validate_image_url("http://localhost:8080/image.png") is True
|
||||
|
||||
def test_valid_url_with_path_only(self):
|
||||
assert _validate_image_url("https://example.com/") is True
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
assert _validate_image_url("") is False
|
||||
|
||||
def test_rejects_none(self):
|
||||
assert _validate_image_url(None) is False
|
||||
|
||||
def test_rejects_non_string(self):
|
||||
assert _validate_image_url(12345) is False
|
||||
|
||||
def test_rejects_ftp_scheme(self):
|
||||
assert _validate_image_url("ftp://files.example.com/image.jpg") is False
|
||||
|
||||
def test_rejects_file_scheme(self):
|
||||
assert _validate_image_url("file:///etc/passwd") is False
|
||||
|
||||
def test_rejects_no_scheme(self):
|
||||
assert _validate_image_url("example.com/image.jpg") is False
|
||||
|
||||
def test_rejects_javascript_scheme(self):
|
||||
assert _validate_image_url("javascript:alert(1)") is False
|
||||
|
||||
def test_rejects_http_without_netloc(self):
|
||||
"""http:// alone has no network location — urlparse catches this."""
|
||||
assert _validate_image_url("http://") is False
|
||||
|
||||
def test_rejects_https_without_netloc(self):
|
||||
assert _validate_image_url("https://") is False
|
||||
|
||||
def test_rejects_http_colon_only(self):
|
||||
assert _validate_image_url("http:") is False
|
||||
|
||||
def test_rejects_data_url(self):
|
||||
assert _validate_image_url("data:image/png;base64,iVBOR") is False
|
||||
|
||||
def test_rejects_whitespace_only(self):
|
||||
assert _validate_image_url(" ") is False
|
||||
|
||||
def test_rejects_boolean(self):
|
||||
assert _validate_image_url(True) is False
|
||||
|
||||
def test_rejects_list(self):
|
||||
assert _validate_image_url(["https://example.com"]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _determine_mime_type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDetermineMimeType:
|
||||
def test_jpg(self):
|
||||
assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
|
||||
|
||||
def test_jpeg(self):
|
||||
assert _determine_mime_type(Path("photo.jpeg")) == "image/jpeg"
|
||||
|
||||
def test_png(self):
|
||||
assert _determine_mime_type(Path("screenshot.png")) == "image/png"
|
||||
|
||||
def test_gif(self):
|
||||
assert _determine_mime_type(Path("anim.gif")) == "image/gif"
|
||||
|
||||
def test_webp(self):
|
||||
assert _determine_mime_type(Path("modern.webp")) == "image/webp"
|
||||
|
||||
def test_unknown_extension_defaults_to_jpeg(self):
|
||||
assert _determine_mime_type(Path("file.xyz")) == "image/jpeg"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _image_to_base64_data_url
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImageToBase64DataUrl:
|
||||
def test_returns_data_url(self, tmp_path):
|
||||
img = tmp_path / "test.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
|
||||
result = _image_to_base64_data_url(img)
|
||||
assert result.startswith("data:image/png;base64,")
|
||||
|
||||
def test_custom_mime_type(self, tmp_path):
|
||||
img = tmp_path / "test.bin"
|
||||
img.write_bytes(b"\x00" * 16)
|
||||
result = _image_to_base64_data_url(img, mime_type="image/webp")
|
||||
assert result.startswith("data:image/webp;base64,")
|
||||
|
||||
def test_file_not_found_raises(self, tmp_path):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_image_to_base64_data_url(tmp_path / "nonexistent.png")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_vision_analyze — type signature & behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHandleVisionAnalyze:
|
||||
"""Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
|
||||
|
||||
def test_returns_awaitable(self):
|
||||
"""The handler must return an Awaitable (coroutine) since it's registered as async."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
result = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "What is this?"}
|
||||
)
|
||||
# It should be an Awaitable (coroutine)
|
||||
assert isinstance(result, Awaitable)
|
||||
# Clean up the coroutine to avoid RuntimeWarning
|
||||
result.close()
|
||||
|
||||
def test_prompt_contains_question(self):
|
||||
"""The full prompt should incorporate the user's question."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "Describe the cat"}
|
||||
)
|
||||
# Clean up coroutine
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
full_prompt = call_args[0][1] # second positional arg
|
||||
assert "Describe the cat" in full_prompt
|
||||
assert "Fully describe and explain" in full_prompt
|
||||
|
||||
def test_uses_auxiliary_vision_model_env(self):
|
||||
"""AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
||||
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}):
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "test"}
|
||||
)
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
model = call_args[0][2] # third positional arg
|
||||
assert model == "custom/model-v1"
|
||||
|
||||
def test_falls_back_to_default_model(self):
|
||||
"""Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
||||
patch.dict(os.environ, {}, clear=False):
|
||||
# Ensure AUXILIARY_VISION_MODEL is not set
|
||||
os.environ.pop("AUXILIARY_VISION_MODEL", None)
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
coro = _handle_vision_analyze(
|
||||
{"image_url": "https://example.com/img.png", "question": "test"}
|
||||
)
|
||||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
model = call_args[0][2]
|
||||
# Should be DEFAULT_VISION_MODEL or the hardcoded fallback
|
||||
assert model is not None
|
||||
assert len(model) > 0
|
||||
|
||||
def test_empty_args_graceful(self):
|
||||
"""Missing keys should default to empty strings, not raise."""
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||
result = _handle_vision_analyze({})
|
||||
assert isinstance(result, Awaitable)
|
||||
result.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error logging with exc_info — verify tracebacks are logged
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestErrorLoggingExcInfo:
|
||||
"""Verify that exc_info=True is used in error/warning log calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_logs_exc_info(self, tmp_path, caplog):
|
||||
"""After max retries, the download error should include exc_info."""
|
||||
from tools.vision_tools import _download_image
|
||||
|
||||
with patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(side_effect=ConnectionError("network down"))
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
dest = tmp_path / "image.jpg"
|
||||
with caplog.at_level(logging.ERROR, logger="tools.vision_tools"), \
|
||||
pytest.raises(ConnectionError):
|
||||
await _download_image("https://example.com/img.jpg", dest, max_retries=1)
|
||||
|
||||
# Should have logged with exc_info (traceback present)
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert len(error_records) >= 1
|
||||
assert error_records[0].exc_info is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analysis_error_logs_exc_info(self, caplog):
|
||||
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
|
||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
||||
patch("tools.vision_tools._download_image", new_callable=AsyncMock,
|
||||
side_effect=Exception("download boom")), \
|
||||
caplog.at_level(logging.ERROR, logger="tools.vision_tools"):
|
||||
|
||||
result = await vision_analyze_tool(
|
||||
"https://example.com/img.jpg", "describe this", "test/model"
|
||||
)
|
||||
result_data = json.loads(result)
|
||||
# Error response uses "success": False, not an "error" key
|
||||
assert result_data["success"] is False
|
||||
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert any(r.exc_info is not None for r in error_records)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_error_logs_exc_info(self, tmp_path, caplog):
|
||||
"""Temp file cleanup failure should log warning with exc_info."""
|
||||
# Create a real temp file that will be "downloaded"
|
||||
temp_dir = tmp_path / "temp_vision_images"
|
||||
temp_dir.mkdir()
|
||||
|
||||
async def fake_download(url, dest, max_retries=3):
|
||||
"""Simulate download by writing file to the expected destination."""
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
|
||||
return dest
|
||||
|
||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
||||
patch("tools.vision_tools._download_image", side_effect=fake_download), \
|
||||
patch("tools.vision_tools._image_to_base64_data_url",
|
||||
return_value="data:image/jpeg;base64,abc"), \
|
||||
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None), \
|
||||
patch("agent.auxiliary_client.auxiliary_max_tokens_param", return_value={"max_tokens": 2000}), \
|
||||
caplog.at_level(logging.WARNING, logger="tools.vision_tools"):
|
||||
|
||||
# Mock the vision client
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "A test image description"
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch module-level _aux_async_client so the tool doesn't bail early
|
||||
with patch("tools.vision_tools._aux_async_client", mock_client), \
|
||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"):
|
||||
|
||||
# Make unlink fail to trigger cleanup warning
|
||||
original_unlink = Path.unlink
|
||||
def failing_unlink(self, *args, **kwargs):
|
||||
raise PermissionError("no permission")
|
||||
|
||||
with patch.object(Path, "unlink", failing_unlink):
|
||||
result = await vision_analyze_tool(
|
||||
"https://example.com/tempimg.jpg", "describe", "test/model"
|
||||
)
|
||||
|
||||
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING
|
||||
and "temporary file" in r.getMessage().lower()]
|
||||
assert len(warning_records) >= 1
|
||||
assert warning_records[0].exc_info is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_vision_requirements & get_debug_session_info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVisionRequirements:
|
||||
def test_check_requirements_returns_bool(self):
|
||||
result = check_vision_requirements()
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_debug_session_info_returns_dict(self):
|
||||
info = get_debug_session_info()
|
||||
assert isinstance(info, dict)
|
||||
# DebugSession.get_session_info() returns these keys
|
||||
assert "enabled" in info
|
||||
assert "session_id" in info
|
||||
assert "total_calls" in info
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: registry entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVisionRegistration:
|
||||
def test_vision_analyze_registered(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
assert entry is not None
|
||||
assert entry.toolset == "vision"
|
||||
assert entry.is_async is True
|
||||
|
||||
def test_schema_has_required_fields(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
schema = entry.schema
|
||||
assert schema["name"] == "vision_analyze"
|
||||
params = schema.get("parameters", {})
|
||||
props = params.get("properties", {})
|
||||
assert "image_url" in props
|
||||
assert "question" in props
|
||||
|
||||
def test_handler_is_callable(self):
|
||||
from tools.registry import registry
|
||||
entry = registry._tools.get("vision_analyze")
|
||||
assert callable(entry.handler)
|
||||
@@ -22,10 +22,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Security flags applied to every container.
|
||||
# The container itself is the security boundary (isolated from host).
|
||||
# We drop all capabilities, block privilege escalation, and limit PIDs.
|
||||
# We drop all capabilities then add back the minimum needed:
|
||||
# DAC_OVERRIDE - root can write to bind-mounted dirs owned by host user
|
||||
# CHOWN/FOWNER - package managers (pip, npm, apt) need to set file ownership
|
||||
# Block privilege escalation and limit PIDs.
|
||||
# /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds).
|
||||
_SECURITY_ARGS = [
|
||||
"--cap-drop", "ALL",
|
||||
"--cap-add", "DAC_OVERRIDE",
|
||||
"--cap-add", "CHOWN",
|
||||
"--cap-add", "FOWNER",
|
||||
"--security-opt", "no-new-privileges",
|
||||
"--pids-limit", "256",
|
||||
"--tmpfs", "/tmp:rw,nosuid,size=512m",
|
||||
|
||||
@@ -400,10 +400,16 @@ class ShellFileOperations(FileOperations):
|
||||
return home
|
||||
elif path.startswith('~/'):
|
||||
return home + path[1:] # Replace ~ with home
|
||||
# ~username format - let shell expand it
|
||||
expand_result = self._exec(f"echo {path}")
|
||||
if expand_result.exit_code == 0:
|
||||
return expand_result.stdout.strip()
|
||||
# ~username format - extract and validate username before
|
||||
# letting shell expand it (prevent shell injection via
|
||||
# paths like "~; rm -rf /").
|
||||
rest = path[1:] # strip leading ~
|
||||
slash_idx = rest.find('/')
|
||||
username = rest[:slash_idx] if slash_idx >= 0 else rest
|
||||
if username and re.fullmatch(r'[a-zA-Z0-9._-]+', username):
|
||||
expand_result = self._exec(f"echo {path}")
|
||||
if expand_result.exit_code == 0 and expand_result.stdout.strip():
|
||||
return expand_result.stdout.strip()
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@@ -27,14 +27,15 @@ Usage:
|
||||
)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any, Awaitable, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
from agent.auxiliary_client import get_vision_auxiliary_client
|
||||
@@ -73,15 +74,18 @@ def _validate_image_url(url: str) -> bool:
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
return False
|
||||
|
||||
# Check if it's a valid URL format
|
||||
if not (url.startswith('http://') or url.startswith('https://')):
|
||||
|
||||
# Basic HTTP/HTTPS URL check
|
||||
if not (url.startswith("http://") or url.startswith("https://")):
|
||||
return False
|
||||
|
||||
# Check for common image extensions (optional, as URLs may not have extensions)
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg']
|
||||
|
||||
return True # Allow all HTTP/HTTPS URLs for flexibility
|
||||
|
||||
# Parse to ensure we at least have a network location; still allow URLs
|
||||
# without file extensions (e.g. CDN endpoints that redirect to images).
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc:
|
||||
return False
|
||||
|
||||
return True # Allow all well-formed HTTP/HTTPS URLs for flexibility
|
||||
|
||||
|
||||
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
|
||||
@@ -131,7 +135,12 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
||||
logger.warning("Retrying in %ss...", wait_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error("Image download failed after %s attempts: %s", max_retries, str(e)[:100])
|
||||
logger.error(
|
||||
"Image download failed after %s attempts: %s",
|
||||
max_retries,
|
||||
str(e)[:100],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
raise last_error
|
||||
|
||||
@@ -188,7 +197,7 @@ def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None)
|
||||
async def vision_analyze_tool(
|
||||
image_url: str,
|
||||
user_prompt: str,
|
||||
model: str = DEFAULT_VISION_MODEL
|
||||
model: str = DEFAULT_VISION_MODEL,
|
||||
) -> str:
|
||||
"""
|
||||
Analyze an image from a URL or local file path using vision AI.
|
||||
@@ -347,7 +356,7 @@ async def vision_analyze_tool(
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing image: {str(e)}"
|
||||
logger.error("%s", error_msg)
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
|
||||
# Prepare error response
|
||||
result = {
|
||||
@@ -368,7 +377,9 @@ async def vision_analyze_tool(
|
||||
temp_image_path.unlink()
|
||||
logger.debug("Cleaned up temporary image file")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning("Could not delete temporary file: %s", cleanup_error)
|
||||
logger.warning(
|
||||
"Could not delete temporary file: %s", cleanup_error, exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def check_vision_requirements() -> bool:
|
||||
@@ -464,10 +475,13 @@ VISION_ANALYZE_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
def _handle_vision_analyze(args, **kw):
|
||||
def _handle_vision_analyze(args: Dict[str, Any], **kw: Any) -> Awaitable[str]:
|
||||
image_url = args.get("image_url", "")
|
||||
question = args.get("question", "")
|
||||
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
|
||||
full_prompt = (
|
||||
"Fully describe and explain everything about this image, then answer the "
|
||||
f"following question:\n\n{question}"
|
||||
)
|
||||
model = (os.getenv("AUXILIARY_VISION_MODEL", "").strip()
|
||||
or DEFAULT_VISION_MODEL
|
||||
or "google/gemini-3-flash-preview")
|
||||
|
||||
@@ -393,8 +393,40 @@ terminal:
|
||||
backend: local # or: docker, ssh, singularity, modal, daytona
|
||||
cwd: "." # Working directory ("." = current dir)
|
||||
timeout: 180 # Command timeout in seconds
|
||||
|
||||
# Docker-specific settings
|
||||
docker_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
docker_volumes: # Share host directories with the container
|
||||
- "/home/user/projects:/workspace/projects"
|
||||
- "/home/user/data:/data:ro" # :ro for read-only
|
||||
|
||||
# Container resource limits (docker, singularity, modal, daytona)
|
||||
container_cpu: 1 # CPU cores
|
||||
container_memory: 5120 # MB (default 5GB)
|
||||
container_disk: 51200 # MB (default 50GB)
|
||||
container_persistent: true # Persist filesystem across sessions
|
||||
```
|
||||
|
||||
### Docker Volume Mounts
|
||||
|
||||
When using the Docker backend, `docker_volumes` lets you share host directories with the container. Each entry uses standard Docker `-v` syntax: `host_path:container_path[:options]`.
|
||||
|
||||
```yaml
|
||||
terminal:
|
||||
backend: docker
|
||||
docker_volumes:
|
||||
- "/home/user/projects:/workspace/projects" # Read-write (default)
|
||||
- "/home/user/datasets:/data:ro" # Read-only
|
||||
- "/home/user/outputs:/outputs" # Agent writes, you read
|
||||
```
|
||||
|
||||
This is useful for:
|
||||
- **Providing files** to the agent (datasets, configs, reference code)
|
||||
- **Receiving files** from the agent (generated code, reports, exports)
|
||||
- **Shared workspaces** where both you and the agent access the same files
|
||||
|
||||
Can also be set via environment variable: `TERMINAL_DOCKER_VOLUMES='["/host:/container"]'` (JSON array).
|
||||
|
||||
See [Code Execution](features/code-execution.md) and the [Terminal section of the README](features/tools.md) for details on each backend.
|
||||
|
||||
## Memory Configuration
|
||||
|
||||
@@ -46,20 +46,26 @@ Navigate to **Features → OAuth & Permissions** in the sidebar. Scroll to **Sco
|
||||
| Scope | Purpose |
|
||||
|-------|---------|
|
||||
| `chat:write` | Send messages as the bot |
|
||||
| `app_mentions:read` | Respond when @mentioned in channels |
|
||||
| `app_mentions:read` | Detect when @mentioned in channels |
|
||||
| `channels:history` | Read messages in public channels the bot is in |
|
||||
| `channels:read` | List and get info about public channels |
|
||||
| `groups:history` | Read messages in private channels the bot is invited to |
|
||||
| `im:history` | Read direct message history |
|
||||
| `im:read` | View basic DM info |
|
||||
| `im:write` | Open and manage DMs |
|
||||
| `users:read` | Look up user information |
|
||||
| `files:write` | Upload files (images, audio, documents) |
|
||||
|
||||
:::caution Missing scopes = missing features
|
||||
Without `channels:history` and `groups:history`, the bot **will not receive messages in channels** —
|
||||
it will only work in DMs. These are the most commonly missed scopes.
|
||||
:::
|
||||
|
||||
**Optional scopes:**
|
||||
|
||||
| Scope | Purpose |
|
||||
|-------|---------|
|
||||
| `groups:history` | Read messages in private channels the bot is invited to |
|
||||
| `files:write` | Upload files (audio, images) |
|
||||
| `groups:read` | List and get info about private channels |
|
||||
|
||||
---
|
||||
|
||||
@@ -83,23 +89,27 @@ You can always find or regenerate app-level tokens under **Settings → Basic In
|
||||
|
||||
## Step 4: Subscribe to Events
|
||||
|
||||
This step is critical — it controls what messages the bot can see.
|
||||
|
||||
1. In the sidebar, go to **Features → Event Subscriptions**
|
||||
2. Toggle **Enable Events** to ON
|
||||
3. Expand **Subscribe to bot events** and add:
|
||||
|
||||
| Event | Purpose |
|
||||
|-------|---------|
|
||||
| `app_mention` | Bot responds when @mentioned in any channel |
|
||||
| `message.im` | Bot responds to direct messages |
|
||||
|
||||
**Optional event:**
|
||||
|
||||
| Event | Purpose |
|
||||
|-------|---------|
|
||||
| `message.channels` | Bot sees all messages in public channels it's added to |
|
||||
| Event | Required? | Purpose |
|
||||
|-------|-----------|---------|
|
||||
| `message.im` | **Yes** | Bot receives direct messages |
|
||||
| `message.channels` | **Yes** | Bot receives messages in **public** channels it's added to |
|
||||
| `message.groups` | **Recommended** | Bot receives messages in **private** channels it's invited to |
|
||||
| `app_mention` | **Yes** | Prevents Bolt SDK errors when bot is @mentioned |
|
||||
|
||||
4. Click **Save Changes** at the bottom of the page
|
||||
|
||||
:::danger Missing event subscriptions is the #1 setup issue
|
||||
If the bot works in DMs but **not in channels**, you almost certainly forgot to add
|
||||
`message.channels` (for public channels) and/or `message.groups` (for private channels).
|
||||
Without these events, Slack simply never delivers channel messages to the bot.
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
## Step 5: Install App to Workspace
|
||||
@@ -111,8 +121,8 @@ You can always find or regenerate app-level tokens under **Settings → Basic In
|
||||
5. **Copy this token** — this is your `SLACK_BOT_TOKEN`
|
||||
|
||||
:::tip
|
||||
If you change scopes later, you'll need to **reinstall the app** for the new scopes to take effect.
|
||||
The Install App page will show a banner prompting you to do so.
|
||||
If you change scopes or event subscriptions later, you **must reinstall the app** for the changes
|
||||
to take effect. The Install App page will show a banner prompting you to do so.
|
||||
:::
|
||||
|
||||
---
|
||||
@@ -139,7 +149,7 @@ Add the following to your `~/.hermes/.env` file:
|
||||
```bash
|
||||
# Required
|
||||
SLACK_BOT_TOKEN=xoxb-your-bot-token-here
|
||||
SLACK_APP_TOKEN=xapp-your-app-level-token-here
|
||||
SLACK_APP_TOKEN=xapp-your-app-token-here
|
||||
SLACK_ALLOWED_USERS=U01ABC2DEF3 # Comma-separated Member IDs
|
||||
|
||||
# Optional
|
||||
@@ -161,6 +171,35 @@ hermes gateway install # Install as a system service
|
||||
|
||||
---
|
||||
|
||||
## Step 8: Invite the Bot to Channels
|
||||
|
||||
After starting the gateway, you need to **invite the bot** to any channel where you want it to respond:
|
||||
|
||||
```
|
||||
/invite @Hermes Agent
|
||||
```
|
||||
|
||||
The bot will **not** automatically join channels. You must invite it to each channel individually.
|
||||
|
||||
---
|
||||
|
||||
## How the Bot Responds
|
||||
|
||||
Understanding how Hermes behaves in different contexts:
|
||||
|
||||
| Context | Behavior |
|
||||
|---------|----------|
|
||||
| **DMs** | Bot responds to every message — no @mention needed |
|
||||
| **Channels** | Bot **only responds when @mentioned** (e.g., `@Hermes Agent what time is it?`) |
|
||||
| **Threads** | Bot replies in threads when the triggering message is in a thread |
|
||||
|
||||
:::tip
|
||||
In channels, always @mention the bot. Simply typing a message without mentioning it will be ignored.
|
||||
This is intentional — it prevents the bot from responding to every message in busy channels.
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
## Home Channel
|
||||
|
||||
Set `SLACK_HOME_CHANNEL` to a channel ID where Hermes will deliver scheduled messages,
|
||||
@@ -192,11 +231,27 @@ Hermes supports voice on Slack:
|
||||
| Problem | Solution |
|
||||
|---------|----------|
|
||||
| Bot doesn't respond to DMs | Verify `message.im` is in your event subscriptions and the app is reinstalled |
|
||||
| Bot doesn't respond to @mentions | Verify `app_mention` is in your event subscriptions |
|
||||
| Bot works in DMs but not in channels | **Most common issue.** Add `message.channels` and `message.groups` to event subscriptions, reinstall the app, and invite the bot to the channel with `/invite @Hermes Agent` |
|
||||
| Bot doesn't respond to @mentions in channels | 1) Check `message.channels` event is subscribed. 2) Bot must be invited to the channel. 3) Ensure `channels:history` scope is added. 4) Reinstall the app after scope/event changes |
|
||||
| Bot ignores messages in private channels | Add both the `message.groups` event subscription and `groups:history` scope, then reinstall the app and `/invite` the bot |
|
||||
| "not_authed" or "invalid_auth" errors | Regenerate your Bot Token and App Token, update `.env` |
|
||||
| Bot responds but can't post in a channel | Invite the bot to the channel with `/invite @Hermes Agent` |
|
||||
| "missing_scope" error | Add the required scope in OAuth & Permissions, then **reinstall** the app |
|
||||
| Socket disconnects frequently | Check your network; Bolt auto-reconnects but unstable connections cause lag |
|
||||
| Changed scopes/events but nothing changed | You **must reinstall** the app to your workspace after any scope or event subscription change |
|
||||
|
||||
### Quick Checklist
|
||||
|
||||
If the bot isn't working in channels, verify **all** of the following:
|
||||
|
||||
1. ✅ `message.channels` event is subscribed (for public channels)
|
||||
2. ✅ `message.groups` event is subscribed (for private channels)
|
||||
3. ✅ `app_mention` event is subscribed
|
||||
4. ✅ `channels:history` scope is added (for public channels)
|
||||
5. ✅ `groups:history` scope is added (for private channels)
|
||||
6. ✅ App was **reinstalled** after adding scopes/events
|
||||
7. ✅ Bot was **invited** to the channel (`/invite @Hermes Agent`)
|
||||
8. ✅ You are **@mentioning** the bot in your message
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user