Compare commits

...

1 Commits

Author SHA1 Message Date
Brooklyn Nicholson
f4d7e6a29e feat: devex help, add Makefile, ruff, pre-commit, and modernize CI 2026-03-09 20:36:51 -05:00
111 changed files with 11655 additions and 10200 deletions

18
.editorconfig Normal file
View File

@@ -0,0 +1,18 @@
root = true
[*]
indent_style = space
indent_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.{yml,yaml,json,toml}]
indent_size = 2
[*.md]
trim_trailing_whitespace = false
[Makefile]
indent_style = tab

View File

@@ -46,7 +46,7 @@ Fixes #
- [ ] My commit messages follow [Conventional Commits](https://www.conventionalcommits.org/) (`fix(scope):`, `feat(scope):`, etc.)
- [ ] I searched for [existing PRs](https://github.com/NousResearch/hermes-agent/pulls) to make sure this isn't a duplicate
- [ ] My PR contains **only** changes related to this fix/feature (no unrelated commits)
- [ ] I've run `pytest tests/ -q` and all tests pass
- [ ] I've run `make check` (lint + test) and all checks pass
- [ ] I've added tests for my changes (required for bug fixes, strongly encouraged for features)
- [ ] I've tested on my platform: <!-- e.g. Ubuntu 24.04, macOS 15.2, Windows 11 -->

View File

@@ -1,4 +1,4 @@
name: Tests
name: CI
on:
push:
@@ -6,37 +6,42 @@ on:
pull_request:
branches: [main]
# Cancel in-progress runs for the same PR/branch
concurrency:
group: tests-${{ github.ref }}
group: ci-${{ github.ref }}
cancel-in-progress: true
env:
SRC: >-
run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py
tools/ hermes_cli/ gateway/ agent/ cron/
jobs:
lint:
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- run: uvx ruff check $SRC
- run: uvx ruff format --check $SRC
test:
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
- name: Set up Python 3.11
run: uv python install 3.11
- name: Install dependencies
run: |
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
with:
enable-cache: true
- run: uv python install 3.11
- run: |
uv venv .venv --python 3.11
source .venv/bin/activate
uv pip install -e ".[all,dev]"
- name: Run tests
run: |
- run: |
source .venv/bin/activate
python -m pytest tests/ -q --ignore=tests/integration --tb=short
env:
# Ensure tests don't accidentally call real APIs
OPENROUTER_API_KEY: ""
OPENAI_API_KEY: ""
NOUS_API_KEY: ""

80
.gitignore vendored
View File

@@ -1,51 +1,53 @@
/venv/
/_pycache/
*.pyc*
# Python
__pycache__/
*.pyc
*.pyo
*.egg-info/
dist/
build/
# Environments
.venv/
venv/
# Tools
.ruff_cache/
.mypy_cache/
.pytest_cache/
# Editors
.vscode/
.idea/
# Secrets & config
.env
.env.local
.env.development.local
.env.test.local
.env.production.local
.env.development
.env.test
export*
__pycache__/model_tools.cpython-310.pyc
__pycache__/web_tools.cpython-310.pyc
.env.*.local
*.pem
*.ppk
# Node
node_modules/
# Project-specific
logs/
data/
.pytest_cache/
tmp/
temp_vision_images/
hermes-*/*
examples/
tests/quick_test_dataset.jsonl
tests/sample_dataset.jsonl
run_datagen_kimik2-thinking.sh
run_datagen_megascience_glm4-6.sh
run_datagen_sonnet.sh
source-data/*
run_datagen_megascience_glm4-6.sh
data/*
node_modules/
wandb/
images/
browser-use/
agent-browser/
# Private keys
*.ppk
*.pem
privvy*
images/
__pycache__/
hermes_agent.egg-info/
wandb/
testlogs
# CLI config (may contain sensitive SSH paths)
source-data/
testlogs/
ignored/
.worktrees/
temp_vision_images/
cli-config.yaml
# Skills Hub state (lives in ~/.hermes/skills/.hub/ at runtime, but just in case)
skills/.hub/
ignored/
.worktrees/
hermes-*/*
examples/
export*
privvy*
run_datagen_*.sh
tests/quick_test_dataset.jsonl
tests/sample_dataset.jsonl

18
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,18 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.5
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-merge-conflict
- id: check-yaml
args: [--allow-multiple-documents]
- id: check-added-large-files
args: [--maxkb=500]

View File

@@ -5,7 +5,8 @@ Instructions for AI coding assistants and developers working on the hermes-agent
## Development Environment
```bash
source .venv/bin/activate # ALWAYS activate before running Python
make setup # First time: creates .venv, installs deps, sets up pre-commit
source .venv/bin/activate
```
## Project Structure
@@ -228,15 +229,27 @@ The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HER
---
## Testing
## Development Commands
```bash
make setup # First time: .venv + deps + pre-commit hooks
make check # Lint + test (mirrors CI — run before pushing)
make lint # Ruff check
make fmt # Ruff format + auto-fix
make test # Full test suite (~2500 tests, ~2 min)
make test-fast # Tests with fail-fast (-x)
make test-watch # Rerun tests on file changes
make dev-cli # Auto-restart CLI on file changes
make dev-gateway # Auto-restart gateway on file changes
```
For targeted testing, use `pytest` directly:
```bash
source .venv/bin/activate
python -m pytest tests/ -q # Full suite (~2500 tests, ~2 min)
python -m pytest tests/test_model_tools.py -q # Toolset resolution
python -m pytest tests/test_cli_init.py -q # CLI config loading
python -m pytest tests/gateway/ -q # Gateway tests
python -m pytest tests/tools/ -q # Tool-level tests
```
Always run the full suite before pushing changes.
Formatting is enforced by **ruff** (config in `pyproject.toml`). Pre-commit hooks run on every commit.

View File

@@ -65,18 +65,7 @@ If your skill is specialized, community-contributed, or niche, it's better suite
```bash
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
cd hermes-agent
# Create venv with Python 3.11
uv venv venv --python 3.11
export VIRTUAL_ENV="$(pwd)/venv"
# Install with all extras (messaging, cron, CLI menus, dev tools)
uv pip install -e ".[all,dev]"
uv pip install -e "./mini-swe-agent"
uv pip install -e "./tinker-atropos"
# Optional: browser tools
npm install
make setup # creates .venv, installs all deps
```
### Configure for development
@@ -90,22 +79,16 @@ touch ~/.hermes/.env
echo 'OPENROUTER_API_KEY=sk-or-v1-your-key' >> ~/.hermes/.env
```
### Run
### Common commands
```bash
# Symlink for global access
mkdir -p ~/.local/bin
ln -sf "$(pwd)/venv/bin/hermes" ~/.local/bin/hermes
# Verify
hermes doctor
hermes chat -q "Hello"
```
### Run tests
```bash
pytest tests/ -v
make test # run unit tests
make lint # ruff check
make fmt # ruff format + fix
make check # lint + test (same as CI)
make dev-cli # auto-restart hermes CLI on file changes
make dev-gateway # auto-restart gateway on file changes
make test-watch # rerun tests on file changes
```
---
@@ -227,7 +210,7 @@ User message → AIAgent._run_agent_loop()
## Code Style
- **PEP 8** with practical exceptions (we don't enforce strict line length)
- **Formatting**: Enforced by **ruff** (config in `pyproject.toml`). Run `make fmt` to auto-fix, `make lint` to check. Pre-commit hooks handle this automatically.
- **Comments**: Only when explaining non-obvious intent, trade-offs, or API quirks. Don't narrate what the code does — `# increment counter` adds nothing
- **Error handling**: Catch specific exceptions. Log with `logger.warning()`/`logger.error()` — use `exc_info=True` for unexpected errors so stack traces appear in logs
- **Cross-platform**: Never assume Unix. See [Cross-Platform Compatibility](#cross-platform-compatibility)
@@ -457,7 +440,7 @@ refactor/description # Code restructuring
### Before submitting
1. **Run tests**: `pytest tests/ -v`
1. **Run checks**: `make check` (lint + test — same as CI)
2. **Test manually**: Run `hermes` and exercise the code path you changed
3. **Check cross-platform impact**: If you touch file I/O, process management, or terminal handling, consider Windows and macOS
4. **Keep PRs focused**: One logical change per PR. Don't mix a bug fix with a refactor with a new feature.

69
Makefile Normal file
View File

@@ -0,0 +1,69 @@
.DEFAULT_GOAL := help
SHELL := /bin/bash
VENV := .venv
UV := uv
SRC := run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py \
tools/ hermes_cli/ gateway/ agent/ cron/
# ─── Setup ──────────────────────────────────────────────────────────────────────
.PHONY: setup sync clean
setup: ## Full dev setup (venv + deps + pre-commit)
$(UV) venv $(VENV) --python 3.11
. $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]"
. $(VENV)/bin/activate && $(UV) pip install -e "./mini-swe-agent"
. $(VENV)/bin/activate && pre-commit install
@echo "\n✅ Setup complete. Run: source $(VENV)/bin/activate"
sync: ## Reinstall deps into existing venv
. $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]"
clean: ## Remove build artifacts and caches
rm -rf .ruff_cache .mypy_cache .pytest_cache dist build *.egg-info
find . -type d -name __pycache__ -not -path "./.venv/*" -exec rm -rf {} +
# ─── Quality ────────────────────────────────────────────────────────────────────
.PHONY: lint fmt check
lint: ## Check lint + formatting (no changes)
. $(VENV)/bin/activate && ruff check $(SRC)
. $(VENV)/bin/activate && ruff format --check $(SRC)
fmt: ## Auto-fix lint + format
. $(VENV)/bin/activate && ruff format $(SRC)
. $(VENV)/bin/activate && ruff check --fix $(SRC)
check: lint test ## Lint + test (mirrors CI)
# ─── Test ───────────────────────────────────────────────────────────────────────
.PHONY: test test-fast test-watch
test: ## Run full test suite
. $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short
test-fast: ## Run tests with fail-fast
. $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short -x
test-watch: ## Rerun tests on file changes
. $(VENV)/bin/activate && python -m watchfiles "python -m pytest tests/ -q --ignore=tests/integration --tb=short -x" $(SRC) tests/
# ─── Dev Servers ────────────────────────────────────────────────────────────────
.PHONY: dev-cli dev-gateway
dev-cli: ## Auto-restart CLI on file changes
. $(VENV)/bin/activate && python -m watchfiles "python -m hermes_cli.main" $(SRC)
dev-gateway: ## Auto-restart gateway on file changes
. $(VENV)/bin/activate && python -m watchfiles "python -m gateway.run" $(SRC)
# ─── Misc ───────────────────────────────────────────────────────────────────────
.PHONY: help
help: ## Show this help
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'

View File

@@ -95,12 +95,8 @@ Quick start for contributors:
```bash
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
cd hermes-agent
curl -LsSf https://astral.sh/uv/install.sh | sh
uv venv .venv --python 3.11
source .venv/bin/activate
uv pip install -e ".[all,dev]"
uv pip install -e "./mini-swe-agent"
python -m pytest tests/ -q
make setup # creates .venv, installs everything
make check # lint + test (same as CI)
```
---

View File

@@ -34,7 +34,7 @@ import logging
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
from openai import OpenAI
@@ -43,7 +43,7 @@ from hermes_constants import OPENROUTER_BASE_URL
logger = logging.getLogger(__name__)
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
_API_KEY_PROVIDER_AUX_MODELS: dict[str, str] = {
"zai": "glm-4.5-flash",
"kimi-coding": "kimi-k2-turbo-preview",
"minimax": "MiniMax-M2.5-highspeed",
@@ -102,7 +102,7 @@ def _convert_content_for_responses(content: Any) -> Any:
if not isinstance(content, list):
return str(content) if content else ""
converted: List[Dict[str, Any]] = []
converted: list[dict[str, Any]] = []
for part in content:
if not isinstance(part, dict):
continue
@@ -113,7 +113,7 @@ def _convert_content_for_responses(content: Any) -> Any:
# chat.completions nests the URL: {"image_url": {"url": "..."}}
image_data = part.get("image_url", {})
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
entry: Dict[str, Any] = {"type": "input_image", "image_url": url}
entry: dict[str, Any] = {"type": "input_image", "image_url": url}
# Preserve detail if specified
detail = image_data.get("detail") if isinstance(image_data, dict) else None
if detail:
@@ -148,19 +148,21 @@ class _CodexCompletionsAdapter:
# Convert chat.completions multimodal content blocks to Responses
# API format (input_text / input_image instead of text / image_url).
instructions = "You are a helpful assistant."
input_msgs: List[Dict[str, Any]] = []
input_msgs: list[dict[str, Any]] = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content") or ""
if role == "system":
instructions = content if isinstance(content, str) else str(content)
else:
input_msgs.append({
"role": role,
"content": _convert_content_for_responses(content),
})
input_msgs.append(
{
"role": role,
"content": _convert_content_for_responses(content),
}
)
resp_kwargs: Dict[str, Any] = {
resp_kwargs: dict[str, Any] = {
"model": model,
"instructions": instructions,
"input": input_msgs or [{"role": "user", "content": ""}],
@@ -179,18 +181,20 @@ class _CodexCompletionsAdapter:
name = fn.get("name")
if not name:
continue
converted.append({
"type": "function",
"name": name,
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
})
converted.append(
{
"type": "function",
"name": name,
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
}
)
if converted:
resp_kwargs["tools"] = converted
# Stream and collect the response
text_parts: List[str] = []
tool_calls_raw: List[Any] = []
text_parts: list[str] = []
tool_calls_raw: list[Any] = []
usage = None
try:
@@ -208,14 +212,16 @@ class _CodexCompletionsAdapter:
if ptype in ("output_text", "text"):
text_parts.append(getattr(part, "text", ""))
elif item_type == "function_call":
tool_calls_raw.append(SimpleNamespace(
id=getattr(item, "call_id", ""),
type="function",
function=SimpleNamespace(
name=getattr(item, "name", ""),
arguments=getattr(item, "arguments", "{}"),
),
))
tool_calls_raw.append(
SimpleNamespace(
id=getattr(item, "call_id", ""),
type="function",
function=SimpleNamespace(
name=getattr(item, "name", ""),
arguments=getattr(item, "arguments", "{}"),
),
)
)
resp_usage = getattr(final, "usage", None)
if resp_usage:
@@ -285,6 +291,7 @@ class _AsyncCodexCompletionsAdapter:
async def create(self, **kwargs) -> Any:
import asyncio
return await asyncio.to_thread(self._sync.create, **kwargs)
@@ -304,7 +311,7 @@ class AsyncCodexAuxiliaryClient:
self.base_url = sync_wrapper.base_url
def _read_nous_auth() -> Optional[dict]:
def _read_nous_auth() -> dict | None:
"""Read and validate ~/.hermes/auth.json for an active Nous provider.
Returns the provider state dict if Nous is active with tokens,
@@ -336,10 +343,11 @@ def _nous_base_url() -> str:
return os.getenv("NOUS_INFERENCE_BASE_URL", _NOUS_DEFAULT_BASE_URL)
def _read_codex_access_token() -> Optional[str]:
def _read_codex_access_token() -> str | None:
"""Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json)."""
try:
from hermes_cli.auth import _read_codex_tokens
data = _read_codex_tokens()
tokens = data.get("tokens", {})
access_token = tokens.get("access_token")
@@ -351,7 +359,7 @@ def _read_codex_access_token() -> Optional[str]:
return None
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
def _resolve_api_key_provider() -> tuple[OpenAI | None, str | None]:
"""Try each API-key provider in PROVIDER_REGISTRY order.
Returns (client, model) for the first provider whose env var is set,
@@ -398,6 +406,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
# ── Provider resolution helpers ─────────────────────────────────────────────
def _get_auxiliary_provider(task: str = "") -> str:
"""Read the provider override for a specific auxiliary task.
@@ -413,16 +422,15 @@ def _get_auxiliary_provider(task: str = "") -> str:
return "auto"
def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]:
def _try_openrouter() -> tuple[OpenAI | None, str | None]:
or_key = os.getenv("OPENROUTER_API_KEY")
if not or_key:
return None, None
logger.debug("Auxiliary client: OpenRouter")
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL, default_headers=_OR_HEADERS), _OPENROUTER_MODEL
def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
def _try_nous() -> tuple[OpenAI | None, str | None]:
nous = _read_nous_auth()
if not nous:
return None, None
@@ -435,7 +443,7 @@ def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
)
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
def _try_custom_endpoint() -> tuple[OpenAI | None, str | None]:
custom_base = os.getenv("OPENAI_BASE_URL")
custom_key = os.getenv("OPENAI_API_KEY")
if not custom_base or not custom_key:
@@ -445,7 +453,7 @@ def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
return OpenAI(api_key=custom_key, base_url=custom_base), model
def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
def _try_codex() -> tuple[Any | None, str | None]:
codex_token = _read_codex_access_token()
if not codex_token:
return None, None
@@ -454,7 +462,7 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
def _resolve_forced_provider(forced: str) -> tuple[OpenAI | None, str | None]:
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
if forced == "openrouter":
client, model = _try_openrouter()
@@ -488,10 +496,9 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st
return None, None
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
def _resolve_auto() -> tuple[OpenAI | None, str | None]:
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
_try_codex, _resolve_api_key_provider):
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint, _try_codex, _resolve_api_key_provider):
client, model = try_fn()
if client is not None:
return client, model
@@ -501,7 +508,8 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
# ── Public API ──────────────────────────────────────────────────────────────
def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optional[str]]:
def get_text_auxiliary_client(task: str = "") -> tuple[OpenAI | None, str | None]:
"""Return (client, default_model_slug) for text-only auxiliary tasks.
Args:
@@ -544,7 +552,7 @@ def get_async_text_auxiliary_client(task: str = ""):
return AsyncOpenAI(**async_kwargs), model
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
def get_vision_auxiliary_client() -> tuple[OpenAI | None, str | None]:
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks.
Checks AUXILIARY_VISION_PROVIDER for a forced provider, otherwise
@@ -564,8 +572,7 @@ def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
# back to the user's custom endpoint. Many local models (Qwen-VL,
# LLaVA, Pixtral, etc.) support vision — skipping them entirely
# caused silent failures for local-only users.
for try_fn in (_try_openrouter, _try_nous, _try_codex,
_try_custom_endpoint):
for try_fn in (_try_openrouter, _try_nous, _try_codex, _try_custom_endpoint):
client, model = try_fn()
if client is not None:
return client, model
@@ -575,7 +582,7 @@ def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
def get_auxiliary_extra_body() -> dict:
"""Return extra_body kwargs for auxiliary API calls.
Includes Nous Portal product tags when the auxiliary client is backed
by Nous Portal. Returns empty dict otherwise.
"""
@@ -584,7 +591,7 @@ def get_auxiliary_extra_body() -> dict:
def auxiliary_max_tokens_param(value: int) -> dict:
"""Return the correct max tokens kwarg for the auxiliary client's provider.
OpenRouter and local models use 'max_tokens'. Direct OpenAI with newer
models (gpt-4o, o-series, gpt-5+) requires 'max_completion_tokens'.
The Codex adapter translates max_tokens internally, so we use max_tokens
@@ -593,8 +600,6 @@ def auxiliary_max_tokens_param(value: int) -> dict:
custom_base = os.getenv("OPENAI_BASE_URL", "")
or_key = os.getenv("OPENROUTER_API_KEY")
# Only use max_completion_tokens for direct OpenAI custom endpoints
if (not or_key
and _read_nous_auth() is None
and "api.openai.com" in custom_base.lower()):
if not or_key and _read_nous_auth() is None and "api.openai.com" in custom_base.lower():
return {"max_completion_tokens": value}
return {"max_tokens": value}

View File

@@ -7,12 +7,12 @@ protecting head and tail context.
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any
from agent.auxiliary_client import get_text_auxiliary_client
from agent.model_metadata import (
get_model_context_length,
estimate_messages_tokens_rough,
get_model_context_length,
)
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ class ContextCompressor:
self.client, default_model = get_text_auxiliary_client("compression")
self.summary_model = summary_model_override or default_model
def update_from_response(self, usage: Dict[str, Any]):
def update_from_response(self, usage: dict[str, Any]):
"""Update tracked token usage from API response."""
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
self.last_completion_tokens = usage.get("completion_tokens", 0)
@@ -67,12 +67,12 @@ class ContextCompressor:
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
return tokens >= self.threshold_tokens
def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool:
def should_compress_preflight(self, messages: list[dict[str, Any]]) -> bool:
"""Quick pre-flight check using rough estimate (before API call)."""
rough_estimate = estimate_messages_tokens_rough(messages)
return rough_estimate >= self.threshold_tokens
def get_status(self) -> Dict[str, Any]:
def get_status(self) -> dict[str, Any]:
"""Get current compression status for display/logging."""
return {
"last_prompt_tokens": self.last_prompt_tokens,
@@ -82,7 +82,7 @@ class ContextCompressor:
"compression_count": self.compression_count,
}
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
def _generate_summary(self, turns_to_summarize: list[dict[str, Any]]) -> str | None:
"""Generate a concise summary of conversation turns.
Tries the auxiliary model first, then falls back to the user's main
@@ -140,7 +140,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
logging.warning(f"Main model summary also failed: {fallback_err}")
# 3. All models failed — return None so the caller drops turns without a summary
logging.warning("Context compression: no model available for summary. Middle turns will be dropped without summary.")
logging.warning(
"Context compression: no model available for summary. Middle turns will be dropped without summary."
)
return None
def _call_summary_model(self, client, model: str, prompt: str) -> str:
@@ -186,12 +188,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
# Don't fallback to the same provider that just failed
from hermes_constants import OPENROUTER_BASE_URL
if custom_base.rstrip("/") == OPENROUTER_BASE_URL.rstrip("/"):
return None, None
model = os.getenv("LLM_MODEL") or os.getenv("OPENAI_MODEL") or self.model
try:
from openai import OpenAI as _OpenAI
client = _OpenAI(api_key=custom_key, base_url=custom_base)
logger.debug("Built fallback auxiliary client: %s via %s", model, custom_base)
return client, model
@@ -210,7 +214,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
return tc.get("id", "")
return getattr(tc, "id", "") or ""
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def _sanitize_tool_pairs(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Fix orphaned tool_call / tool_result pairs after compression.
Two failure modes:
@@ -243,8 +247,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
orphaned_results = result_call_ids - surviving_call_ids
if orphaned_results:
messages = [
m for m in messages
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
m for m in messages if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
]
if not self.quiet_mode:
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
@@ -252,25 +255,27 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
# 2. Add stub results for assistant tool_calls whose results were dropped
missing_results = surviving_call_ids - result_call_ids
if missing_results:
patched: List[Dict[str, Any]] = []
patched: list[dict[str, Any]] = []
for msg in messages:
patched.append(msg)
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = self._get_tool_call_id(tc)
if cid in missing_results:
patched.append({
"role": "tool",
"content": "[Result from earlier conversation — see context summary above]",
"tool_call_id": cid,
})
patched.append(
{
"role": "tool",
"content": "[Result from earlier conversation — see context summary above]",
"tool_call_id": cid,
}
)
messages = patched
if not self.quiet_mode:
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
return messages
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
def _align_boundary_forward(self, messages: list[dict[str, Any]], idx: int) -> int:
"""Push a compress-start boundary forward past any orphan tool results.
If ``messages[idx]`` is a tool result, slide forward until we hit a
@@ -280,7 +285,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
idx += 1
return idx
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
def _align_boundary_backward(self, messages: list[dict[str, Any]], idx: int) -> int:
"""Pull a compress-end boundary backward to avoid splitting a
tool_call / result group.
@@ -298,7 +303,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
idx -= 1
return idx
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
def compress(self, messages: list[dict[str, Any]], current_tokens: int = None) -> list[dict[str, Any]]:
"""Compress conversation messages by summarizing middle turns.
Keeps first N + last N turns, summarizes everything in between.
@@ -308,7 +313,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
n_messages = len(messages)
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
if not self.quiet_mode:
print(f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})")
print(
f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})"
)
return messages
compress_start = self.protect_first_n
@@ -323,14 +330,20 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
return messages
turns_to_summarize = messages[compress_start:compress_end]
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
display_tokens = (
current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
)
if not self.quiet_mode:
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
print(
f"\n📦 Context compression triggered ({display_tokens:,} tokens {self.threshold_tokens:,} threshold)"
)
print(
f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent * 100:.0f}% = {self.threshold_tokens:,})"
)
if not self.quiet_mode:
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
print(f" 🗜️ Summarizing turns {compress_start + 1}-{compress_end} ({len(turns_to_summarize)} turns)")
summary = self._generate_summary(turns_to_summarize)
@@ -338,7 +351,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
for i in range(compress_start):
msg = messages[i].copy()
if i == 0 and msg.get("role") == "system" and self.compression_count == 0:
msg["content"] = (msg.get("content") or "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
msg["content"] = (
msg.get("content") or ""
) + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
compressed.append(msg)
if summary:

View File

@@ -6,7 +6,6 @@ Used by AIAgent._execute_tool_calls for CLI feedback.
import json
import os
import random
import sys
import threading
import time
@@ -20,19 +19,31 @@ _RESET = "\033[0m"
# Tool preview (one-line summary of a tool call's primary argument)
# =========================================================================
def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
"""Build a short preview of a tool call's primary argument for display."""
primary_args = {
"terminal": "command", "web_search": "query", "web_extract": "urls",
"read_file": "path", "write_file": "path", "patch": "path",
"search_files": "pattern", "browser_navigate": "url",
"browser_click": "ref", "browser_type": "text",
"image_generate": "prompt", "text_to_speech": "text",
"vision_analyze": "question", "mixture_of_agents": "user_prompt",
"skill_view": "name", "skills_list": "category",
"terminal": "command",
"web_search": "query",
"web_extract": "urls",
"read_file": "path",
"write_file": "path",
"patch": "path",
"search_files": "pattern",
"browser_navigate": "url",
"browser_click": "ref",
"browser_type": "text",
"image_generate": "prompt",
"text_to_speech": "text",
"vision_analyze": "question",
"mixture_of_agents": "user_prompt",
"skill_view": "name",
"skills_list": "category",
"schedule_cronjob": "name",
"execute_code": "code", "delegate_task": "goal",
"clarify": "question", "skill_manage": "name",
"execute_code": "code",
"delegate_task": "goal",
"clarify": "question",
"skill_manage": "name",
}
if tool_name == "process":
@@ -61,18 +72,18 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
if tool_name == "session_search":
query = args.get("query", "")
return f"recall: \"{query[:25]}{'...' if len(query) > 25 else ''}\""
return f'recall: "{query[:25]}{"..." if len(query) > 25 else ""}"'
if tool_name == "memory":
action = args.get("action", "")
target = args.get("target", "")
if action == "add":
content = args.get("content", "")
return f"+{target}: \"{content[:25]}{'...' if len(content) > 25 else ''}\""
return f'+{target}: "{content[:25]}{"..." if len(content) > 25 else ""}"'
elif action == "replace":
return f"~{target}: \"{args.get('old_text', '')[:20]}\""
return f'~{target}: "{args.get("old_text", "")[:20]}"'
elif action == "remove":
return f"-{target}: \"{args.get('old_text', '')[:20]}\""
return f'-{target}: "{args.get("old_text", "")[:20]}"'
return action
if tool_name == "send_message":
@@ -80,7 +91,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
msg = args.get("message", "")
if len(msg) > 20:
msg = msg[:17] + "..."
return f"to {target}: \"{msg}\""
return f'to {target}: "{msg}"'
if tool_name.startswith("rl_"):
rl_previews = {
@@ -115,7 +126,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
if not preview:
return None
if len(preview) > max_len:
preview = preview[:max_len - 3] + "..."
preview = preview[: max_len - 3] + "..."
return preview
@@ -123,41 +134,74 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
# KawaiiSpinner
# =========================================================================
class KawaiiSpinner:
"""Animated spinner with kawaii faces for CLI feedback during tool execution."""
SPINNERS = {
'dots': ['', '', '', '', '', '', '', '', '', ''],
'bounce': ['', '', '', '', '', '', '', ''],
'grow': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
'arrows': ['', '', '', '', '', '', '', ''],
'star': ['', '', '', '', '', '', '', ''],
'moon': ['🌑', '🌒', '🌓', '🌔', '🌕', '🌖', '🌗', '🌘'],
'pulse': ['', '', '', '', '', ''],
'brain': ['🧠', '💭', '💡', '', '💫', '🌟', '💡', '💭'],
'sparkle': ['', '˚', '*', '', '', '', '*', '˚'],
"dots": ["", "", "", "", "", "", "", "", "", ""],
"bounce": ["", "", "", "", "", "", "", ""],
"grow": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
"arrows": ["", "", "", "", "", "", "", ""],
"star": ["", "", "", "", "", "", "", ""],
"moon": ["🌑", "🌒", "🌓", "🌔", "🌕", "🌖", "🌗", "🌘"],
"pulse": ["", "", "", "", "", ""],
"brain": ["🧠", "💭", "💡", "", "💫", "🌟", "💡", "💭"],
"sparkle": ["", "˚", "*", "", "", "", "*", "˚"],
}
KAWAII_WAITING = [
"(。◕‿◕。)", "(◕‿◕✿)", "٩(◕‿◕。)۶", "(✿◠‿◠)", "( ˘▽˘)っ",
"♪(´ε` )", "(◕◕✿)", "ヾ(^∇^)", "(≧◡≦)", "(★ω★)",
"(。◕‿◕。)",
"(◕◕✿)",
"٩(◕‿◕。)۶",
"(✿◠‿◠)",
"( ˘▽˘)っ",
"♪(´ε` )",
"(◕ᴗ◕✿)",
"ヾ(^∇^)",
"(≧◡≦)",
"(★ω★)",
]
KAWAII_THINKING = [
"(。•́︿•̀。)", "(◔_◔)", "(¬‿¬)", "( •_•)>⌐■-■", "(⌐■_■)",
"(´・_・`)", "◉_◉", "(°ロ°)", "( ˘⌣˘)♡", "ヽ(>∀<☆)☆",
"٩(๑❛ᴗ❛๑)۶", "(⊙_⊙)", "_¬)", "( ͡° ͜ʖ ͡°)", "ಠ_ಠ",
"(。•́︿•̀。)",
"(◔_◔)",
"¬)",
"( •_•)>⌐■-■",
"(⌐■_■)",
"(´・_・`)",
"◉_◉",
"(°ロ°)",
"( ˘⌣˘)♡",
"ヽ(>∀<☆)☆",
"٩(๑❛ᴗ❛๑)۶",
"(⊙_⊙)",
"(¬_¬)",
"( ͡° ͜ʖ ͡°)",
"ಠ_ಠ",
]
THINKING_VERBS = [
"pondering", "contemplating", "musing", "cogitating", "ruminating",
"deliberating", "mulling", "reflecting", "processing", "reasoning",
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
"pondering",
"contemplating",
"musing",
"cogitating",
"ruminating",
"deliberating",
"mulling",
"reflecting",
"processing",
"reasoning",
"analyzing",
"computing",
"synthesizing",
"formulating",
"brainstorming",
]
def __init__(self, message: str = "", spinner_type: str = 'dots'):
def __init__(self, message: str = "", spinner_type: str = "dots"):
self.message = message
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS["dots"])
self.running = False
self.thread = None
self.frame_idx = 0
@@ -167,7 +211,7 @@ class KawaiiSpinner:
# child agents can replace sys.stdout with a black hole.
self._out = sys.stdout
def _write(self, text: str, end: str = '\n', flush: bool = False):
def _write(self, text: str, end: str = "\n", flush: bool = False):
"""Write to the stdout captured at spinner creation time."""
try:
self._out.write(text + end)
@@ -185,7 +229,7 @@ class KawaiiSpinner:
elapsed = time.time() - self.start_time
line = f" {frame} {self.message} ({elapsed:.1f}s)"
pad = max(self.last_line_len - len(line), 0)
self._write(f"\r{line}{' ' * pad}", end='', flush=True)
self._write(f"\r{line}{' ' * pad}", end="", flush=True)
self.last_line_len = len(line)
self.frame_idx += 1
time.sleep(0.12)
@@ -216,7 +260,7 @@ class KawaiiSpinner:
# Clear spinner line with spaces (not \033[K) to avoid garbled escape
# codes when prompt_toolkit's patch_stdout is active — same approach
# as stop(). Then print text; spinner redraws on next tick.
blanks = ' ' * max(self.last_line_len + 5, 40)
blanks = " " * max(self.last_line_len + 5, 40)
self._write(f"\r{blanks}\r {text}", flush=True)
def stop(self, final_message: str = None):
@@ -225,8 +269,8 @@ class KawaiiSpinner:
self.thread.join(timeout=0.5)
# Clear the spinner line with spaces instead of \033[K to avoid
# garbled escape codes when prompt_toolkit's patch_stdout is active.
blanks = ' ' * max(self.last_line_len + 5, 40)
self._write(f"\r{blanks}\r", end='', flush=True)
blanks = " " * max(self.last_line_len + 5, 40)
self._write(f"\r{blanks}\r", end="", flush=True)
if final_message:
self._write(f" {final_message}", flush=True)
@@ -244,38 +288,110 @@ class KawaiiSpinner:
# =========================================================================
KAWAII_SEARCH = [
"♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ",
"٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)*:・゚✧", "(◎o◎)",
"♪(´ε` )",
"(◕‿◕。)",
"ヾ(^∇^)",
"(◕ᴗ◕✿)",
"( ˘▽˘)っ",
"٩(◕‿◕。)۶",
"(✿◠‿◠)",
"♪~(´ε` )",
"(ノ´ヮ`)*:・゚✧",
"(◎o◎)",
]
KAWAII_READ = [
"φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)",
"ヾ(@⌒ー⌒@)", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )",
"φ(゜▽゜*)♪",
"( ˘▽˘)っ",
"(⌐■_■)",
"٩(。•́‿•̀。)۶",
"(◕‿◕✿)",
"ヾ(@⌒ー⌒@)",
"(✧ω✧)",
"♪(๑ᴖ◡ᴖ๑)♪",
"(≧◡≦)",
"( ´ ▽ ` )",
]
KAWAII_TERMINAL = [
"ヽ(>∀<☆)", "(ノ°∀°)", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و",
"┗(0)┓", "(`・ω・´)", "( ̄▽ ̄)", "(ง •̀_•́)ง", "ヽ(´▽`)/",
"ヽ(>∀<☆)",
"(ノ°∀°)",
"٩(^ᴗ^)۶",
"ヾ(⌐■_■)ノ♪",
"(•̀ᴗ•́)و",
"┗(0)┓",
"(`・ω・´)",
"( ̄▽ ̄)",
"(ง •̀_•́)ง",
"ヽ(´▽`)/",
]
KAWAII_BROWSER = [
"(ノ°∀°)", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)",
"ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "(◎o◎)",
"(ノ°∀°)",
"(☞゚ヮ゚)☞",
"( ͡° ͜ʖ ͡°)",
"┌( ಠ_ಠ)┘",
"(⊙_⊙)",
"ヾ(•ω•`)o",
"( ̄ω ̄)",
"( ˇωˇ )",
"(ᵔᴥᵔ)",
"(◎o◎)",
]
KAWAII_CREATE = [
"✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)", "٩(♡ε♡)۶", "(◕‿◕)♡",
"✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(-)", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°",
"✧*。٩(ˊᗜˋ*)و✧",
"(ノ◕ヮ◕)ノ*:・゚✧",
"ヽ(>∀<☆)",
"٩(♡ε♡)۶",
"(◕‿◕)♡",
"✿◕ ‿ ◕✿",
"(*≧▽≦)",
"ヾ(-)",
"(☆▽☆)",
"°˖✧◝(⁰▿⁰)◜✧˖°",
]
KAWAII_SKILL = [
"ヾ(@⌒ー⌒@)", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)",
"(ノ´ヮ`)*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)",
"ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "(◎o◎)",
"(✧ω✧)", "ヽ(>∀<☆)", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)",
"ヾ(@⌒ー⌒@)",
"(๑˃ᴗ˂)ﻭ",
"٩(◕‿◕。)۶",
"(✿╹◡╹)",
"ヽ(・∀・)",
"(ノ´ヮ`)*:・゚✧",
"♪(๑ᴖ◡ᴖ๑)♪",
"(◠‿◠)",
"٩(ˊᗜˋ*)و",
"(^▽^)",
"ヾ(^∇^)",
"(★ω★)/",
"٩(。•́‿•̀。)۶",
"(◕ᴗ◕✿)",
"(◎o◎)",
"(✧ω✧)",
"ヽ(>∀<☆)",
"( ˘▽˘)っ",
"(≧◡≦) ♡",
"ヾ( ̄▽ ̄)",
]
KAWAII_THINK = [
"(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)",
"(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )", "(一_一)",
"(っ°Д°;)っ",
"(;′⌒`)",
"(・_・ヾ",
"( ´_ゝ`)",
"( ̄ヘ ̄)",
"(。-`ω´-)",
"( ˘︹˘ )",
"(¬_¬)",
"ヽ(ー_ー )",
"(一_一)",
]
KAWAII_GENERIC = [
"♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)",
"(ノ´ヮ`)*:・゚✧", "ヽ(>∀<☆)", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)",
"♪(´ε` )",
"(◕‿◕✿)",
"ヾ(^∇^)",
"٩(◕‿◕。)۶",
"(✿◠‿◠)",
"(ノ´ヮ`)*:・゚✧",
"ヽ(>∀<☆)",
"(☆▽☆)",
"( ˘▽˘)っ",
"(≧◡≦)",
]
@@ -283,6 +399,7 @@ KAWAII_GENERIC = [
# Cute tool message (completion line that replaces the spinner)
# =========================================================================
def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
"""Inspect a tool result string for signs of failure.
@@ -321,7 +438,10 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
def get_cute_tool_message(
tool_name: str, args: dict, duration: float, result: str | None = None,
tool_name: str,
args: dict,
duration: float,
result: str | None = None,
) -> str:
"""Generate a formatted tool completion line for CLI quiet mode.
@@ -335,11 +455,11 @@ def get_cute_tool_message(
def _trunc(s, n=40):
s = str(s)
return (s[:n-3] + "...") if len(s) > n else s
return (s[: n - 3] + "...") if len(s) > n else s
def _path(p, n=35):
p = str(p)
return ("..." + p[-(n-3):]) if len(p) > n else p
return ("..." + p[-(n - 3) :]) if len(p) > n else p
def _wrap(line: str) -> str:
"""Append failure suffix when the tool failed."""
@@ -354,7 +474,7 @@ def get_cute_tool_message(
if urls:
url = urls[0] if isinstance(urls, list) else str(urls)
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
extra = f" +{len(urls)-1}" if len(urls) > 1 else ""
extra = f" +{len(urls) - 1}" if len(urls) > 1 else ""
return _wrap(f"┊ 📄 fetch {_trunc(domain, 35)}{extra} {dur}")
return _wrap(f"┊ 📄 fetch pages {dur}")
if tool_name == "web_crawl":
@@ -366,8 +486,15 @@ def get_cute_tool_message(
if tool_name == "process":
action = args.get("action", "?")
sid = args.get("session_id", "")[:12]
labels = {"list": "ls processes", "poll": f"poll {sid}", "log": f"log {sid}",
"wait": f"wait {sid}", "kill": f"kill {sid}", "write": f"write {sid}", "submit": f"submit {sid}"}
labels = {
"list": "ls processes",
"poll": f"poll {sid}",
"log": f"log {sid}",
"wait": f"wait {sid}",
"kill": f"kill {sid}",
"write": f"write {sid}",
"submit": f"submit {sid}",
}
return _wrap(f"┊ ⚙️ proc {labels.get(action, f'{action} {sid}')} {dur}")
if tool_name == "read_file":
return _wrap(f"┊ 📖 read {_path(args.get('path', ''))} {dur}")
@@ -390,7 +517,7 @@ def get_cute_tool_message(
if tool_name == "browser_click":
return _wrap(f"┊ 👆 click {args.get('ref', '?')} {dur}")
if tool_name == "browser_type":
return _wrap(f"┊ ⌨️ type \"{_trunc(args.get('text', ''), 30)}\" {dur}")
return _wrap(f'┊ ⌨️ type "{_trunc(args.get("text", ""), 30)}" {dur}')
if tool_name == "browser_scroll":
d = args.get("direction", "down")
arrow = {"down": "", "up": "", "right": "", "left": ""}.get(d, "")
@@ -415,16 +542,16 @@ def get_cute_tool_message(
else:
return _wrap(f"┊ 📋 plan {len(todos_arg)} task(s) {dur}")
if tool_name == "session_search":
return _wrap(f"┊ 🔍 recall \"{_trunc(args.get('query', ''), 35)}\" {dur}")
return _wrap(f'┊ 🔍 recall "{_trunc(args.get("query", ""), 35)}" {dur}')
if tool_name == "memory":
action = args.get("action", "?")
target = args.get("target", "")
if action == "add":
return _wrap(f"┊ 🧠 memory +{target}: \"{_trunc(args.get('content', ''), 30)}\" {dur}")
return _wrap(f'┊ 🧠 memory +{target}: "{_trunc(args.get("content", ""), 30)}" {dur}')
elif action == "replace":
return _wrap(f"┊ 🧠 memory ~{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}")
return _wrap(f'┊ 🧠 memory ~{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}')
elif action == "remove":
return _wrap(f"┊ 🧠 memory -{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}")
return _wrap(f'┊ 🧠 memory -{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}')
return _wrap(f"┊ 🧠 memory {action} {dur}")
if tool_name == "skills_list":
return _wrap(f"┊ 📚 skills list {args.get('category', 'all')} {dur}")
@@ -439,7 +566,7 @@ def get_cute_tool_message(
if tool_name == "mixture_of_agents":
return _wrap(f"┊ 🧠 reason {_trunc(args.get('user_prompt', ''), 30)} {dur}")
if tool_name == "send_message":
return _wrap(f"┊ 📨 send {args.get('target', '?')}: \"{_trunc(args.get('message', ''), 25)}\" {dur}")
return _wrap(f'┊ 📨 send {args.get("target", "?")}: "{_trunc(args.get("message", ""), 25)}" {dur}')
if tool_name == "schedule_cronjob":
return _wrap(f"┊ ⏰ schedule {_trunc(args.get('name', args.get('prompt', 'task')), 30)} {dur}")
if tool_name == "list_cronjobs":
@@ -448,11 +575,16 @@ def get_cute_tool_message(
return _wrap(f"┊ ⏰ remove job {args.get('job_id', '?')} {dur}")
if tool_name.startswith("rl_"):
rl = {
"rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}",
"rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}",
"rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}",
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}",
"rl_list_runs": "list runs", "rl_test_inference": "test inference",
"rl_list_environments": "list envs",
"rl_select_environment": f"select {args.get('name', '')}",
"rl_get_current_config": "get config",
"rl_edit_config": f"set {args.get('field', '?')}",
"rl_start_training": "start training",
"rl_check_status": f"status {args.get('run_id', '?')[:12]}",
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}",
"rl_get_results": f"results {args.get('run_id', '?')[:12]}",
"rl_list_runs": "list runs",
"rl_test_inference": "test inference",
}
return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}")
if tool_name == "execute_code":

View File

@@ -20,7 +20,7 @@ import json
import time
from collections import Counter, defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any
# =========================================================================
# Model pricing (USD per million tokens) — approximate as of early 2026
@@ -81,7 +81,7 @@ def _has_known_pricing(model_name: str) -> bool:
return _get_pricing(model_name) is not _DEFAULT_PRICING
def _get_pricing(model_name: str) -> Dict[str, float]:
def _get_pricing(model_name: str) -> dict[str, float]:
"""Look up pricing for a model. Uses fuzzy matching on model name.
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
@@ -150,7 +150,7 @@ def _format_duration(seconds: float) -> str:
return f"{days:.1f}d"
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
def _bar_chart(values: list[int], max_width: int = 20) -> list[str]:
"""Create simple horizontal bar chart strings from values."""
peak = max(values) if values else 1
if peak == 0:
@@ -176,7 +176,7 @@ class InsightsEngine:
self.db = db
self._conn = db._conn
def generate(self, days: int = 30, source: str = None) -> Dict[str, Any]:
def generate(self, days: int = 30, source: str = None) -> dict[str, Any]:
"""
Generate a complete insights report.
@@ -233,10 +233,11 @@ class InsightsEngine:
# =========================================================================
# Columns we actually need (skip system_prompt, model_config blobs)
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
"message_count, tool_call_count, input_tokens, output_tokens")
_SESSION_COLS = (
"id, source, model, started_at, ended_at, message_count, tool_call_count, input_tokens, output_tokens"
)
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
def _get_sessions(self, cutoff: float, source: str = None) -> list[dict]:
"""Fetch sessions within the time window."""
if source:
cursor = self._conn.execute(
@@ -254,7 +255,7 @@ class InsightsEngine:
)
return [dict(row) for row in cursor.fetchall()]
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
def _get_tool_usage(self, cutoff: float, source: str = None) -> list[dict]:
"""Get tool call counts from messages.
Uses two sources:
@@ -341,12 +342,9 @@ class InsightsEngine:
tool_counts = merged
# Convert to the expected format
return [
{"tool_name": name, "count": count}
for name, count in tool_counts.most_common()
]
return [{"tool_name": name, "count": count} for name, count in tool_counts.most_common()]
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
def _get_message_stats(self, cutoff: float, source: str = None) -> dict:
"""Get aggregate message statistics."""
if source:
cursor = self._conn.execute(
@@ -373,16 +371,22 @@ class InsightsEngine:
(cutoff,),
)
row = cursor.fetchone()
return dict(row) if row else {
"total_messages": 0, "user_messages": 0,
"assistant_messages": 0, "tool_messages": 0,
}
return (
dict(row)
if row
else {
"total_messages": 0,
"user_messages": 0,
"assistant_messages": 0,
"tool_messages": 0,
}
)
# =========================================================================
# Computation
# =========================================================================
def _compute_overview(self, sessions: List[Dict], message_stats: Dict) -> Dict:
def _compute_overview(self, sessions: list[dict], message_stats: dict) -> dict:
"""Compute high-level overview statistics."""
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
@@ -442,12 +446,18 @@ class InsightsEngine:
"models_without_pricing": sorted(models_without_pricing),
}
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
def _compute_model_breakdown(self, sessions: list[dict]) -> list[dict]:
"""Break down usage by model."""
model_data = defaultdict(lambda: {
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
})
model_data = defaultdict(
lambda: {
"sessions": 0,
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"tool_calls": 0,
"cost": 0.0,
}
)
for s in sessions:
model = s.get("model") or "unknown"
@@ -464,20 +474,23 @@ class InsightsEngine:
d["cost"] += _estimate_cost(model, inp, out)
d["has_pricing"] = _has_known_pricing(model)
result = [
{"model": model, **data}
for model, data in model_data.items()
]
result = [{"model": model, **data} for model, data in model_data.items()]
# Sort by tokens first, fall back to session count when tokens are 0
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
return result
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]:
def _compute_platform_breakdown(self, sessions: list[dict]) -> list[dict]:
"""Break down usage by platform/source."""
platform_data = defaultdict(lambda: {
"sessions": 0, "messages": 0, "input_tokens": 0,
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
})
platform_data = defaultdict(
lambda: {
"sessions": 0,
"messages": 0,
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"tool_calls": 0,
}
)
for s in sessions:
source = s.get("source") or "unknown"
@@ -491,27 +504,26 @@ class InsightsEngine:
d["total_tokens"] += inp + out
d["tool_calls"] += s.get("tool_call_count") or 0
result = [
{"platform": platform, **data}
for platform, data in platform_data.items()
]
result = [{"platform": platform, **data} for platform, data in platform_data.items()]
result.sort(key=lambda x: x["sessions"], reverse=True)
return result
def _compute_tool_breakdown(self, tool_usage: List[Dict]) -> List[Dict]:
def _compute_tool_breakdown(self, tool_usage: list[dict]) -> list[dict]:
"""Process tool usage data into a ranked list with percentages."""
total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0
result = []
for t in tool_usage:
pct = (t["count"] / total_calls * 100) if total_calls else 0
result.append({
"tool": t["tool_name"],
"count": t["count"],
"percentage": pct,
})
result.append(
{
"tool": t["tool_name"],
"count": t["count"],
"percentage": pct,
}
)
return result
def _compute_activity_patterns(self, sessions: List[Dict]) -> Dict:
def _compute_activity_patterns(self, sessions: list[dict]) -> dict:
"""Analyze activity patterns by day of week and hour."""
day_counts = Counter() # 0=Monday ... 6=Sunday
hour_counts = Counter()
@@ -527,15 +539,9 @@ class InsightsEngine:
daily_counts[dt.strftime("%Y-%m-%d")] += 1
day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
day_breakdown = [
{"day": day_names[i], "count": day_counts.get(i, 0)}
for i in range(7)
]
day_breakdown = [{"day": day_names[i], "count": day_counts.get(i, 0)} for i in range(7)]
hour_breakdown = [
{"hour": i, "count": hour_counts.get(i, 0)}
for i in range(24)
]
hour_breakdown = [{"hour": i, "count": hour_counts.get(i, 0)} for i in range(24)]
# Busiest day and hour
busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None
@@ -569,37 +575,40 @@ class InsightsEngine:
"max_streak": max_streak,
}
def _compute_top_sessions(self, sessions: List[Dict]) -> List[Dict]:
def _compute_top_sessions(self, sessions: list[dict]) -> list[dict]:
"""Find notable sessions (longest, most messages, most tokens)."""
top = []
# Longest by duration
sessions_with_duration = [
s for s in sessions
if s.get("started_at") and s.get("ended_at")
]
sessions_with_duration = [s for s in sessions if s.get("started_at") and s.get("ended_at")]
if sessions_with_duration:
longest = max(
sessions_with_duration,
key=lambda s: (s["ended_at"] - s["started_at"]),
key=lambda s: s["ended_at"] - s["started_at"],
)
dur = longest["ended_at"] - longest["started_at"]
top.append({
"label": "Longest session",
"session_id": longest["id"][:16],
"value": _format_duration(dur),
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
})
top.append(
{
"label": "Longest session",
"session_id": longest["id"][:16],
"value": _format_duration(dur),
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
}
)
# Most messages
most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0)
if (most_msgs.get("message_count") or 0) > 0:
top.append({
"label": "Most messages",
"session_id": most_msgs["id"][:16],
"value": f"{most_msgs['message_count']} msgs",
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") if most_msgs.get("started_at") else "?",
})
top.append(
{
"label": "Most messages",
"session_id": most_msgs["id"][:16],
"value": f"{most_msgs['message_count']} msgs",
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d")
if most_msgs.get("started_at")
else "?",
}
)
# Most tokens
most_tokens = max(
@@ -608,22 +617,30 @@ class InsightsEngine:
)
token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0)
if token_total > 0:
top.append({
"label": "Most tokens",
"session_id": most_tokens["id"][:16],
"value": f"{token_total:,} tokens",
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") if most_tokens.get("started_at") else "?",
})
top.append(
{
"label": "Most tokens",
"session_id": most_tokens["id"][:16],
"value": f"{token_total:,} tokens",
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d")
if most_tokens.get("started_at")
else "?",
}
)
# Most tool calls
most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0)
if (most_tools.get("tool_call_count") or 0) > 0:
top.append({
"label": "Most tool calls",
"session_id": most_tools["id"][:16],
"value": f"{most_tools['tool_call_count']} calls",
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") if most_tools.get("started_at") else "?",
})
top.append(
{
"label": "Most tool calls",
"session_id": most_tools["id"][:16],
"value": f"{most_tools['tool_call_count']} calls",
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d")
if most_tools.get("started_at")
else "?",
}
)
return top
@@ -631,7 +648,7 @@ class InsightsEngine:
# Formatting
# =========================================================================
def format_terminal(self, report: Dict) -> str:
def format_terminal(self, report: dict) -> str:
"""Format the insights report for terminal display (CLI)."""
if report.get("empty"):
days = report.get("days", 30)
@@ -669,13 +686,17 @@ class InsightsEngine:
lines.append(" " + "" * 56)
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}")
lines.append(
f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}"
)
cost_str = f"${o['estimated_cost']:.2f}"
if o.get("models_without_pricing"):
cost_str += " *"
lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}")
if o["total_hours"] > 0:
lines.append(f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}")
lines.append(
f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}"
)
lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}")
lines.append("")
@@ -692,7 +713,7 @@ class InsightsEngine:
cost_cell = " N/A"
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
if o.get("models_without_pricing"):
lines.append(f" * Cost N/A for custom/self-hosted models")
lines.append(" * Cost N/A for custom/self-hosted models")
lines.append("")
# Platform breakdown
@@ -758,7 +779,7 @@ class InsightsEngine:
return "\n".join(lines)
def format_gateway(self, report: Dict) -> str:
def format_gateway(self, report: dict) -> str:
"""Format the insights report for gateway/messaging (shorter)."""
if report.get("empty"):
days = report.get("days", 30)
@@ -771,14 +792,20 @@ class InsightsEngine:
lines.append(f"📊 **Hermes Insights** — Last {days} days\n")
# Overview
lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}")
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
lines.append(
f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}"
)
lines.append(
f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})"
)
cost_note = ""
if o.get("models_without_pricing"):
cost_note = " _(excludes custom/self-hosted models)_"
lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}")
if o["total_hours"] > 0:
lines.append(f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}")
lines.append(
f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}"
)
lines.append("")
# Models (top 5)
@@ -786,7 +813,9 @@ class InsightsEngine:
lines.append("**🤖 Models:**")
for m in report["models"][:5]:
cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A"
lines.append(f" {m['model'][:25]}{m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}")
lines.append(
f" {m['model'][:25]}{m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}"
)
lines.append("")
# Platforms (if multi-platform)
@@ -809,9 +838,13 @@ class InsightsEngine:
hr = act["busiest_hour"]["hour"]
ampm = "AM" if hr < 12 else "PM"
display_hr = hr % 12 or 12
lines.append(f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)")
lines.append(
f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)"
)
if act.get("active_days"):
lines.append(f"**Active days:** {act['active_days']}", )
lines.append(
f"**Active days:** {act['active_days']}",
)
if act.get("max_streak", 0) > 1:
lines.append(f"**Best streak:** {act['max_streak']} consecutive days")

View File

@@ -9,7 +9,7 @@ import os
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
import requests
import yaml
@@ -18,7 +18,7 @@ from hermes_constants import OPENROUTER_MODELS_URL
logger = logging.getLogger(__name__)
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
_model_metadata_cache: dict[str, dict[str, Any]] = {}
_model_metadata_cache_time: float = 0
_MODEL_CACHE_TTL = 3600
@@ -63,7 +63,7 @@ DEFAULT_CONTEXT_LENGTHS = {
}
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
def fetch_model_metadata(force_refresh: bool = False) -> dict[str, dict[str, Any]]:
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
global _model_metadata_cache, _model_metadata_cache_time
@@ -104,7 +104,7 @@ def _get_context_cache_path() -> Path:
return hermes_home / "context_length_cache.yaml"
def _load_context_cache() -> Dict[str, int]:
def _load_context_cache() -> dict[str, int]:
"""Load the model+provider → context_length cache from disk."""
path = _get_context_cache_path()
if not path.exists():
@@ -139,14 +139,14 @@ def save_context_length(model: str, base_url: str, length: int) -> None:
logger.debug("Failed to save context length cache: %s", e)
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
def get_cached_context_length(model: str, base_url: str) -> int | None:
"""Look up a previously discovered context length for model+provider."""
key = f"{model}@{base_url}"
cache = _load_context_cache()
return cache.get(key)
def get_next_probe_tier(current_length: int) -> Optional[int]:
def get_next_probe_tier(current_length: int) -> int | None:
"""Return the next lower probe tier, or None if already at minimum."""
for tier in CONTEXT_PROBE_TIERS:
if tier < current_length:
@@ -154,7 +154,7 @@ def get_next_probe_tier(current_length: int) -> Optional[int]:
return None
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
def parse_context_limit_from_error(error_msg: str) -> int | None:
"""Try to extract the actual context limit from an API error message.
Many providers include the limit in their error text, e.g.:
@@ -166,11 +166,11 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
error_lower = error_msg.lower()
# Pattern: look for numbers near context-related keywords
patterns = [
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
r"(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})",
r"context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})",
r"(\d{4,})\s*(?:token)?\s*(?:context|limit)",
r">\s*(\d{4,})\s*(?:max|limit|token)", # "250000 tokens > 200000 maximum"
r"(\d{4,})\s*(?:max(?:imum)?)\b", # "200000 maximum"
]
for pattern in patterns:
match = re.search(pattern, error_lower)
@@ -218,7 +218,7 @@ def estimate_tokens_rough(text: str) -> int:
return len(text) // 4
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
def estimate_messages_tokens_rough(messages: list[dict[str, Any]]) -> int:
"""Rough token estimate for a message list (pre-flight only)."""
total_chars = sum(len(str(msg)) for msg in messages)
return total_chars // 4

View File

@@ -8,7 +8,6 @@ import logging
import os
import re
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
@@ -18,21 +17,29 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
_CONTEXT_THREAT_PATTERNS = [
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
(r'system\s+prompt\s+override', "sys_prompt_override"),
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
(r'<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->', "html_comment_injection"),
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
(r"system\s+prompt\s+override", "sys_prompt_override"),
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
(r"<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->", "html_comment_injection"),
(r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', "hidden_div"),
(r'translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)', "translate_execute"),
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
(r"translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)", "translate_execute"),
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
]
_CONTEXT_INVISIBLE_CHARS = {
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
"\u200b",
"\u200c",
"\u200d",
"\u2060",
"\ufeff",
"\u202a",
"\u202b",
"\u202c",
"\u202d",
"\u202e",
}
@@ -52,10 +59,13 @@ def _scan_context_content(content: str, filename: str) -> str:
if findings:
logger.warning("Context file %s blocked: %s", filename, ", ".join(findings))
return f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]"
return (
f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]"
)
return content
# =========================================================================
# Constants
# =========================================================================
@@ -131,10 +141,7 @@ PLATFORM_HINTS = {
"files arrive as downloadable documents. You can also include image "
"URLs in markdown format ![alt](url) and they will be sent as photos."
),
"cli": (
"You are a CLI AI Agent. Try not to use markdown but simple text "
"renderable inside a terminal."
),
"cli": ("You are a CLI AI Agent. Try not to use markdown but simple text renderable inside a terminal."),
}
CONTEXT_FILE_MAX_CHARS = 20_000
@@ -146,18 +153,20 @@ CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
# Skills index
# =========================================================================
def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
"""Read the description from a SKILL.md frontmatter, capped at max_chars."""
try:
raw = skill_file.read_text(encoding="utf-8")[:2000]
match = re.search(
r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---",
raw, re.MULTILINE | re.DOTALL,
raw,
re.MULTILINE | re.DOTALL,
)
if match:
desc = match.group(1).strip().strip("'\"")
if len(desc) > max_chars:
desc = desc[:max_chars - 3] + "..."
desc = desc[: max_chars - 3] + "..."
return desc
except Exception:
pass
@@ -172,6 +181,7 @@ def _skill_is_platform_compatible(skill_file: Path) -> bool:
"""
try:
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
raw = skill_file.read_text(encoding="utf-8")[:2000]
frontmatter, _ = _parse_frontmatter(raw)
return skill_matches_platform(frontmatter)
@@ -260,8 +270,7 @@ def build_skills_system_prompt() -> str:
"load it with skill_view(name) and follow its instructions. "
"If a skill has issues, fix it with skill_manage(action='patch').\n"
"\n"
"<available_skills>\n"
+ "\n".join(index_lines) + "\n"
"<available_skills>\n" + "\n".join(index_lines) + "\n"
"</available_skills>\n"
"\n"
"If none match, proceed normally without loading a skill."
@@ -272,6 +281,7 @@ def build_skills_system_prompt() -> str:
# Context files (SOUL.md, AGENTS.md, .cursorrules)
# =========================================================================
def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str:
"""Head/tail truncation with a marker in the middle."""
if len(content) <= max_chars:
@@ -284,7 +294,7 @@ def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE
return head + marker + tail
def build_context_files_prompt(cwd: Optional[str] = None) -> str:
def build_context_files_prompt(cwd: str | None = None) -> str:
"""Discover and load context files for the system prompt.
Discovery: AGENTS.md (recursive), .cursorrules / .cursor/rules/*.mdc,
@@ -307,7 +317,9 @@ def build_context_files_prompt(cwd: Optional[str] = None) -> str:
if top_level_agents:
agents_files = []
for root, dirs, files in os.walk(cwd_path):
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
dirs[:] = [
d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "venv", ".venv")
]
for f in files:
if f.lower() == "agents.md":
agents_files.append(Path(root) / f)
@@ -384,4 +396,7 @@ def build_context_files_prompt(cwd: Optional[str] = None) -> str:
if not sections:
return ""
return "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + "\n".join(sections)
return (
"# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n"
+ "\n".join(sections)
)

View File

@@ -9,7 +9,7 @@ Pure functions -- no class state, no AIAgent dependency.
"""
import copy
from typing import Any, Dict, List
from typing import Any
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
@@ -36,9 +36,9 @@ def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
def apply_anthropic_cache_control(
api_messages: List[Dict[str, Any]],
api_messages: list[dict[str, Any]],
cache_ttl: str = "5m",
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Apply system_and_3 caching strategy to messages for Anthropic models.
Places up to 4 cache_control breakpoints: system prompt + last 3 non-system messages.

View File

@@ -10,34 +10,33 @@ the first 6 and last 4 characters for debuggability.
import logging
import os
import re
from typing import Optional
logger = logging.getLogger(__name__)
# Known API key prefixes -- match the prefix + contiguous token chars
_PREFIX_PATTERNS = [
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
r"npm_[A-Za-z0-9]{10,}", # npm access token
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
r"npm_[A-Za-z0-9]{10,}", # npm access token
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
]
# ENV assignment patterns: KEY=value where KEY contains a secret-like name
@@ -66,9 +65,7 @@ _TELEGRAM_RE = re.compile(
)
# Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----
_PRIVATE_KEY_RE = re.compile(
r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----"
)
_PRIVATE_KEY_RE = re.compile(r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----")
# Database connection strings: protocol://user:PASSWORD@host
# Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password
@@ -82,9 +79,7 @@ _DB_CONNSTR_RE = re.compile(
_SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])")
# Compile known prefix patterns into one alternation
_PREFIX_RE = re.compile(
r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])"
)
_PREFIX_RE = re.compile(r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])")
def _mask_token(token: str) -> str:
@@ -112,12 +107,14 @@ def redact_sensitive_text(text: str) -> str:
def _redact_env(m):
name, quote, value = m.group(1), m.group(2), m.group(3)
return f"{name}={quote}{_mask_token(value)}{quote}"
text = _ENV_ASSIGN_RE.sub(_redact_env, text)
# JSON fields: "apiKey": "value"
def _redact_json(m):
key, value = m.group(1), m.group(2)
return f'{key}: "{_mask_token(value)}"'
text = _JSON_FIELD_RE.sub(_redact_json, text)
# Authorization headers
@@ -131,6 +128,7 @@ def redact_sensitive_text(text: str) -> str:
prefix = m.group(1) or ""
digits = m.group(2)
return f"{prefix}{digits}:***"
text = _TELEGRAM_RE.sub(_redact_telegram, text)
# Private key blocks
@@ -145,6 +143,7 @@ def redact_sensitive_text(text: str) -> str:
if len(phone) <= 8:
return phone[:2] + "****" + phone[-2:]
return phone[:4] + "****" + phone[-4:]
text = _SIGNAL_PHONE_RE.sub(_redact_phone, text)
return text
@@ -153,7 +152,7 @@ def redact_sensitive_text(text: str) -> str:
class RedactingFormatter(logging.Formatter):
"""Log formatter that redacts secrets from all log messages."""
def __init__(self, fmt=None, datefmt=None, style='%', **kwargs):
def __init__(self, fmt=None, datefmt=None, style="%", **kwargs):
super().__init__(fmt, datefmt, style, **kwargs)
def format(self, record: logging.LogRecord) -> str:

View File

@@ -6,14 +6,14 @@ can invoke skills via /skill-name commands.
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any
logger = logging.getLogger(__name__)
_skill_commands: Dict[str, Dict[str, Any]] = {}
_skill_commands: dict[str, dict[str, Any]] = {}
def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
def scan_skill_commands() -> dict[str, dict[str, Any]]:
"""Scan ~/.hermes/skills/ and return a mapping of /command -> skill info.
Returns:
@@ -23,26 +23,27 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
_skill_commands = {}
try:
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
if not SKILLS_DIR.exists():
return _skill_commands
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
if any(part in (".git", ".github", ".hub") for part in skill_md.parts):
continue
try:
content = skill_md.read_text(encoding='utf-8')
content = skill_md.read_text(encoding="utf-8")
frontmatter, body = _parse_frontmatter(content)
# Skip skills incompatible with the current OS platform
if not skill_matches_platform(frontmatter):
continue
name = frontmatter.get('name', skill_md.parent.name)
description = frontmatter.get('description', '')
name = frontmatter.get("name", skill_md.parent.name)
description = frontmatter.get("description", "")
if not description:
for line in body.strip().split('\n'):
for line in body.strip().split("\n"):
line = line.strip()
if line and not line.startswith('#'):
if line and not line.startswith("#"):
description = line[:80]
break
cmd_name = name.lower().replace(' ', '-').replace('_', '-')
cmd_name = name.lower().replace(" ", "-").replace("_", "-")
_skill_commands[f"/{cmd_name}"] = {
"name": name,
"description": description or f"Invoke the {name} skill",
@@ -56,14 +57,14 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
return _skill_commands
def get_skill_commands() -> Dict[str, Dict[str, Any]]:
def get_skill_commands() -> dict[str, dict[str, Any]]:
"""Return the current skill commands mapping (scan first if empty)."""
if not _skill_commands:
scan_skill_commands()
return _skill_commands
def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> Optional[str]:
def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> str | None:
"""Build the user message content for a skill slash command invocation.
Args:
@@ -83,7 +84,7 @@ def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") ->
skill_name = skill_info["name"]
try:
content = skill_md_path.read_text(encoding='utf-8')
content = skill_md_path.read_text(encoding="utf-8")
except Exception:
return f"[Failed to load skill: {skill_name}]"
@@ -111,6 +112,8 @@ def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") ->
if user_instruction:
parts.append("")
parts.append(f"The user has provided the following instruction alongside the skill invocation: {user_instruction}")
parts.append(
f"The user has provided the following instruction alongside the skill invocation: {user_instruction}"
)
return "\n".join(parts)

View File

@@ -8,7 +8,7 @@ the file-write logic live here.
import json
import logging
from datetime import datetime
from typing import Any, Dict, List
from typing import Any
logger = logging.getLogger(__name__)
@@ -27,8 +27,7 @@ def has_incomplete_scratchpad(content: str) -> bool:
return "<REASONING_SCRATCHPAD>" in content and "</REASONING_SCRATCHPAD>" not in content
def save_trajectory(trajectory: List[Dict[str, Any]], model: str,
completed: bool, filename: str = None):
def save_trajectory(trajectory: list[dict[str, Any]], model: str, completed: bool, filename: str = None):
"""Append a trajectory entry to a JSONL file.
Args:

File diff suppressed because it is too large Load Diff

1190
cli.py

File diff suppressed because it is too large Load Diff

View File

@@ -15,18 +15,18 @@ duplicate execution if multiple processes overlap.
"""
from cron.jobs import (
JOBS_FILE,
create_job,
get_job,
list_jobs,
remove_job,
update_job,
JOBS_FILE,
)
from cron.scheduler import tick
__all__ = [
"create_job",
"get_job",
"get_job",
"list_jobs",
"remove_job",
"update_job",

View File

@@ -6,18 +6,19 @@ Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
"""
import json
import tempfile
import os
import re
import tempfile
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, List, Any
from typing import Any
from hermes_time import now as _hermes_now
try:
from croniter import croniter
HAS_CRONITER = True
except ImportError:
HAS_CRONITER = False
@@ -42,37 +43,38 @@ def ensure_dirs():
# Schedule Parsing
# =============================================================================
def parse_duration(s: str) -> int:
"""
Parse duration string into minutes.
Examples:
"30m" → 30
"2h" → 120
"1d" → 1440
"""
s = s.strip().lower()
match = re.match(r'^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$', s)
match = re.match(r"^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$", s)
if not match:
raise ValueError(f"Invalid duration: '{s}'. Use format like '30m', '2h', or '1d'")
value = int(match.group(1))
unit = match.group(2)[0] # First char: m, h, or d
multipliers = {'m': 1, 'h': 60, 'd': 1440}
multipliers = {"m": 1, "h": 60, "d": 1440}
return value * multipliers[unit]
def parse_schedule(schedule: str) -> Dict[str, Any]:
def parse_schedule(schedule: str) -> dict[str, Any]:
"""
Parse schedule string into structured format.
Returns dict with:
- kind: "once" | "interval" | "cron"
- For "once": "run_at" (ISO timestamp)
- For "interval": "minutes" (int)
- For "cron": "expr" (cron expression)
Examples:
"30m" → once in 30 minutes
"2h" → once in 2 hours
@@ -84,23 +86,17 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
schedule = schedule.strip()
original = schedule
schedule_lower = schedule.lower()
# "every X" pattern → recurring interval
if schedule_lower.startswith("every "):
duration_str = schedule[6:].strip()
minutes = parse_duration(duration_str)
return {
"kind": "interval",
"minutes": minutes,
"display": f"every {minutes}m"
}
return {"kind": "interval", "minutes": minutes, "display": f"every {minutes}m"}
# Check for cron expression (5 or 6 space-separated fields)
# Cron fields: minute hour day month weekday [year]
parts = schedule.split()
if len(parts) >= 5 and all(
re.match(r'^[\d\*\-,/]+$', p) for p in parts[:5]
):
if len(parts) >= 5 and all(re.match(r"^[\d\*\-,/]+$", p) for p in parts[:5]):
if not HAS_CRONITER:
raise ValueError("Cron expressions require 'croniter' package. Install with: pip install croniter")
# Validate cron expression
@@ -108,37 +104,25 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
croniter(schedule)
except Exception as e:
raise ValueError(f"Invalid cron expression '{schedule}': {e}")
return {
"kind": "cron",
"expr": schedule,
"display": schedule
}
return {"kind": "cron", "expr": schedule, "display": schedule}
# ISO timestamp (contains T or looks like date)
if 'T' in schedule or re.match(r'^\d{4}-\d{2}-\d{2}', schedule):
if "T" in schedule or re.match(r"^\d{4}-\d{2}-\d{2}", schedule):
try:
# Parse and validate
dt = datetime.fromisoformat(schedule.replace('Z', '+00:00'))
return {
"kind": "once",
"run_at": dt.isoformat(),
"display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"
}
dt = datetime.fromisoformat(schedule.replace("Z", "+00:00"))
return {"kind": "once", "run_at": dt.isoformat(), "display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"}
except ValueError as e:
raise ValueError(f"Invalid timestamp '{schedule}': {e}")
# Duration like "30m", "2h", "1d" → one-shot from now
try:
minutes = parse_duration(schedule)
run_at = _hermes_now() + timedelta(minutes=minutes)
return {
"kind": "once",
"run_at": run_at.isoformat(),
"display": f"once in {original}"
}
return {"kind": "once", "run_at": run_at.isoformat(), "display": f"once in {original}"}
except ValueError:
pass
raise ValueError(
f"Invalid schedule '{original}'. Use:\n"
f" - Duration: '30m', '2h', '1d' (one-shot)\n"
@@ -161,7 +145,7 @@ def _ensure_aware(dt: datetime) -> datetime:
return dt
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
def compute_next_run(schedule: dict[str, Any], last_run_at: str | None = None) -> str | None:
"""
Compute the next run time for a schedule.
@@ -199,26 +183,27 @@ def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None
# Job CRUD Operations
# =============================================================================
def load_jobs() -> List[Dict[str, Any]]:
def load_jobs() -> list[dict[str, Any]]:
"""Load all jobs from storage."""
ensure_dirs()
if not JOBS_FILE.exists():
return []
try:
with open(JOBS_FILE, 'r', encoding='utf-8') as f:
with open(JOBS_FILE, encoding="utf-8") as f:
data = json.load(f)
return data.get("jobs", [])
except (json.JSONDecodeError, IOError):
except (OSError, json.JSONDecodeError):
return []
def save_jobs(jobs: List[Dict[str, Any]]):
def save_jobs(jobs: list[dict[str, Any]]):
"""Save all jobs to storage."""
ensure_dirs()
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_')
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix=".tmp", prefix=".jobs_")
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
f.flush()
os.fsync(f.fileno())
@@ -234,14 +219,14 @@ def save_jobs(jobs: List[Dict[str, Any]]):
def create_job(
prompt: str,
schedule: str,
name: Optional[str] = None,
repeat: Optional[int] = None,
deliver: Optional[str] = None,
origin: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
name: str | None = None,
repeat: int | None = None,
deliver: str | None = None,
origin: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Create a new cron job.
Args:
prompt: The prompt to run (must be self-contained)
schedule: Schedule string (see parse_schedule)
@@ -249,23 +234,23 @@ def create_job(
repeat: How many times to run (None = forever, 1 = once)
deliver: Where to deliver output ("origin", "local", "telegram", etc.)
origin: Source info where job was created (for "origin" delivery)
Returns:
The created job dict
"""
parsed_schedule = parse_schedule(schedule)
# Auto-set repeat=1 for one-shot schedules if not specified
if parsed_schedule["kind"] == "once" and repeat is None:
repeat = 1
# Default delivery to origin if available, otherwise local
if deliver is None:
deliver = "origin" if origin else "local"
job_id = uuid.uuid4().hex[:12]
now = _hermes_now().isoformat()
job = {
"id": job_id,
"name": name or prompt[:50].strip(),
@@ -274,7 +259,7 @@ def create_job(
"schedule_display": parsed_schedule.get("display", schedule),
"repeat": {
"times": repeat, # None = forever
"completed": 0
"completed": 0,
},
"enabled": True,
"created_at": now,
@@ -286,15 +271,15 @@ def create_job(
"deliver": deliver,
"origin": origin, # Tracks where job was created for "origin" delivery
}
jobs = load_jobs()
jobs.append(job)
save_jobs(jobs)
return job
def get_job(job_id: str) -> Optional[Dict[str, Any]]:
def get_job(job_id: str) -> dict[str, Any] | None:
"""Get a job by ID."""
jobs = load_jobs()
for job in jobs:
@@ -303,7 +288,7 @@ def get_job(job_id: str) -> Optional[Dict[str, Any]]:
return None
def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
def list_jobs(include_disabled: bool = False) -> list[dict[str, Any]]:
"""List all jobs, optionally including disabled ones."""
jobs = load_jobs()
if not include_disabled:
@@ -311,7 +296,7 @@ def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
return jobs
def update_job(job_id: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
def update_job(job_id: str, updates: dict[str, Any]) -> dict[str, Any] | None:
"""Update a job by ID."""
jobs = load_jobs()
for i, job in enumerate(jobs):
@@ -333,10 +318,10 @@ def remove_job(job_id: str) -> bool:
return False
def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
def mark_job_run(job_id: str, success: bool, error: str | None = None):
"""
Mark a job as having been run.
Updates last_run_at, last_status, increments completed count,
computes next_run_at, and auto-deletes if repeat limit reached.
"""
@@ -347,11 +332,11 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
job["last_run_at"] = now
job["last_status"] = "ok" if success else "error"
job["last_error"] = error if not success else None
# Increment completed count
if job.get("repeat"):
job["repeat"]["completed"] = job["repeat"].get("completed", 0) + 1
# Check if we've hit the repeat limit
times = job["repeat"].get("times")
completed = job["repeat"]["completed"]
@@ -360,38 +345,38 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
jobs.pop(i)
save_jobs(jobs)
return
# Compute next run
job["next_run_at"] = compute_next_run(job["schedule"], now)
# If no next run (one-shot completed), disable
if job["next_run_at"] is None:
job["enabled"] = False
save_jobs(jobs)
return
save_jobs(jobs)
def get_due_jobs() -> List[Dict[str, Any]]:
def get_due_jobs() -> list[dict[str, Any]]:
"""Get all jobs that are due to run now."""
now = _hermes_now()
jobs = load_jobs()
due = []
for job in jobs:
if not job.get("enabled", True):
continue
next_run = job.get("next_run_at")
if not next_run:
continue
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
if next_run_dt <= now:
due.append(job)
return due
@@ -400,11 +385,11 @@ def save_job_output(job_id: str, output: str):
ensure_dirs()
job_output_dir = OUTPUT_DIR / job_id
job_output_dir.mkdir(parents=True, exist_ok=True)
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
output_file = job_output_dir / f"{timestamp}.md"
with open(output_file, 'w', encoding='utf-8') as f:
with open(output_file, "w", encoding="utf-8") as f:
f.write(output)
return output_file

View File

@@ -23,9 +23,7 @@ except ImportError:
import msvcrt
except ImportError:
msvcrt = None
from datetime import datetime
from pathlib import Path
from typing import Optional
from hermes_time import now as _hermes_now
@@ -44,7 +42,7 @@ _LOCK_DIR = _hermes_home / "cron"
_LOCK_FILE = _LOCK_DIR / ".tick.lock"
def _resolve_origin(job: dict) -> Optional[dict]:
def _resolve_origin(job: dict) -> dict | None:
"""Extract origin info from a job, returning {platform, chat_id, chat_name} or None."""
origin = job.get("origin")
if not origin:
@@ -87,11 +85,16 @@ def _deliver_result(job: dict, content: str) -> None:
# Fall back to home channel
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
if not chat_id:
logger.warning("Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL <channel_id>", job["id"], deliver, platform_name.upper())
logger.warning(
"Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL <channel_id>",
job["id"],
deliver,
platform_name.upper(),
)
return
from gateway.config import Platform, load_gateway_config
from tools.send_message_tool import _send_to_platform
from gateway.config import load_gateway_config, Platform
platform_map = {
"telegram": Platform.TELEGRAM,
@@ -123,6 +126,7 @@ def _deliver_result(job: dict, content: str) -> None:
# asyncio.run() fails if there's already a running loop in this thread;
# spin up a new thread to avoid that.
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content))
result = future.result(timeout=30)
@@ -137,25 +141,26 @@ def _deliver_result(job: dict, content: str) -> None:
# Mirror the delivered content into the target's gateway session
try:
from gateway.mirror import mirror_to_session
mirror_to_session(platform_name, chat_id, content, source_label="cron")
except Exception:
pass
def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
def run_job(job: dict) -> tuple[bool, str, str, str | None]:
"""
Execute a single cron job.
Returns:
Tuple of (success, full_output_doc, final_response, error_message)
"""
from run_agent import AIAgent
job_id = job["id"]
job_name = job["name"]
prompt = job["prompt"]
origin = _resolve_origin(job)
logger.info("Running job '%s' (ID: %s)", job_name, job_id)
logger.info("Prompt: %s", prompt[:100])
@@ -170,6 +175,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
# Re-read .env and config.yaml fresh every run so provider/key
# changes take effect without a gateway restart.
from dotenv import load_dotenv
try:
load_dotenv(str(_hermes_home / ".env"), override=True, encoding="utf-8")
except UnicodeDecodeError:
@@ -181,6 +187,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
_cfg = {}
try:
import yaml
_cfg_path = str(_hermes_home / "config.yaml")
if os.path.exists(_cfg_path):
with open(_cfg_path) as _f:
@@ -210,12 +217,13 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "")
if prefill_file:
import json as _json
pfpath = Path(prefill_file).expanduser()
if not pfpath.is_absolute():
pfpath = _hermes_home / pfpath
if pfpath.exists():
try:
with open(pfpath, "r", encoding="utf-8") as _pf:
with open(pfpath, encoding="utf-8") as _pf:
prefill_messages = _json.load(_pf)
if not isinstance(prefill_messages, list):
prefill_messages = None
@@ -229,9 +237,10 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
pr = _cfg.get("provider_routing", {})
from hermes_cli.runtime_provider import (
resolve_runtime_provider,
format_runtime_provider_error,
resolve_runtime_provider,
)
try:
runtime = resolve_runtime_provider(
requested=os.getenv("HERMES_INFERENCE_PROVIDER"),
@@ -254,20 +263,20 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
providers_order=pr.get("order"),
provider_sort=pr.get("sort"),
quiet_mode=True,
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}",
)
result = agent.run_conversation(prompt)
final_response = result.get("final_response", "")
if not final_response:
final_response = "(No response generated)"
output = f"""# Cron Job: {job_name}
**Job ID:** {job_id}
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
**Schedule:** {job.get('schedule_display', 'N/A')}
**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")}
**Schedule:** {job.get("schedule_display", "N/A")}
## Prompt
@@ -277,19 +286,19 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
{final_response}
"""
logger.info("Job '%s' completed successfully", job_name)
return True, output, final_response, None
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}"
logger.error("Job '%s' failed: %s", job_name, error_msg)
output = f"""# Cron Job: {job_name} (FAILED)
**Job ID:** {job_id}
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
**Schedule:** {job.get('schedule_display', 'N/A')}
**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")}
**Schedule:** {job.get("schedule_display", "N/A")}
## Prompt
@@ -314,13 +323,13 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
def tick(verbose: bool = True) -> int:
"""
Check and run all due jobs.
Uses a file lock so only one tick runs at a time, even if the gateway's
in-process ticker and a standalone daemon or manual tick overlap.
Args:
verbose: Whether to print status messages
Returns:
Number of jobs executed (0 if another tick is already running)
"""
@@ -334,7 +343,7 @@ def tick(verbose: bool = True) -> int:
fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
elif msvcrt:
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_NBLCK, 1)
except (OSError, IOError):
except OSError:
logger.debug("Tick skipped — another instance holds the lock")
if lock_fd is not None:
lock_fd.close()
@@ -344,11 +353,11 @@ def tick(verbose: bool = True) -> int:
due_jobs = get_due_jobs()
if verbose and not due_jobs:
logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S'))
logger.info("%s - No jobs due", _hermes_now().strftime("%H:%M:%S"))
return 0
if verbose:
logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs))
logger.info("%s - %s job(s) due", _hermes_now().strftime("%H:%M:%S"), len(due_jobs))
executed = 0
for job in due_jobs:
@@ -360,7 +369,9 @@ def tick(verbose: bool = True) -> int:
logger.info("Output saved to: %s", output_file)
# Deliver the final response to the origin/target chat
deliver_content = final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}"
deliver_content = (
final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}"
)
if deliver_content:
try:
_deliver_result(job, deliver_content)
@@ -371,7 +382,7 @@ def tick(verbose: bool = True) -> int:
executed += 1
except Exception as e:
logger.error("Error processing job %s: %s", job['id'], e)
logger.error("Error processing job %s: %s", job["id"], e)
mark_job_run(job["id"], False, str(e))
return executed
@@ -381,7 +392,7 @@ def tick(verbose: bool = True) -> int:
elif msvcrt:
try:
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_UNLCK, 1)
except (OSError, IOError):
except OSError:
pass
lock_fd.close()

View File

@@ -9,19 +9,18 @@ to various messaging platforms (Telegram, Discord, WhatsApp) with:
- Platform-specific toolsets (different capabilities per platform)
"""
from .config import GatewayConfig, PlatformConfig, HomeChannel, load_gateway_config
from .config import GatewayConfig, HomeChannel, PlatformConfig, SessionResetPolicy, load_gateway_config
from .delivery import DeliveryRouter, DeliveryTarget
from .session import (
SessionContext,
SessionStore,
SessionResetPolicy,
build_session_context_prompt,
)
from .delivery import DeliveryRouter, DeliveryTarget
__all__ = [
# Config
"GatewayConfig",
"PlatformConfig",
"PlatformConfig",
"HomeChannel",
"load_gateway_config",
# Session

View File

@@ -10,7 +10,7 @@ import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
logger = logging.getLogger(__name__)
@@ -21,7 +21,8 @@ DIRECTORY_PATH = Path.home() / ".hermes" / "channel_directory.json"
# Build / refresh
# ---------------------------------------------------------------------------
def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
def build_channel_directory(adapters: dict[Any, Any]) -> dict[str, Any]:
"""
Build a channel directory from connected platform adapters and session data.
@@ -29,7 +30,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
"""
from gateway.config import Platform
platforms: Dict[str, List[Dict[str, str]]] = {}
platforms: dict[str, list[dict[str, str]]] = {}
for platform, adapter in adapters.items():
try:
@@ -60,7 +61,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
return directory
def _build_discord(adapter) -> List[Dict[str, str]]:
def _build_discord(adapter) -> list[dict[str, str]]:
"""Enumerate all text channels the Discord bot can see."""
channels = []
client = getattr(adapter, "_client", None)
@@ -74,12 +75,14 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
for guild in client.guilds:
for ch in guild.text_channels:
channels.append({
"id": str(ch.id),
"name": ch.name,
"guild": guild.name,
"type": "channel",
})
channels.append(
{
"id": str(ch.id),
"name": ch.name,
"guild": guild.name,
"type": "channel",
}
)
# Also include DM-capable users we've interacted with is not
# feasible via guild enumeration; those come from sessions.
@@ -88,7 +91,7 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
return channels
def _build_slack(adapter) -> List[Dict[str, str]]:
def _build_slack(adapter) -> list[dict[str, str]]:
"""List Slack channels the bot has joined."""
channels = []
# Slack adapter may expose a web client
@@ -97,7 +100,6 @@ def _build_slack(adapter) -> List[Dict[str, str]]:
return _build_from_sessions("slack")
try:
import asyncio
from tools.send_message_tool import _send_slack # noqa: F401
# Use the Slack Web API directly if available
except Exception:
@@ -107,7 +109,7 @@ def _build_slack(adapter) -> List[Dict[str, str]]:
return _build_from_sessions("slack")
def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
def _build_from_sessions(platform_name: str) -> list[dict[str, str]]:
"""Pull known channels/contacts from sessions.json origin data."""
sessions_path = Path.home() / ".hermes" / "sessions" / "sessions.json"
if not sessions_path.exists():
@@ -127,11 +129,13 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
if not chat_id or chat_id in seen_ids:
continue
seen_ids.add(chat_id)
entries.append({
"id": str(chat_id),
"name": origin.get("chat_name") or origin.get("user_name") or str(chat_id),
"type": session.get("chat_type", "dm"),
})
entries.append(
{
"id": str(chat_id),
"name": origin.get("chat_name") or origin.get("user_name") or str(chat_id),
"type": session.get("chat_type", "dm"),
}
)
except Exception as e:
logger.debug("Channel directory: failed to read sessions for %s: %s", platform_name, e)
@@ -142,7 +146,8 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
# Read / resolve
# ---------------------------------------------------------------------------
def load_directory() -> Dict[str, Any]:
def load_directory() -> dict[str, Any]:
"""Load the cached channel directory from disk."""
if not DIRECTORY_PATH.exists():
return {"updated_at": None, "platforms": {}}
@@ -153,7 +158,7 @@ def load_directory() -> Dict[str, Any]:
return {"updated_at": None, "platforms": {}}
def resolve_channel_name(platform_name: str, name: str) -> Optional[str]:
def resolve_channel_name(platform_name: str, name: str) -> str | None:
"""
Resolve a human-friendly channel name to a numeric ID.
@@ -206,8 +211,8 @@ def format_directory_for_display() -> str:
# Group Discord channels by guild
if plat_name == "discord":
guilds: Dict[str, List] = {}
dms: List = []
guilds: dict[str, list] = {}
dms: list = []
for ch in channels:
guild = ch.get("guild")
if guild:

View File

@@ -8,19 +8,20 @@ Handles loading and validating configuration for:
- Delivery preferences
"""
import json
import logging
import os
import json
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from enum import Enum
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
class Platform(Enum):
"""Supported messaging platforms."""
LOCAL = "local"
TELEGRAM = "telegram"
DISCORD = "discord"
@@ -34,23 +35,24 @@ class Platform(Enum):
class HomeChannel:
"""
Default destination for a platform.
When a cron job specifies deliver="telegram" without a specific chat ID,
messages are sent to this home channel.
"""
platform: Platform
chat_id: str
name: str # Human-readable name for display
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return {
"platform": self.platform.value,
"chat_id": self.chat_id,
"name": self.name,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HomeChannel":
def from_dict(cls, data: dict[str, Any]) -> "HomeChannel":
return cls(
platform=Platform(data["platform"]),
chat_id=str(data["chat_id"]),
@@ -62,26 +64,27 @@ class HomeChannel:
class SessionResetPolicy:
"""
Controls when sessions reset (lose context).
Modes:
- "daily": Reset at a specific hour each day
- "idle": Reset after N minutes of inactivity
- "both": Whichever triggers first (daily boundary OR idle timeout)
- "none": Never auto-reset (context managed only by compression)
"""
mode: str = "both" # "daily", "idle", "both", or "none"
at_hour: int = 4 # Hour for daily reset (0-23, local time)
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return {
"mode": self.mode,
"at_hour": self.at_hour,
"idle_minutes": self.idle_minutes,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SessionResetPolicy":
def from_dict(cls, data: dict[str, Any]) -> "SessionResetPolicy":
return cls(
mode=data.get("mode", "both"),
at_hour=data.get("at_hour", 4),
@@ -92,15 +95,16 @@ class SessionResetPolicy:
@dataclass
class PlatformConfig:
"""Configuration for a single messaging platform."""
enabled: bool = False
token: Optional[str] = None # Bot token (Telegram, Discord)
api_key: Optional[str] = None # API key if different from token
home_channel: Optional[HomeChannel] = None
token: str | None = None # Bot token (Telegram, Discord)
api_key: str | None = None # API key if different from token
home_channel: HomeChannel | None = None
# Platform-specific settings
extra: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
extra: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
result = {
"enabled": self.enabled,
"extra": self.extra,
@@ -112,13 +116,13 @@ class PlatformConfig:
if self.home_channel:
result["home_channel"] = self.home_channel.to_dict()
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PlatformConfig":
def from_dict(cls, data: dict[str, Any]) -> "PlatformConfig":
home_channel = None
if "home_channel" in data:
home_channel = HomeChannel.from_dict(data["home_channel"])
return cls(
enabled=data.get("enabled", False),
token=data.get("token"),
@@ -132,89 +136,80 @@ class PlatformConfig:
class GatewayConfig:
"""
Main gateway configuration.
Manages all platform connections, session policies, and delivery settings.
"""
# Platform configurations
platforms: Dict[Platform, PlatformConfig] = field(default_factory=dict)
platforms: dict[Platform, PlatformConfig] = field(default_factory=dict)
# Session reset policies by type
default_reset_policy: SessionResetPolicy = field(default_factory=SessionResetPolicy)
reset_by_type: Dict[str, SessionResetPolicy] = field(default_factory=dict)
reset_by_platform: Dict[Platform, SessionResetPolicy] = field(default_factory=dict)
reset_by_type: dict[str, SessionResetPolicy] = field(default_factory=dict)
reset_by_platform: dict[Platform, SessionResetPolicy] = field(default_factory=dict)
# Reset trigger commands
reset_triggers: List[str] = field(default_factory=lambda: ["/new", "/reset"])
reset_triggers: list[str] = field(default_factory=lambda: ["/new", "/reset"])
# Storage paths
sessions_dir: Path = field(default_factory=lambda: Path.home() / ".hermes" / "sessions")
# Delivery settings
always_log_local: bool = True # Always save cron outputs to local files
def get_connected_platforms(self) -> List[Platform]:
def get_connected_platforms(self) -> list[Platform]:
"""Return list of platforms that are enabled and configured."""
connected = []
for platform, config in self.platforms.items():
if not config.enabled:
continue
# Platforms that use token/api_key auth
if config.token or config.api_key:
connected.append(platform)
# WhatsApp uses enabled flag only (bridge handles auth)
elif platform == Platform.WHATSAPP:
connected.append(platform)
# Signal uses extra dict for config (http_url + account)
elif platform == Platform.SIGNAL and config.extra.get("http_url"):
if (
config.token
or config.api_key
or platform == Platform.WHATSAPP
or platform == Platform.SIGNAL
and config.extra.get("http_url")
):
connected.append(platform)
return connected
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
def get_home_channel(self, platform: Platform) -> HomeChannel | None:
"""Get the home channel for a platform."""
config = self.platforms.get(platform)
if config:
return config.home_channel
return None
def get_reset_policy(
self,
platform: Optional[Platform] = None,
session_type: Optional[str] = None
) -> SessionResetPolicy:
def get_reset_policy(self, platform: Platform | None = None, session_type: str | None = None) -> SessionResetPolicy:
"""
Get the appropriate reset policy for a session.
Priority: platform override > type override > default
"""
# Platform-specific override takes precedence
if platform and platform in self.reset_by_platform:
return self.reset_by_platform[platform]
# Type-specific override (dm, group, thread)
if session_type and session_type in self.reset_by_type:
return self.reset_by_type[session_type]
return self.default_reset_policy
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return {
"platforms": {
p.value: c.to_dict() for p, c in self.platforms.items()
},
"platforms": {p.value: c.to_dict() for p, c in self.platforms.items()},
"default_reset_policy": self.default_reset_policy.to_dict(),
"reset_by_type": {
k: v.to_dict() for k, v in self.reset_by_type.items()
},
"reset_by_platform": {
p.value: v.to_dict() for p, v in self.reset_by_platform.items()
},
"reset_by_type": {k: v.to_dict() for k, v in self.reset_by_type.items()},
"reset_by_platform": {p.value: v.to_dict() for p, v in self.reset_by_platform.items()},
"reset_triggers": self.reset_triggers,
"sessions_dir": str(self.sessions_dir),
"always_log_local": self.always_log_local,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GatewayConfig":
def from_dict(cls, data: dict[str, Any]) -> "GatewayConfig":
platforms = {}
for platform_name, platform_data in data.get("platforms", {}).items():
try:
@@ -222,11 +217,11 @@ class GatewayConfig:
platforms[platform] = PlatformConfig.from_dict(platform_data)
except ValueError:
pass # Skip unknown platforms
reset_by_type = {}
for type_name, policy_data in data.get("reset_by_type", {}).items():
reset_by_type[type_name] = SessionResetPolicy.from_dict(policy_data)
reset_by_platform = {}
for platform_name, policy_data in data.get("reset_by_platform", {}).items():
try:
@@ -234,15 +229,15 @@ class GatewayConfig:
reset_by_platform[platform] = SessionResetPolicy.from_dict(policy_data)
except ValueError:
pass
default_policy = SessionResetPolicy()
if "default_reset_policy" in data:
default_policy = SessionResetPolicy.from_dict(data["default_reset_policy"])
sessions_dir = Path.home() / ".hermes" / "sessions"
if "sessions_dir" in data:
sessions_dir = Path(data["sessions_dir"])
return cls(
platforms=platforms,
default_reset_policy=default_policy,
@@ -257,7 +252,7 @@ class GatewayConfig:
def load_gateway_config() -> GatewayConfig:
"""
Load gateway configuration from multiple sources.
Priority (highest to lowest):
1. Environment variables
2. ~/.hermes/gateway.json
@@ -265,22 +260,23 @@ def load_gateway_config() -> GatewayConfig:
4. Defaults
"""
config = GatewayConfig()
# Try loading from ~/.hermes/gateway.json
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
if gateway_config_path.exists():
try:
with open(gateway_config_path, "r") as f:
with open(gateway_config_path) as f:
data = json.load(f)
config = GatewayConfig.from_dict(data)
except Exception as e:
print(f"[gateway] Warning: Failed to load {gateway_config_path}: {e}")
# Bridge session_reset from config.yaml (the user-facing config file)
# into the gateway config. config.yaml takes precedence over gateway.json
# for session reset policy since that's where hermes setup writes it.
try:
import yaml
config_yaml_path = Path.home() / ".hermes" / "config.yaml"
if config_yaml_path.exists():
with open(config_yaml_path) as f:
@@ -293,14 +289,12 @@ def load_gateway_config() -> GatewayConfig:
# Override with environment variables
_apply_env_overrides(config)
# --- Validate loaded values ---
policy = config.default_reset_policy
if not (0 <= policy.at_hour <= 23):
logger.warning(
"Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour
)
logger.warning("Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour)
policy.at_hour = 4
if policy.idle_minutes is None or policy.idle_minutes <= 0:
@@ -323,9 +317,9 @@ def load_gateway_config() -> GatewayConfig:
env_name = _token_env_names.get(platform)
if env_name and pconfig.token is not None and not pconfig.token.strip():
logger.warning(
"%s is enabled but %s is empty. "
"The adapter will likely fail to connect.",
platform.value, env_name,
"%s is enabled but %s is empty. The adapter will likely fail to connect.",
platform.value,
env_name,
)
return config
@@ -333,7 +327,7 @@ def load_gateway_config() -> GatewayConfig:
def _apply_env_overrides(config: GatewayConfig) -> None:
"""Apply environment variable overrides to config."""
# Telegram
telegram_token = os.getenv("TELEGRAM_BOT_TOKEN")
if telegram_token:
@@ -341,7 +335,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
config.platforms[Platform.TELEGRAM] = PlatformConfig()
config.platforms[Platform.TELEGRAM].enabled = True
config.platforms[Platform.TELEGRAM].token = telegram_token
telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL")
if telegram_home and Platform.TELEGRAM in config.platforms:
config.platforms[Platform.TELEGRAM].home_channel = HomeChannel(
@@ -349,7 +343,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
chat_id=telegram_home,
name=os.getenv("TELEGRAM_HOME_CHANNEL_NAME", "Home"),
)
# Discord
discord_token = os.getenv("DISCORD_BOT_TOKEN")
if discord_token:
@@ -357,7 +351,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
config.platforms[Platform.DISCORD] = PlatformConfig()
config.platforms[Platform.DISCORD].enabled = True
config.platforms[Platform.DISCORD].token = discord_token
discord_home = os.getenv("DISCORD_HOME_CHANNEL")
if discord_home and Platform.DISCORD in config.platforms:
config.platforms[Platform.DISCORD].home_channel = HomeChannel(
@@ -365,14 +359,14 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
chat_id=discord_home,
name=os.getenv("DISCORD_HOME_CHANNEL_NAME", "Home"),
)
# WhatsApp (typically uses different auth mechanism)
whatsapp_enabled = os.getenv("WHATSAPP_ENABLED", "").lower() in ("true", "1", "yes")
if whatsapp_enabled:
if Platform.WHATSAPP not in config.platforms:
config.platforms[Platform.WHATSAPP] = PlatformConfig()
config.platforms[Platform.WHATSAPP].enabled = True
# Slack
slack_token = os.getenv("SLACK_BOT_TOKEN")
if slack_token:
@@ -388,7 +382,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
chat_id=slack_home,
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
)
# Signal
signal_url = os.getenv("SIGNAL_HTTP_URL")
signal_account = os.getenv("SIGNAL_ACCOUNT")
@@ -396,11 +390,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
if Platform.SIGNAL not in config.platforms:
config.platforms[Platform.SIGNAL] = PlatformConfig()
config.platforms[Platform.SIGNAL].enabled = True
config.platforms[Platform.SIGNAL].extra.update({
"http_url": signal_url,
"account": signal_account,
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
})
config.platforms[Platform.SIGNAL].extra.update(
{
"http_url": signal_url,
"account": signal_account,
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
}
)
signal_home = os.getenv("SIGNAL_HOME_CHANNEL")
if signal_home:
config.platforms[Platform.SIGNAL].home_channel = HomeChannel(
@@ -427,7 +423,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
config.default_reset_policy.idle_minutes = int(idle_minutes)
except ValueError:
pass
reset_hour = os.getenv("SESSION_RESET_HOUR")
if reset_hour:
try:
@@ -440,6 +436,6 @@ def save_gateway_config(config: GatewayConfig) -> None:
"""Save gateway configuration to ~/.hermes/gateway.json."""
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
gateway_config_path.parent.mkdir(parents=True, exist_ok=True)
with open(gateway_config_path, "w") as f:
json.dump(config.to_dict(), f, indent=2)

View File

@@ -9,18 +9,17 @@ Routes messages to the appropriate destination based on:
"""
import logging
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass
from typing import Dict, List, Optional, Any, Union
from enum import Enum
from datetime import datetime
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
MAX_PLATFORM_OUTPUT = 4000
TRUNCATED_VISIBLE = 3800
from .config import Platform, GatewayConfig
from .config import GatewayConfig, Platform
from .session import SessionSource
@@ -28,23 +27,24 @@ from .session import SessionSource
class DeliveryTarget:
"""
A single delivery target.
Represents where a message should be sent:
- "origin" → back to source
- "local" → save to local files
- "telegram" → Telegram home channel
- "telegram:123456" → specific Telegram chat
"""
platform: Platform
chat_id: Optional[str] = None # None means use home channel
chat_id: str | None = None # None means use home channel
is_origin: bool = False
is_explicit: bool = False # True if chat_id was explicitly specified
@classmethod
def parse(cls, target: str, origin: Optional[SessionSource] = None) -> "DeliveryTarget":
def parse(cls, target: str, origin: SessionSource | None = None) -> "DeliveryTarget":
"""
Parse a delivery target string.
Formats:
- "origin" → back to source
- "local" → local files only
@@ -52,7 +52,7 @@ class DeliveryTarget:
- "telegram:123456" → specific Telegram chat
"""
target = target.strip().lower()
if target == "origin":
if origin:
return cls(
@@ -63,10 +63,10 @@ class DeliveryTarget:
else:
# Fallback to local if no origin
return cls(platform=Platform.LOCAL, is_origin=True)
if target == "local":
return cls(platform=Platform.LOCAL)
# Check for platform:chat_id format
if ":" in target:
platform_str, chat_id = target.split(":", 1)
@@ -76,7 +76,7 @@ class DeliveryTarget:
except ValueError:
# Unknown platform, treat as local
return cls(platform=Platform.LOCAL)
# Just a platform name (use home channel)
try:
platform = Platform(target)
@@ -84,7 +84,7 @@ class DeliveryTarget:
except ValueError:
# Unknown platform, treat as local
return cls(platform=Platform.LOCAL)
def to_string(self) -> str:
"""Convert back to string format."""
if self.is_origin:
@@ -99,15 +99,15 @@ class DeliveryTarget:
class DeliveryRouter:
"""
Routes messages to appropriate destinations.
Handles the logic of resolving delivery targets and dispatching
messages to the right platform adapters.
"""
def __init__(self, config: GatewayConfig, adapters: Dict[Platform, Any] = None):
def __init__(self, config: GatewayConfig, adapters: dict[Platform, Any] = None):
"""
Initialize the delivery router.
Args:
config: Gateway configuration
adapters: Dict mapping platforms to their adapter instances
@@ -115,31 +115,27 @@ class DeliveryRouter:
self.config = config
self.adapters = adapters or {}
self.output_dir = Path.home() / ".hermes" / "cron" / "output"
def resolve_targets(
self,
deliver: Union[str, List[str]],
origin: Optional[SessionSource] = None
) -> List[DeliveryTarget]:
def resolve_targets(self, deliver: str | list[str], origin: SessionSource | None = None) -> list[DeliveryTarget]:
"""
Resolve delivery specification to concrete targets.
Args:
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
origin: The source where the request originated (for "origin" target)
Returns:
List of resolved delivery targets
"""
if isinstance(deliver, str):
deliver = [deliver]
targets = []
seen_platforms = set()
for target_str in deliver:
target = DeliveryTarget.parse(target_str, origin)
# Resolve home channel if needed
if target.chat_id is None and target.platform != Platform.LOCAL:
home = self.config.get_home_channel(target.platform)
@@ -148,109 +144,96 @@ class DeliveryRouter:
else:
# No home channel configured, skip this platform
continue
# Deduplicate
key = (target.platform, target.chat_id)
if key not in seen_platforms:
seen_platforms.add(key)
targets.append(target)
# Always include local if configured
if self.config.always_log_local:
local_key = (Platform.LOCAL, None)
if local_key not in seen_platforms:
targets.append(DeliveryTarget(platform=Platform.LOCAL))
return targets
async def deliver(
self,
content: str,
targets: List[DeliveryTarget],
job_id: Optional[str] = None,
job_name: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
targets: list[DeliveryTarget],
job_id: str | None = None,
job_name: str | None = None,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Deliver content to all specified targets.
Args:
content: The message/output to deliver
targets: List of delivery targets
job_id: Optional job ID (for cron jobs)
job_name: Optional job name
metadata: Additional metadata to include
Returns:
Dict with delivery results per target
"""
results = {}
for target in targets:
try:
if target.platform == Platform.LOCAL:
result = self._deliver_local(content, job_id, job_name, metadata)
else:
result = await self._deliver_to_platform(target, content, metadata)
results[target.to_string()] = {
"success": True,
"result": result
}
results[target.to_string()] = {"success": True, "result": result}
except Exception as e:
results[target.to_string()] = {
"success": False,
"error": str(e)
}
results[target.to_string()] = {"success": False, "error": str(e)}
return results
def _deliver_local(
self,
content: str,
job_id: Optional[str],
job_name: Optional[str],
metadata: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
self, content: str, job_id: str | None, job_name: str | None, metadata: dict[str, Any] | None
) -> dict[str, Any]:
"""Save content to local files."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if job_id:
output_path = self.output_dir / job_id / f"{timestamp}.md"
else:
output_path = self.output_dir / "misc" / f"{timestamp}.md"
output_path.parent.mkdir(parents=True, exist_ok=True)
# Build the output document
lines = []
if job_name:
lines.append(f"# {job_name}")
else:
lines.append("# Delivery Output")
lines.append("")
lines.append(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
if job_id:
lines.append(f"**Job ID:** {job_id}")
if metadata:
for key, value in metadata.items():
lines.append(f"**{key}:** {value}")
lines.append("")
lines.append("---")
lines.append("")
lines.append(content)
output_path.write_text("\n".join(lines))
return {
"path": str(output_path),
"timestamp": timestamp
}
return {"path": str(output_path), "timestamp": timestamp}
def _save_full_output(self, content: str, job_id: str) -> Path:
"""Save full cron output to disk and return the file path."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -261,41 +244,33 @@ class DeliveryRouter:
return path
async def _deliver_to_platform(
self,
target: DeliveryTarget,
content: str,
metadata: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
self, target: DeliveryTarget, content: str, metadata: dict[str, Any] | None
) -> dict[str, Any]:
"""Deliver content to a messaging platform."""
adapter = self.adapters.get(target.platform)
if not adapter:
raise ValueError(f"No adapter configured for {target.platform.value}")
if not target.chat_id:
raise ValueError(f"No chat ID for {target.platform.value} delivery")
# Guard: truncate oversized cron output to stay within platform limits
if len(content) > MAX_PLATFORM_OUTPUT:
job_id = (metadata or {}).get("job_id", "unknown")
saved_path = self._save_full_output(content, job_id)
logger.info("Cron output truncated (%d chars) — full output: %s", len(content), saved_path)
content = (
content[:TRUNCATED_VISIBLE]
+ f"\n\n... [truncated, full output saved to {saved_path}]"
)
content = content[:TRUNCATED_VISIBLE] + f"\n\n... [truncated, full output saved to {saved_path}]"
return await adapter.send(target.chat_id, content, metadata=metadata)
def parse_deliver_spec(
deliver: Optional[Union[str, List[str]]],
origin: Optional[SessionSource] = None,
default: str = "origin"
) -> Union[str, List[str]]:
deliver: str | list[str] | None, origin: SessionSource | None = None, default: str = "origin"
) -> str | list[str]:
"""
Normalize a delivery specification.
If None or empty, returns the default.
"""
if not deliver:
@@ -303,17 +278,14 @@ def parse_deliver_spec(
return deliver
def build_delivery_context_for_tool(
config: GatewayConfig,
origin: Optional[SessionSource] = None
) -> Dict[str, Any]:
def build_delivery_context_for_tool(config: GatewayConfig, origin: SessionSource | None = None) -> dict[str, Any]:
"""
Build context for the schedule_cronjob tool to understand delivery options.
This is passed to the tool so it can validate and explain delivery targets.
"""
connected = config.get_connected_platforms()
options = {
"origin": {
"description": "Back to where this job was created",
@@ -322,9 +294,9 @@ def build_delivery_context_for_tool(
"local": {
"description": "Save to local files only",
"available": True,
}
},
}
for platform in connected:
home = config.get_home_channel(platform)
options[platform.value] = {
@@ -332,7 +304,7 @@ def build_delivery_context_for_tool(
"available": True,
"home_channel": home.to_dict() if home else None,
}
return {
"origin": origin.to_dict() if origin else None,
"options": options,

View File

@@ -21,12 +21,12 @@ Errors in hooks are caught and logged but never block the main pipeline.
import asyncio
import importlib.util
import os
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import Any
import yaml
HOOKS_DIR = Path(os.path.expanduser("~/.hermes/hooks"))
@@ -42,11 +42,11 @@ class HookRegistry:
def __init__(self):
# event_type -> [handler_fn, ...]
self._handlers: Dict[str, List[Callable]] = {}
self._loaded_hooks: List[dict] = [] # metadata for listing
self._handlers: dict[str, list[Callable]] = {}
self._loaded_hooks: list[dict] = [] # metadata for listing
@property
def loaded_hooks(self) -> List[dict]:
def loaded_hooks(self) -> list[dict]:
"""Return metadata about all loaded hooks."""
return list(self._loaded_hooks)
@@ -84,9 +84,7 @@ class HookRegistry:
continue
# Dynamically load the handler module
spec = importlib.util.spec_from_file_location(
f"hermes_hook_{hook_name}", handler_path
)
spec = importlib.util.spec_from_file_location(f"hermes_hook_{hook_name}", handler_path)
if spec is None or spec.loader is None:
print(f"[hooks] Skipping {hook_name}: could not load handler.py", flush=True)
continue
@@ -103,19 +101,21 @@ class HookRegistry:
for event in events:
self._handlers.setdefault(event, []).append(handle_fn)
self._loaded_hooks.append({
"name": hook_name,
"description": manifest.get("description", ""),
"events": events,
"path": str(hook_dir),
})
self._loaded_hooks.append(
{
"name": hook_name,
"description": manifest.get("description", ""),
"events": events,
"path": str(hook_dir),
}
)
print(f"[hooks] Loaded hook '{hook_name}' for events: {events}", flush=True)
except Exception as e:
print(f"[hooks] Error loading hook {hook_dir.name}: {e}", flush=True)
async def emit(self, event_type: str, context: Optional[Dict[str, Any]] = None) -> None:
async def emit(self, event_type: str, context: dict[str, Any] | None = None) -> None:
"""
Fire all handlers registered for an event.

View File

@@ -13,7 +13,6 @@ import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
@@ -61,7 +60,7 @@ def mirror_to_session(
return False
def _find_session_id(platform: str, chat_id: str) -> Optional[str]:
def _find_session_id(platform: str, chat_id: str) -> str | None:
"""
Find the active session_id for a platform + chat_id pair.
@@ -113,6 +112,7 @@ def _append_to_sqlite(session_id: str, message: dict) -> None:
"""Append a message to the SQLite session database."""
try:
from hermes_state import SessionDB
db = SessionDB()
db.append_message(
session_id=session_id,

View File

@@ -23,21 +23,19 @@ import os
import secrets
import time
from pathlib import Path
from typing import Optional
# Unambiguous alphabet -- excludes 0/O, 1/I to prevent confusion
ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
CODE_LENGTH = 8
# Timing constants
CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour
RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes
LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures
CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour
RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes
LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures
# Limits
MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform
MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout
MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform
MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout
PAIRING_DIR = Path(os.path.expanduser("~/.hermes/pairing"))
@@ -123,9 +121,7 @@ class PairingStore:
# ----- Pending codes -----
def generate_code(
self, platform: str, user_id: str, user_name: str = ""
) -> Optional[str]:
def generate_code(self, platform: str, user_id: str, user_name: str = "") -> str | None:
"""
Generate a pairing code for a new user.
@@ -165,7 +161,7 @@ class PairingStore:
return code
def approve_code(self, platform: str, code: str) -> Optional[dict]:
def approve_code(self, platform: str, code: str) -> dict | None:
"""
Approve a pairing code. Adds the user to the approved list.
@@ -199,13 +195,15 @@ class PairingStore:
pending = self._load_json(self._pending_path(p))
for code, info in pending.items():
age_min = int((time.time() - info["created_at"]) / 60)
results.append({
"platform": p,
"code": code,
"user_id": info["user_id"],
"user_name": info.get("user_name", ""),
"age_minutes": age_min,
})
results.append(
{
"platform": p,
"code": code,
"user_id": info["user_id"],
"user_name": info.get("user_name", ""),
"age_minutes": age_min,
}
)
return results
def clear_pending(self, platform: str = None) -> int:
@@ -251,8 +249,11 @@ class PairingStore:
lockout_key = f"_lockout:{platform}"
limits[lockout_key] = time.time() + LOCKOUT_SECONDS
limits[fail_key] = 0 # Reset counter
print(f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s "
f"after {MAX_FAILED_ATTEMPTS} failed attempts", flush=True)
print(
f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s "
f"after {MAX_FAILED_ATTEMPTS} failed attempts",
flush=True,
)
self._save_json(self._rate_limit_path(), limits)
# ----- Cleanup -----
@@ -262,10 +263,7 @@ class PairingStore:
path = self._pending_path(platform)
pending = self._load_json(path)
now = time.time()
expired = [
code for code, info in pending.items()
if (now - info["created_at"]) > CODE_TTL_SECONDS
]
expired = [code for code, info in pending.items() if (now - info["created_at"]) > CODE_TTL_SECONDS]
if expired:
for code in expired:
del pending[code]

View File

@@ -303,8 +303,8 @@ Optional but valuable:
After implementing everything, verify with:
```bash
# All tests pass
python -m pytest tests/ -q
# All checks pass (lint + test)
make check
# Grep for your platform name to find any missed integration points
grep -r "telegram\|discord\|whatsapp\|slack" gateway/ tools/ agent/ cron/ hermes_cli/ toolsets.py \

View File

@@ -13,20 +13,20 @@ import uuid
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
import sys
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
from enum import Enum
import sys
from pathlib import Path
from pathlib import Path as _Path
from typing import Any
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
from gateway.session import SessionSource
# ---------------------------------------------------------------------------
# Image cache utilities
#
@@ -251,6 +251,7 @@ def cleanup_document_cache(max_age_hours: int = 24) -> int:
class MessageType(Enum):
"""Types of incoming messages."""
TEXT = "text"
LOCATION = "location"
PHOTO = "photo"
@@ -266,42 +267,43 @@ class MessageType(Enum):
class MessageEvent:
"""
Incoming message from a platform.
Normalized representation that all adapters produce.
"""
# Message content
text: str
message_type: MessageType = MessageType.TEXT
# Source information
source: SessionSource = None
# Original platform data
raw_message: Any = None
message_id: Optional[str] = None
message_id: str | None = None
# Media attachments
media_urls: List[str] = field(default_factory=list)
media_types: List[str] = field(default_factory=list)
media_urls: list[str] = field(default_factory=list)
media_types: list[str] = field(default_factory=list)
# Reply context
reply_to_message_id: Optional[str] = None
reply_to_message_id: str | None = None
# Timestamps
timestamp: datetime = field(default_factory=datetime.now)
def is_command(self) -> bool:
"""Check if this is a command message (e.g., /new, /reset)."""
return self.text.startswith("/")
def get_command(self) -> Optional[str]:
def get_command(self) -> str | None:
"""Extract command name if this is a command message."""
if not self.is_command():
return None
# Split on space and get first word, strip the /
parts = self.text.split(maxsplit=1)
return parts[0][1:].lower() if parts else None
def get_command_args(self) -> str:
"""Get the arguments after a command."""
if not self.is_command():
@@ -310,91 +312,88 @@ class MessageEvent:
return parts[1] if len(parts) > 1 else ""
@dataclass
@dataclass
class SendResult:
"""Result of sending a message."""
success: bool
message_id: Optional[str] = None
error: Optional[str] = None
message_id: str | None = None
error: str | None = None
raw_response: Any = None
# Type for message handlers
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
MessageHandler = Callable[[MessageEvent], Awaitable[str | None]]
class BasePlatformAdapter(ABC):
"""
Base class for platform adapters.
Subclasses implement platform-specific logic for:
- Connecting and authenticating
- Receiving messages
- Sending messages/responses
- Handling media
"""
def __init__(self, config: PlatformConfig, platform: Platform):
self.config = config
self.platform = platform
self._message_handler: Optional[MessageHandler] = None
self._message_handler: MessageHandler | None = None
self._running = False
# Track active message handlers per session for interrupt support
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
self._active_sessions: Dict[str, asyncio.Event] = {}
self._pending_messages: Dict[str, MessageEvent] = {}
self._active_sessions: dict[str, asyncio.Event] = {}
self._pending_messages: dict[str, MessageEvent] = {}
@property
def name(self) -> str:
"""Human-readable name for this adapter."""
return self.platform.value.title()
@property
def is_connected(self) -> bool:
"""Check if adapter is currently connected."""
return self._running
def set_message_handler(self, handler: MessageHandler) -> None:
"""
Set the handler for incoming messages.
The handler receives a MessageEvent and should return
an optional response string.
"""
self._message_handler = handler
@abstractmethod
async def connect(self) -> bool:
"""
Connect to the platform and start receiving messages.
Returns True if connection was successful.
"""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Disconnect from the platform."""
pass
@abstractmethod
async def send(
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
) -> SendResult:
"""
Send a message to a chat.
Args:
chat_id: The chat/channel ID to send to
content: Message content (may be markdown)
reply_to: Optional message ID to reply to
metadata: Additional platform-specific options
Returns:
SendResult with success status and message ID
"""
@@ -416,21 +415,21 @@ class BasePlatformAdapter(ABC):
async def send_typing(self, chat_id: str) -> None:
"""
Send a typing indicator.
Override in subclasses if the platform supports it.
"""
pass
async def send_image(
self,
chat_id: str,
image_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""
Send an image natively via the platform API.
Override in subclasses to send images as proper attachments
instead of plain-text URLs. Default falls back to sending the
URL as a text message.
@@ -438,87 +437,91 @@ class BasePlatformAdapter(ABC):
# Fallback: send URL as text (subclasses override for native images)
text = f"{caption}\n{image_url}" if caption else image_url
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
async def send_animation(
self,
chat_id: str,
animation_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""
Send an animated GIF natively via the platform API.
Override in subclasses to send GIFs as proper animations
(e.g., Telegram send_animation) so they auto-play inline.
Default falls back to send_image.
"""
return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to)
@staticmethod
def _is_animation_url(url: str) -> bool:
"""Check if a URL points to an animated GIF (vs a static image)."""
lower = url.lower().split('?')[0] # Strip query params
return lower.endswith('.gif')
lower = url.lower().split("?")[0] # Strip query params
return lower.endswith(".gif")
@staticmethod
def extract_images(content: str) -> Tuple[List[Tuple[str, str]], str]:
def extract_images(content: str) -> tuple[list[tuple[str, str]], str]:
"""
Extract image URLs from markdown and HTML image tags in a response.
Finds patterns like:
- ![alt text](https://example.com/image.png)
- <img src="https://example.com/image.png">
- <img src="https://example.com/image.png"></img>
Args:
content: The response text to scan.
Returns:
Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed).
"""
images = []
cleaned = content
# Match markdown images: ![alt](url)
md_pattern = r'!\[([^\]]*)\]\((https?://[^\s\)]+)\)'
md_pattern = r"!\[([^\]]*)\]\((https?://[^\s\)]+)\)"
for match in re.finditer(md_pattern, content):
alt_text = match.group(1)
url = match.group(2)
# Only extract URLs that look like actual images
if any(url.lower().endswith(ext) or ext in url.lower() for ext in
['.png', '.jpg', '.jpeg', '.gif', '.webp', 'fal.media', 'fal-cdn', 'replicate.delivery']):
if any(
url.lower().endswith(ext) or ext in url.lower()
for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp", "fal.media", "fal-cdn", "replicate.delivery"]
):
images.append((url, alt_text))
# Match HTML img tags: <img src="url"> or <img src="url"></img> or <img src="url"/>
html_pattern = r'<img\s+src=["\']?(https?://[^\s"\'<>]+)["\']?\s*/?>\s*(?:</img>)?'
for match in re.finditer(html_pattern, content):
url = match.group(1)
images.append((url, ""))
# Remove only the matched image tags from content (not all markdown images)
if images:
extracted_urls = {url for url, _ in images}
def _remove_if_extracted(match):
url = match.group(2) if match.lastindex >= 2 else match.group(1)
return '' if url in extracted_urls else match.group(0)
return "" if url in extracted_urls else match.group(0)
cleaned = re.sub(md_pattern, _remove_if_extracted, cleaned)
cleaned = re.sub(html_pattern, _remove_if_extracted, cleaned)
# Clean up leftover blank lines
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
return images, cleaned
async def send_voice(
self,
chat_id: str,
audio_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""
Send an audio file as a native voice message via the platform API.
Override in subclasses to send audio as voice bubbles (Telegram)
or file attachments (Discord). Default falls back to sending the
file path as text.
@@ -532,8 +535,8 @@ class BasePlatformAdapter(ABC):
self,
chat_id: str,
video_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""
Send a video natively via the platform API.
@@ -550,9 +553,9 @@ class BasePlatformAdapter(ABC):
self,
chat_id: str,
file_path: str,
caption: Optional[str] = None,
file_name: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
file_name: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""
Send a document/file natively via the platform API.
@@ -569,8 +572,8 @@ class BasePlatformAdapter(ABC):
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""
Send a local image file natively via the platform API.
@@ -585,45 +588,45 @@ class BasePlatformAdapter(ABC):
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
@staticmethod
def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]:
def extract_media(content: str) -> tuple[list[tuple[str, bool]], str]:
"""
Extract MEDIA:<path> tags and [[audio_as_voice]] directives from response text.
The TTS tool returns responses like:
[[audio_as_voice]]
MEDIA:/path/to/audio.ogg
Args:
content: The response text to scan.
Returns:
Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed).
"""
media = []
cleaned = content
# Check for [[audio_as_voice]] directive
has_voice_tag = "[[audio_as_voice]]" in content
cleaned = cleaned.replace("[[audio_as_voice]]", "")
# Extract MEDIA:<path> tags (path may contain spaces)
media_pattern = r'MEDIA:(\S+)'
media_pattern = r"MEDIA:(\S+)"
for match in re.finditer(media_pattern, content):
path = match.group(1).strip()
if path:
media.append((path, has_voice_tag))
# Remove MEDIA tags from content
if media:
cleaned = re.sub(media_pattern, '', cleaned)
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
cleaned = re.sub(media_pattern, "", cleaned)
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
return media, cleaned
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
"""
Continuously send typing indicator until cancelled.
Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2
to recover quickly after progress messages interrupt it.
"""
@@ -633,20 +636,20 @@ class BasePlatformAdapter(ABC):
await asyncio.sleep(interval)
except asyncio.CancelledError:
pass # Normal cancellation when handler completes
async def handle_message(self, event: MessageEvent) -> None:
"""
Process an incoming message.
This method returns quickly by spawning background tasks.
This allows new messages to be processed even while an agent is running,
enabling interruption support.
"""
if not self._message_handler:
return
session_key = event.source.chat_id
# Check if there's already an active handler for this session
if session_key in self._active_sessions:
# Store this as a pending message - it will interrupt the running agent
@@ -655,10 +658,10 @@ class BasePlatformAdapter(ABC):
# Signal the interrupt (the processing task checks this)
self._active_sessions[session_key].set()
return # Don't process now - will be handled after current task finishes
# Spawn background task to process this message
asyncio.create_task(self._process_message_background(event, session_key))
@staticmethod
def _get_human_delay() -> float:
"""
@@ -685,35 +688,40 @@ class BasePlatformAdapter(ABC):
# Create interrupt event for this session
interrupt_event = asyncio.Event()
self._active_sessions[session_key] = interrupt_event
# Start continuous typing indicator (refreshes every 2 seconds)
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
try:
# Call the handler (this can take a while with tool calls)
response = await self._message_handler(event)
# Send response if any
if not response:
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
if response:
# Extract MEDIA:<path> tags (from TTS tool) before other processing
media_files, response = self.extract_media(response)
# Extract image URLs and send them as native platform attachments
images, text_content = self.extract_images(response)
if images:
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
logger.info(
"[%s] extract_images found %d image(s) in response (%d chars)",
self.name,
len(images),
len(response),
)
# Send the text portion first (if any remains after extractions)
if text_content:
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
result = await self.send(
chat_id=event.source.chat_id,
content=text_content,
reply_to=event.message_id
logger.info(
"[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id
)
result = await self.send(
chat_id=event.source.chat_id, content=text_content, reply_to=event.message_id
)
# Log send failures (don't raise - user already saw tool progress)
if not result.success:
print(f"[{self.name}] Failed to send response: {result.error}")
@@ -721,14 +729,14 @@ class BasePlatformAdapter(ABC):
fallback_result = await self.send(
chat_id=event.source.chat_id,
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
reply_to=event.message_id
reply_to=event.message_id,
)
if not fallback_result.success:
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
# Human-like pacing delay between text and media
human_delay = self._get_human_delay()
# Send extracted images as native attachments
if images:
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
@@ -736,7 +744,12 @@ class BasePlatformAdapter(ABC):
if human_delay > 0:
await asyncio.sleep(human_delay)
try:
logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "")
logger.info(
"[%s] Sending image: %s (alt=%s)",
self.name,
image_url[:80],
alt_text[:30] if alt_text else "",
)
# Route animated GIFs through send_animation for proper playback
if self._is_animation_url(image_url):
img_result = await self.send_animation(
@@ -754,11 +767,11 @@ class BasePlatformAdapter(ABC):
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
except Exception as img_err:
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
# Send extracted media files — route by file type
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'}
_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'}
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"}
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
for media_path, is_voice in media_files:
if human_delay > 0:
@@ -790,7 +803,7 @@ class BasePlatformAdapter(ABC):
print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}")
except Exception as media_err:
print(f"[{self.name}] Error sending media: {media_err}")
# Check if there's a pending message that was queued during our processing
if session_key in self._pending_messages:
pending_event = self._pending_messages.pop(session_key)
@@ -806,10 +819,11 @@ class BasePlatformAdapter(ABC):
# Process pending message in new background task
await self._process_message_background(pending_event, session_key)
return # Already cleaned up
except Exception as e:
print(f"[{self.name}] Error handling message: {e}")
import traceback
traceback.print_exc()
finally:
# Stop typing indicator
@@ -821,26 +835,26 @@ class BasePlatformAdapter(ABC):
# Clean up session tracking
if session_key in self._active_sessions:
del self._active_sessions[session_key]
def has_pending_interrupt(self, session_key: str) -> bool:
"""Check if there's a pending interrupt for a session."""
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
def get_pending_message(self, session_key: str) -> MessageEvent | None:
"""Get and clear any pending message for a session."""
return self._pending_messages.pop(session_key, None)
def build_source(
self,
chat_id: str,
chat_name: Optional[str] = None,
chat_name: str | None = None,
chat_type: str = "dm",
user_id: Optional[str] = None,
user_name: Optional[str] = None,
thread_id: Optional[str] = None,
chat_topic: Optional[str] = None,
user_id_alt: Optional[str] = None,
chat_id_alt: Optional[str] = None,
user_id: str | None = None,
user_name: str | None = None,
thread_id: str | None = None,
chat_topic: str | None = None,
user_id_alt: str | None = None,
chat_id_alt: str | None = None,
) -> SessionSource:
"""Helper to build a SessionSource for this platform."""
# Normalize empty topic to None
@@ -858,30 +872,30 @@ class BasePlatformAdapter(ABC):
user_id_alt=user_id_alt,
chat_id_alt=chat_id_alt,
)
@abstractmethod
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
"""
Get information about a chat/channel.
Returns dict with at least:
- name: Chat name
- type: "dm", "group", "channel"
"""
pass
def format_message(self, content: str) -> str:
"""
Format a message for this platform.
Override in subclasses to handle platform-specific formatting
(e.g., Telegram MarkdownV2, Discord markdown).
Default implementation returns content as-is.
"""
return content
def truncate_message(self, content: str, max_length: int = 4096) -> List[str]:
def truncate_message(self, content: str, max_length: int = 4096) -> list[str]:
"""
Split a long message into chunks, preserving code block boundaries.
@@ -900,14 +914,14 @@ class BasePlatformAdapter(ABC):
if len(content) <= max_length:
return [content]
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
FENCE_CLOSE = "\n```"
chunks: List[str] = []
chunks: list[str] = []
remaining = content
# When the previous chunk ended mid-code-block, this holds the
# language tag (possibly "") so we can reopen the fence.
carry_lang: Optional[str] = None
carry_lang: str | None = None
while remaining:
# If we're continuing a code block from the previous chunk,
@@ -965,8 +979,6 @@ class BasePlatformAdapter(ABC):
# Append chunk indicators when the response spans multiple messages
if len(chunks) > 1:
total = len(chunks)
chunks = [
f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)
]
chunks = [f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)]
return chunks

View File

@@ -10,14 +10,16 @@ Uses discord.py library for:
import asyncio
import logging
import os
from typing import Dict, List, Optional, Any
from typing import Any
logger = logging.getLogger(__name__)
try:
import discord
from discord import Message as DiscordMessage, Intents
from discord import Intents
from discord import Message as DiscordMessage
from discord.ext import commands
DISCORD_AVAILABLE = True
except ImportError:
DISCORD_AVAILABLE = False
@@ -28,6 +30,7 @@ except ImportError:
import sys
from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
@@ -36,8 +39,8 @@ from gateway.platforms.base import (
MessageEvent,
MessageType,
SendResult,
cache_image_from_url,
cache_audio_from_url,
cache_image_from_url,
)
@@ -49,7 +52,7 @@ def check_discord_requirements() -> bool:
class DiscordAdapter(BasePlatformAdapter):
"""
Discord bot adapter.
Handles:
- Receiving messages from servers and DMs
- Sending responses with Discord markdown
@@ -59,26 +62,26 @@ class DiscordAdapter(BasePlatformAdapter):
- Auto-threading for long conversations
- Reaction-based feedback
"""
# Discord message limits
MAX_MESSAGE_LENGTH = 2000
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.DISCORD)
self._client: Optional[commands.Bot] = None
self._client: commands.Bot | None = None
self._ready_event = asyncio.Event()
self._allowed_user_ids: set = set() # For button approval authorization
async def connect(self) -> bool:
"""Connect to Discord and start receiving events."""
if not DISCORD_AVAILABLE:
print(f"[{self.name}] discord.py not installed. Run: pip install discord.py")
return False
if not self.config.token:
print(f"[{self.name}] No bot token configured")
return False
try:
# Set up intents -- members intent needed for username-to-ID resolution
intents = Intents.default()
@@ -86,30 +89,28 @@ class DiscordAdapter(BasePlatformAdapter):
intents.dm_messages = True
intents.guild_messages = True
intents.members = True
# Create bot
self._client = commands.Bot(
command_prefix="!", # Not really used, we handle raw messages
intents=intents,
)
# Parse allowed user entries (may contain usernames or IDs)
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
if allowed_env:
self._allowed_user_ids = {
uid.strip() for uid in allowed_env.split(",") if uid.strip()
}
self._allowed_user_ids = {uid.strip() for uid in allowed_env.split(",") if uid.strip()}
adapter_self = self # capture for closure
# Register event handlers
@self._client.event
async def on_ready():
print(f"[{adapter_self.name}] Connected as {adapter_self._client.user}")
# Resolve any usernames in the allowed list to numeric IDs
await adapter_self._resolve_allowed_usernames()
# Sync slash commands with Discord
try:
synced = await adapter_self._client.tree.sync()
@@ -117,33 +118,33 @@ class DiscordAdapter(BasePlatformAdapter):
except Exception as e:
print(f"[{adapter_self.name}] Slash command sync failed: {e}")
adapter_self._ready_event.set()
@self._client.event
async def on_message(message: DiscordMessage):
# Ignore bot's own messages
if message.author == self._client.user:
return
await self._handle_message(message)
# Register slash commands
self._register_slash_commands()
# Start the bot in background
asyncio.create_task(self._client.start(self.config.token))
# Wait for ready
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
self._running = True
return True
except asyncio.TimeoutError:
except TimeoutError:
print(f"[{self.name}] Timeout waiting for connection")
return False
except Exception as e:
print(f"[{self.name}] Failed to connect: {e}")
return False
async def disconnect(self) -> None:
"""Disconnect from Discord."""
if self._client:
@@ -151,59 +152,55 @@ class DiscordAdapter(BasePlatformAdapter):
await self._client.close()
except Exception as e:
print(f"[{self.name}] Error during disconnect: {e}")
self._running = False
self._client = None
self._ready_event.clear()
print(f"[{self.name}] Disconnected")
async def send(
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
) -> SendResult:
"""Send a message to a Discord channel."""
if not self._client:
return SendResult(success=False, error="Not connected")
try:
# Get the channel
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return SendResult(success=False, error=f"Channel {chat_id} not found")
# Format and split message if needed
formatted = self.format_message(content)
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
message_ids = []
reference = None
if reply_to:
try:
ref_msg = await channel.fetch_message(int(reply_to))
reference = ref_msg
except Exception as e:
logger.debug("Could not fetch reply-to message: %s", e)
for i, chunk in enumerate(chunks):
msg = await channel.send(
content=chunk,
reference=reference if i == 0 else None,
)
message_ids.append(str(msg.id))
return SendResult(
success=True,
message_id=message_ids[0] if message_ids else None,
raw_response={"message_ids": message_ids}
raw_response={"message_ids": message_ids},
)
except Exception as e:
return SendResult(success=False, error=str(e))
@@ -223,7 +220,7 @@ class DiscordAdapter(BasePlatformAdapter):
msg = await channel.fetch_message(int(message_id))
formatted = self.format_message(content)
if len(formatted) > self.MAX_MESSAGE_LENGTH:
formatted = formatted[:self.MAX_MESSAGE_LENGTH - 3] + "..."
formatted = formatted[: self.MAX_MESSAGE_LENGTH - 3] + "..."
await msg.edit(content=formatted)
return SendResult(success=True, message_id=message_id)
except Exception as e:
@@ -233,28 +230,28 @@ class DiscordAdapter(BasePlatformAdapter):
self,
chat_id: str,
audio_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send audio as a Discord file attachment."""
if not self._client:
return SendResult(success=False, error="Not connected")
try:
import io
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return SendResult(success=False, error=f"Channel {chat_id} not found")
if not os.path.exists(audio_path):
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
# Determine filename from path
filename = os.path.basename(audio_path)
with open(audio_path, "rb") as f:
file = discord.File(io.BytesIO(f.read()), filename=filename)
msg = await channel.send(
@@ -262,36 +259,36 @@ class DiscordAdapter(BasePlatformAdapter):
file=file,
)
return SendResult(success=True, message_id=str(msg.id))
except Exception as e:
print(f"[{self.name}] Failed to send audio: {e}")
return await super().send_voice(chat_id, audio_path, caption, reply_to)
async def send_image_file(
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send a local image file natively as a Discord file attachment."""
if not self._client:
return SendResult(success=False, error="Not connected")
try:
import io
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return SendResult(success=False, error=f"Channel {chat_id} not found")
if not os.path.exists(image_path):
return SendResult(success=False, error=f"Image file not found: {image_path}")
filename = os.path.basename(image_path)
with open(image_path, "rb") as f:
file = discord.File(io.BytesIO(f.read()), filename=filename)
msg = await channel.send(
@@ -299,7 +296,7 @@ class DiscordAdapter(BasePlatformAdapter):
file=file,
)
return SendResult(success=True, message_id=str(msg.id))
except Exception as e:
print(f"[{self.name}] Failed to send local image: {e}")
return await super().send_image_file(chat_id, image_path, caption, reply_to)
@@ -308,31 +305,31 @@ class DiscordAdapter(BasePlatformAdapter):
self,
chat_id: str,
image_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send an image natively as a Discord file attachment."""
if not self._client:
return SendResult(success=False, error="Not connected")
try:
import aiohttp
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return SendResult(success=False, error=f"Channel {chat_id} not found")
# Download the image and send as a Discord file attachment
# (Discord renders attachments inline, unlike plain URLs)
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status != 200:
raise Exception(f"Failed to download image: HTTP {resp.status}")
image_data = await resp.read()
# Determine filename from URL or content type
content_type = resp.headers.get("content-type", "image/png")
ext = "png"
@@ -342,23 +339,24 @@ class DiscordAdapter(BasePlatformAdapter):
ext = "gif"
elif "webp" in content_type:
ext = "webp"
import io
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
msg = await channel.send(
content=caption if caption else None,
file=file,
)
return SendResult(success=True, message_id=str(msg.id))
except ImportError:
print(f"[{self.name}] aiohttp not installed, falling back to URL. Run: pip install aiohttp")
return await super().send_image(chat_id, image_url, caption, reply_to)
except Exception as e:
print(f"[{self.name}] Failed to send image attachment, falling back to URL: {e}")
return await super().send_image(chat_id, image_url, caption, reply_to)
async def send_typing(self, chat_id: str) -> None:
"""Send typing indicator."""
if self._client:
@@ -368,20 +366,20 @@ class DiscordAdapter(BasePlatformAdapter):
await channel.typing()
except Exception:
pass # Ignore typing indicator failures
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
"""Get information about a Discord channel."""
if not self._client:
return {"name": "Unknown", "type": "dm"}
try:
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return {"name": str(chat_id), "type": "dm"}
# Determine channel type
if isinstance(channel, discord.DMChannel):
chat_type = "dm"
@@ -397,7 +395,7 @@ class DiscordAdapter(BasePlatformAdapter):
else:
chat_type = "channel"
name = getattr(channel, "name", str(chat_id))
return {
"name": name,
"type": chat_type,
@@ -406,7 +404,7 @@ class DiscordAdapter(BasePlatformAdapter):
}
except Exception as e:
return {"name": str(chat_id), "type": "dm", "error": str(e)}
async def _resolve_allowed_usernames(self) -> None:
"""
Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs.
@@ -453,8 +451,10 @@ class DiscordAdapter(BasePlatformAdapter):
uid = str(member.id)
numeric_ids.add(uid)
resolved_count += 1
matched_name = name_lower if name_lower in to_resolve else (
display_lower if display_lower in to_resolve else global_lower
matched_name = (
name_lower
if name_lower in to_resolve
else (display_lower if display_lower in to_resolve else global_lower)
)
to_resolve.discard(matched_name)
print(f"[{self.name}] Resolved '{matched_name}' -> {uid} ({member.name}#{member.discriminator})")
@@ -474,12 +474,12 @@ class DiscordAdapter(BasePlatformAdapter):
def format_message(self, content: str) -> str:
"""
Format message for Discord.
Discord uses its own markdown variant.
"""
# Discord markdown is fairly standard, no special escaping needed
return content
def _register_slash_commands(self) -> None:
"""Register Discord slash commands on the command tree."""
if not self._client:
@@ -694,7 +694,7 @@ class DiscordAdapter(BasePlatformAdapter):
chat_name = interaction.channel.name
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
# Get channel topic (if available)
chat_topic = getattr(interaction.channel, "topic", None)
@@ -715,9 +715,7 @@ class DiscordAdapter(BasePlatformAdapter):
raw_message=interaction,
)
async def send_exec_approval(
self, chat_id: str, command: str, approval_id: str
) -> SendResult:
async def send_exec_approval(self, chat_id: str, command: str, approval_id: str) -> SendResult:
"""
Send a button-based exec approval prompt for a dangerous command.
@@ -759,28 +757,28 @@ class DiscordAdapter(BasePlatformAdapter):
# bot responds to every message without needing a mention.
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
# globally (all channels become free-response). Default: "true".
if not isinstance(message.channel, discord.DMChannel):
# Check if this channel is in the free-response list
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
channel_id = str(message.channel.id)
# Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
is_free_channel = channel_id in free_channels
if require_mention and not is_free_channel:
# Must be @mentioned to respond
if self._client.user not in message.mentions:
return # Silently ignore messages that don't mention the bot
# Strip the bot mention from the message text so the agent sees clean input
if self._client.user and self._client.user in message.mentions:
message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip()
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
# Determine message type
msg_type = MessageType.TEXT
if message.content.startswith("/"):
@@ -798,7 +796,7 @@ class DiscordAdapter(BasePlatformAdapter):
else:
msg_type = MessageType.DOCUMENT
break
# Determine chat type
if isinstance(message.channel, discord.DMChannel):
chat_type = "dm"
@@ -811,15 +809,15 @@ class DiscordAdapter(BasePlatformAdapter):
chat_name = getattr(message.channel, "name", str(message.channel.id))
if hasattr(message.channel, "guild") and message.channel.guild:
chat_name = f"{message.channel.guild.name} / #{chat_name}"
# Get thread ID if in a thread
thread_id = None
if isinstance(message.channel, discord.Thread):
thread_id = str(message.channel.id)
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
chat_topic = getattr(message.channel, "topic", None)
# Build source
source = self.build_source(
chat_id=str(message.channel.id),
@@ -830,7 +828,7 @@ class DiscordAdapter(BasePlatformAdapter):
thread_id=thread_id,
chat_topic=chat_topic,
)
# Build media URLs -- download image attachments to local cache so the
# vision tool can access them reliably (Discord CDN URLs can expire).
media_urls = []
@@ -869,7 +867,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Other attachments: keep the original URL
media_urls.append(att.url)
media_types.append(content_type)
event = MessageEvent(
text=message.content,
message_type=msg_type,
@@ -881,7 +879,7 @@ class DiscordAdapter(BasePlatformAdapter):
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
timestamp=message.created_at,
)
await self.handle_message(event)
@@ -911,20 +909,14 @@ if DISCORD_AVAILABLE:
return True # No allowlist = anyone can approve
return str(interaction.user.id) in self.allowed_user_ids
async def _resolve(
self, interaction: discord.Interaction, action: str, color: discord.Color
):
async def _resolve(self, interaction: discord.Interaction, action: str, color: discord.Color):
"""Resolve the approval and update the message."""
if self.resolved:
await interaction.response.send_message(
"This approval has already been resolved~", ephemeral=True
)
await interaction.response.send_message("This approval has already been resolved~", ephemeral=True)
return
if not self._check_auth(interaction):
await interaction.response.send_message(
"You're not authorized to approve commands~", ephemeral=True
)
await interaction.response.send_message("You're not authorized to approve commands~", ephemeral=True)
return
self.resolved = True
@@ -944,6 +936,7 @@ if DISCORD_AVAILABLE:
# Store the approval decision
try:
from tools.approval import approve_permanent
if action == "allow_once":
pass # One-time approval handled by gateway
elif action == "allow_always":
@@ -952,21 +945,15 @@ if DISCORD_AVAILABLE:
pass
@discord.ui.button(label="Allow Once", style=discord.ButtonStyle.green)
async def allow_once(
self, interaction: discord.Interaction, button: discord.ui.Button
):
async def allow_once(self, interaction: discord.Interaction, button: discord.ui.Button):
await self._resolve(interaction, "allow_once", discord.Color.green())
@discord.ui.button(label="Always Allow", style=discord.ButtonStyle.blurple)
async def allow_always(
self, interaction: discord.Interaction, button: discord.ui.Button
):
async def allow_always(self, interaction: discord.Interaction, button: discord.ui.Button):
await self._resolve(interaction, "allow_always", discord.Color.blue())
@discord.ui.button(label="Deny", style=discord.ButtonStyle.red)
async def deny(
self, interaction: discord.Interaction, button: discord.ui.Button
):
async def deny(self, interaction: discord.Interaction, button: discord.ui.Button):
await self._resolve(interaction, "deny", discord.Color.red())
async def on_timeout(self):

View File

@@ -19,10 +19,11 @@ import os
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Set
from typing import Any
try:
import aiohttp
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
@@ -66,10 +67,10 @@ class HomeAssistantAdapter(BasePlatformAdapter):
super().__init__(config, Platform.HOMEASSISTANT)
# Connection state
self._session: Optional["aiohttp.ClientSession"] = None
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
self._rest_session: Optional["aiohttp.ClientSession"] = None
self._listen_task: Optional[asyncio.Task] = None
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._rest_session: aiohttp.ClientSession | None = None
self._listen_task: asyncio.Task | None = None
self._msg_id: int = 0
# Configuration from extra
@@ -80,13 +81,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
self._hass_token: str = token
# Event filtering
self._watch_domains: Set[str] = set(extra.get("watch_domains", []))
self._watch_entities: Set[str] = set(extra.get("watch_entities", []))
self._ignore_entities: Set[str] = set(extra.get("ignore_entities", []))
self._watch_domains: set[str] = set(extra.get("watch_domains", []))
self._watch_entities: set[str] = set(extra.get("watch_entities", []))
self._ignore_entities: set[str] = set(extra.get("ignore_entities", []))
self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30))
# Cooldown tracking: entity_id -> last_event_timestamp
self._last_event_time: Dict[str, float] = {}
self._last_event_time: dict[str, float] = {}
def _next_id(self) -> int:
"""Return the next WebSocket message ID."""
@@ -141,10 +142,12 @@ class HomeAssistantAdapter(BasePlatformAdapter):
return False
# Step 2: Send auth
await self._ws.send_json({
"type": "auth",
"access_token": self._hass_token,
})
await self._ws.send_json(
{
"type": "auth",
"access_token": self._hass_token,
}
)
# Step 3: Wait for auth_ok
msg = await self._ws.receive_json()
@@ -155,11 +158,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
# Step 4: Subscribe to state_changed events
sub_id = self._next_id()
await self._ws.send_json({
"id": sub_id,
"type": "subscribe_events",
"event_type": "state_changed",
})
await self._ws.send_json(
{
"id": sub_id,
"type": "subscribe_events",
"event_type": "state_changed",
}
)
# Verify subscription acknowledgement
msg = await self._ws.receive_json()
@@ -245,7 +250,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
break
async def _handle_ha_event(self, event: Dict[str, Any]) -> None:
async def _handle_ha_event(self, event: dict[str, Any]) -> None:
"""Process a state_changed event from Home Assistant."""
event_data = event.get("data", {})
entity_id: str = event_data.get("entity_id", "")
@@ -302,9 +307,9 @@ class HomeAssistantAdapter(BasePlatformAdapter):
@staticmethod
def _format_state_change(
entity_id: str,
old_state: Dict[str, Any],
new_state: Dict[str, Any],
) -> Optional[str]:
old_state: dict[str, Any],
new_state: dict[str, Any],
) -> str | None:
"""Convert a state_changed event into a human-readable description."""
if not new_state:
return None
@@ -331,10 +336,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
if domain == "sensor":
unit = new_state.get("attributes", {}).get("unit_of_measurement", "")
return (
f"[Home Assistant] {friendly_name}: changed from "
f"{old_val}{unit} to {new_val}{unit}"
)
return f"[Home Assistant] {friendly_name}: changed from {old_val}{unit} to {new_val}{unit}"
if domain == "binary_sensor":
return (
@@ -344,22 +346,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
)
if domain in ("light", "switch", "fan"):
return (
f"[Home Assistant] {friendly_name}: turned "
f"{'on' if new_val == 'on' else 'off'}"
)
return f"[Home Assistant] {friendly_name}: turned {'on' if new_val == 'on' else 'off'}"
if domain == "alarm_control_panel":
return (
f"[Home Assistant] {friendly_name}: alarm state changed from "
f"'{old_val}' to '{new_val}'"
)
return f"[Home Assistant] {friendly_name}: alarm state changed from '{old_val}' to '{new_val}'"
# Generic fallback
return (
f"[Home Assistant] {friendly_name} ({entity_id}): "
f"changed from '{old_val}' to '{new_val}'"
)
return f"[Home Assistant] {friendly_name} ({entity_id}): changed from '{old_val}' to '{new_val}'"
# ------------------------------------------------------------------
# Outbound messaging
@@ -369,8 +362,8 @@ class HomeAssistantAdapter(BasePlatformAdapter):
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
reply_to: str | None = None,
metadata: dict[str, Any] | None = None,
) -> SendResult:
"""Send a notification via HA REST API (persistent_notification.create).
@@ -384,7 +377,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
}
payload = {
"title": "Hermes Agent",
"message": content[:self.MAX_MESSAGE_LENGTH],
"message": content[: self.MAX_MESSAGE_LENGTH],
}
try:
@@ -401,20 +394,22 @@ class HomeAssistantAdapter(BasePlatformAdapter):
body = await resp.text()
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
else:
async with aiohttp.ClientSession() as session:
async with session.post(
async with (
aiohttp.ClientSession() as session,
session.post(
url,
headers=headers,
json=payload,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status < 300:
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
else:
body = await resp.text()
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
) as resp,
):
if resp.status < 300:
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
else:
body = await resp.text()
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
except asyncio.TimeoutError:
except TimeoutError:
return SendResult(success=False, error="Timeout sending notification to HA")
except Exception as e:
return SendResult(success=False, error=str(e))
@@ -423,7 +418,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
"""No typing indicator for Home Assistant."""
pass
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
"""Return basic info about the HA event channel."""
return {
"name": "Home Assistant Events",

View File

@@ -19,9 +19,9 @@ import os
import random
import re
import time
from datetime import datetime, timezone
from datetime import UTC, datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from typing import Any
from urllib.parse import unquote
import httpx
@@ -32,9 +32,9 @@ from gateway.platforms.base import (
MessageEvent,
MessageType,
SendResult,
cache_image_from_bytes,
cache_audio_from_bytes,
cache_document_from_bytes,
cache_image_from_bytes,
cache_image_from_url,
)
@@ -59,6 +59,7 @@ _PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
# Helpers
# ---------------------------------------------------------------------------
def _redact_phone(phone: str) -> str:
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
if not phone:
@@ -68,7 +69,7 @@ def _redact_phone(phone: str) -> str:
return phone[:4] + "****" + phone[-4:]
def _parse_comma_list(value: str) -> List[str]:
def _parse_comma_list(value: str) -> list[str]:
"""Split a comma-separated string into a list, stripping whitespace."""
return [v.strip() for v in value.split(",") if v.strip()]
@@ -110,7 +111,7 @@ def _render_mentions(text: str, mentions: list) -> str:
Signal encodes @mentions as the Unicode object replacement character
with out-of-band metadata containing the mentioned user's UUID/number.
"""
if not mentions or "\uFFFC" not in text:
if not mentions or "\ufffc" not in text:
return text
# Sort mentions by start position (reverse) to replace from end to start
# so indices don't shift as we replace
@@ -121,7 +122,7 @@ def _render_mentions(text: str, mentions: list) -> str:
# Use the mention's number or UUID as the replacement
identifier = mention.get("number") or mention.get("uuid") or "user"
replacement = f"@{identifier}"
text = text[:start] + replacement + text[start + length:]
text = text[:start] + replacement + text[start + length :]
return text
@@ -134,6 +135,7 @@ def check_signal_requirements() -> bool:
# Signal Adapter
# ---------------------------------------------------------------------------
class SignalAdapter(BasePlatformAdapter):
"""Signal messenger adapter using signal-cli HTTP daemon."""
@@ -152,22 +154,25 @@ class SignalAdapter(BasePlatformAdapter):
self.group_allow_from = set(_parse_comma_list(group_allowed_str))
# HTTP client
self.client: Optional[httpx.AsyncClient] = None
self.client: httpx.AsyncClient | None = None
# Background tasks
self._sse_task: Optional[asyncio.Task] = None
self._health_monitor_task: Optional[asyncio.Task] = None
self._typing_tasks: Dict[str, asyncio.Task] = {}
self._sse_task: asyncio.Task | None = None
self._health_monitor_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {}
self._running = False
self._last_sse_activity = 0.0
self._sse_response: Optional[httpx.Response] = None
self._sse_response: httpx.Response | None = None
# Normalize account for self-message filtering
self._account_normalized = self.account.strip()
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
self.http_url, _redact_phone(self.account),
"enabled" if self.group_allow_from else "disabled")
logger.info(
"Signal adapter initialized: url=%s account=%s groups=%s",
self.http_url,
_redact_phone(self.account),
"enabled" if self.group_allow_from else "disabled",
)
# ------------------------------------------------------------------
# Lifecycle
@@ -241,7 +246,8 @@ class SignalAdapter(BasePlatformAdapter):
try:
logger.debug("Signal SSE: connecting to %s", url)
async with self.client.stream(
"GET", url,
"GET",
url,
headers={"Accept": "text/event-stream"},
timeout=None,
) as response:
@@ -306,9 +312,7 @@ class SignalAdapter(BasePlatformAdapter):
if elapsed > HEALTH_CHECK_STALE_THRESHOLD:
logger.warning("Signal: SSE idle for %.0fs, checking daemon health", elapsed)
try:
resp = await self.client.get(
f"{self.http_url}/api/v1/check", timeout=10.0
)
resp = await self.client.get(f"{self.http_url}/api/v1/check", timeout=10.0)
if resp.status_code == 200:
# Daemon is alive but SSE is idle — update activity to
# avoid repeated warnings (connection may just be quiet)
@@ -345,11 +349,7 @@ class SignalAdapter(BasePlatformAdapter):
return
# Extract sender info
sender = (
envelope_data.get("sourceNumber")
or envelope_data.get("sourceUuid")
or envelope_data.get("source")
)
sender = envelope_data.get("sourceNumber") or envelope_data.get("sourceUuid") or envelope_data.get("source")
sender_name = envelope_data.get("sourceName", "")
sender_uuid = envelope_data.get("sourceUuid", "")
@@ -367,10 +367,7 @@ class SignalAdapter(BasePlatformAdapter):
# Get data message — also check editMessage (edited messages contain
# their updated dataMessage inside editMessage.dataMessage)
data_message = (
envelope_data.get("dataMessage")
or (envelope_data.get("editMessage") or {}).get("dataMessage")
)
data_message = envelope_data.get("dataMessage") or (envelope_data.get("editMessage") or {}).get("dataMessage")
if not data_message:
return
@@ -451,11 +448,11 @@ class SignalAdapter(BasePlatformAdapter):
ts_ms = envelope_data.get("timestamp", 0)
if ts_ms:
try:
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=UTC)
except (ValueError, OSError):
timestamp = datetime.now(tz=timezone.utc)
timestamp = datetime.now(tz=UTC)
else:
timestamp = datetime.now(tz=timezone.utc)
timestamp = datetime.now(tz=UTC)
# Build and dispatch event
event = MessageEvent(
@@ -468,8 +465,7 @@ class SignalAdapter(BasePlatformAdapter):
timestamp=timestamp,
)
logger.debug("Signal: message from %s in %s: %s",
_redact_phone(sender), chat_id[:20], (text or "")[:50])
logger.debug("Signal: message from %s in %s: %s", _redact_phone(sender), chat_id[:20], (text or "")[:50])
await self.handle_message(event)
@@ -479,10 +475,13 @@ class SignalAdapter(BasePlatformAdapter):
async def _fetch_attachment(self, attachment_id: str) -> tuple:
"""Fetch an attachment via JSON-RPC and cache it. Returns (path, ext)."""
result = await self._rpc("getAttachment", {
"account": self.account,
"attachmentId": attachment_id,
})
result = await self._rpc(
"getAttachment",
{
"account": self.account,
"attachmentId": attachment_id,
},
)
if not result:
return None, ""
@@ -547,13 +546,13 @@ class SignalAdapter(BasePlatformAdapter):
self,
chat_id: str,
text: str,
reply_to_message_id: Optional[str] = None,
reply_to_message_id: str | None = None,
**kwargs,
) -> SendResult:
"""Send a text message."""
await self._stop_typing_indicator(chat_id)
params: Dict[str, Any] = {
params: dict[str, Any] = {
"account": self.account,
"message": text,
}
@@ -571,7 +570,7 @@ class SignalAdapter(BasePlatformAdapter):
async def send_typing(self, chat_id: str) -> None:
"""Send a typing indicator."""
params: Dict[str, Any] = {
params: dict[str, Any] = {
"account": self.account,
}
@@ -586,7 +585,7 @@ class SignalAdapter(BasePlatformAdapter):
self,
chat_id: str,
image_url: str,
caption: Optional[str] = None,
caption: str | None = None,
**kwargs,
) -> SendResult:
"""Send an image. Supports http(s):// and file:// URLs."""
@@ -611,7 +610,7 @@ class SignalAdapter(BasePlatformAdapter):
if file_size > SIGNAL_MAX_ATTACHMENT_SIZE:
return SendResult(success=False, error=f"Image too large ({file_size} bytes)")
params: Dict[str, Any] = {
params: dict[str, Any] = {
"account": self.account,
"message": caption or "",
"attachments": [file_path],
@@ -631,8 +630,8 @@ class SignalAdapter(BasePlatformAdapter):
self,
chat_id: str,
file_path: str,
caption: Optional[str] = None,
filename: Optional[str] = None,
caption: str | None = None,
filename: str | None = None,
**kwargs,
) -> SendResult:
"""Send a document/file attachment."""
@@ -641,7 +640,7 @@ class SignalAdapter(BasePlatformAdapter):
if not Path(file_path).exists():
return SendResult(success=False, error="File not found")
params: Dict[str, Any] = {
params: dict[str, Any] = {
"account": self.account,
"message": caption or "",
"attachments": [file_path],
@@ -690,7 +689,7 @@ class SignalAdapter(BasePlatformAdapter):
# Chat Info
# ------------------------------------------------------------------
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
"""Get information about a chat/contact."""
if chat_id.startswith("group:"):
return {
@@ -700,10 +699,13 @@ class SignalAdapter(BasePlatformAdapter):
}
# Try to resolve contact name
result = await self._rpc("getContact", {
"account": self.account,
"contactAddress": chat_id,
})
result = await self._rpc(
"getContact",
{
"account": self.account,
"contactAddress": chat_id,
},
)
name = chat_id
if result and isinstance(result, dict):

View File

@@ -11,12 +11,13 @@ Uses slack-bolt (Python) with Socket Mode for:
import asyncio
import os
import re
from typing import Dict, List, Optional, Any
from typing import Any
try:
from slack_bolt.async_app import AsyncApp
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
from slack_bolt.async_app import AsyncApp
from slack_sdk.web.async_client import AsyncWebClient
SLACK_AVAILABLE = True
except ImportError:
SLACK_AVAILABLE = False
@@ -26,18 +27,17 @@ except ImportError:
import sys
from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
SUPPORTED_DOCUMENT_TYPES,
BasePlatformAdapter,
MessageEvent,
MessageType,
SendResult,
SUPPORTED_DOCUMENT_TYPES,
cache_document_from_bytes,
cache_image_from_url,
cache_audio_from_url,
)
@@ -66,9 +66,9 @@ class SlackAdapter(BasePlatformAdapter):
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.SLACK)
self._app: Optional[AsyncApp] = None
self._handler: Optional[AsyncSocketModeHandler] = None
self._bot_user_id: Optional[str] = None
self._app: AsyncApp | None = None
self._handler: AsyncSocketModeHandler | None = None
self._bot_user_id: str | None = None
async def connect(self) -> bool:
"""Connect to Slack via Socket Mode."""
@@ -135,8 +135,8 @@ class SlackAdapter(BasePlatformAdapter):
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
reply_to: str | None = None,
metadata: dict[str, Any] | None = None,
) -> SendResult:
"""Send a message to a Slack channel or DM."""
if not self._app:
@@ -193,8 +193,8 @@ class SlackAdapter(BasePlatformAdapter):
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send a local image file to Slack by uploading it."""
if not self._app:
@@ -202,6 +202,7 @@ class SlackAdapter(BasePlatformAdapter):
try:
import os
if not os.path.exists(image_path):
return SendResult(success=False, error=f"Image file not found: {image_path}")
@@ -222,8 +223,8 @@ class SlackAdapter(BasePlatformAdapter):
self,
chat_id: str,
image_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send an image to Slack by uploading the URL as a file."""
if not self._app:
@@ -247,7 +248,7 @@ class SlackAdapter(BasePlatformAdapter):
return SendResult(success=True, raw_response=result)
except Exception as e:
except Exception:
# Fall back to sending the URL as text
text = f"{caption}\n{image_url}" if caption else image_url
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
@@ -256,8 +257,8 @@ class SlackAdapter(BasePlatformAdapter):
self,
chat_id: str,
audio_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send an audio file to Slack."""
if not self._app:
@@ -280,8 +281,8 @@ class SlackAdapter(BasePlatformAdapter):
self,
chat_id: str,
video_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send a video file to Slack."""
if not self._app:
@@ -308,9 +309,9 @@ class SlackAdapter(BasePlatformAdapter):
self,
chat_id: str,
file_path: str,
caption: Optional[str] = None,
file_name: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
file_name: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send a document/file attachment to Slack."""
if not self._app:
@@ -335,7 +336,7 @@ class SlackAdapter(BasePlatformAdapter):
print(f"[{self.name}] Failed to send document: {e}")
return await super().send_document(chat_id, file_path, caption, file_name, reply_to)
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
"""Get information about a Slack channel."""
if not self._app:
return {"name": chat_id, "type": "unknown"}
@@ -442,9 +443,7 @@ class SlackAdapter(BasePlatformAdapter):
# Download and cache
raw_bytes = await self._download_slack_file_bytes(url)
cached_path = cache_document_from_bytes(
raw_bytes, original_filename or f"document{ext}"
)
cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}")
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
media_urls.append(cached_path)
media_types.append(doc_mime)
@@ -457,7 +456,7 @@ class SlackAdapter(BasePlatformAdapter):
try:
text_content = raw_bytes.decode("utf-8")
display_name = original_filename or f"document{ext}"
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
injection = f"[Content of {display_name}]:\n{text_content}"
if text:
text = f"{injection}\n\n{text}"
@@ -499,16 +498,20 @@ class SlackAdapter(BasePlatformAdapter):
# Map subcommands to gateway commands
subcommand_map = {
"new": "/reset", "reset": "/reset",
"status": "/status", "stop": "/stop",
"new": "/reset",
"reset": "/reset",
"status": "/status",
"stop": "/stop",
"help": "/help",
"model": "/model", "personality": "/personality",
"retry": "/retry", "undo": "/undo",
"model": "/model",
"personality": "/personality",
"retry": "/retry",
"undo": "/undo",
}
first_word = text.split()[0] if text else ""
if first_word in subcommand_map:
# Preserve arguments after the subcommand
rest = text[len(first_word):].strip()
rest = text[len(first_word) :].strip()
text = f"{subcommand_map[first_word]} {rest}".strip() if rest else subcommand_map[first_word]
elif text:
pass # Treat as a regular question
@@ -544,9 +547,11 @@ class SlackAdapter(BasePlatformAdapter):
if audio:
from gateway.platforms.base import cache_audio_from_bytes
return cache_audio_from_bytes(response.content, ext)
else:
from gateway.platforms.base import cache_image_from_bytes
return cache_image_from_bytes(response.content, ext)
async def _download_slack_file_bytes(self, url: str) -> bytes:

View File

@@ -7,24 +7,26 @@ Uses python-telegram-bot library for:
- Handling media and commands
"""
import asyncio
import logging
import os
import re
from typing import Dict, List, Optional, Any
from typing import Any
logger = logging.getLogger(__name__)
try:
from telegram import Update, Bot, Message
from telegram import Bot, Message, Update
from telegram.constants import ChatType, ParseMode
from telegram.ext import (
Application,
CommandHandler,
MessageHandler as TelegramMessageHandler,
ContextTypes,
filters,
)
from telegram.constants import ParseMode, ChatType
from telegram.ext import (
MessageHandler as TelegramMessageHandler,
)
TELEGRAM_AVAILABLE = True
except ImportError:
TELEGRAM_AVAILABLE = False
@@ -42,22 +44,24 @@ except ImportError:
# don't crash during class definition when the library isn't installed.
class _MockContextTypes:
DEFAULT_TYPE = Any
ContextTypes = _MockContextTypes
import sys
from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
SUPPORTED_DOCUMENT_TYPES,
BasePlatformAdapter,
MessageEvent,
MessageType,
SendResult,
cache_image_from_bytes,
cache_audio_from_bytes,
cache_document_from_bytes,
SUPPORTED_DOCUMENT_TYPES,
cache_image_from_bytes,
)
@@ -68,12 +72,12 @@ def check_telegram_requirements() -> bool:
# Matches every character that MarkdownV2 requires to be backslash-escaped
# when it appears outside a code span or fenced code block.
_MDV2_ESCAPE_RE = re.compile(r'([_*\[\]()~`>#\+\-=|{}.!\\])')
_MDV2_ESCAPE_RE = re.compile(r"([_*\[\]()~`>#\+\-=|{}.!\\])")
def _escape_mdv2(text: str) -> str:
"""Escape Telegram MarkdownV2 special characters with a preceding backslash."""
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
return _MDV2_ESCAPE_RE.sub(r"\\\1", text)
def _strip_mdv2(text: str) -> str:
@@ -83,103 +87,108 @@ def _strip_mdv2(text: str) -> str:
doesn't show stray asterisks from header/bold conversion.
"""
# Remove escape backslashes before special characters
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
cleaned = re.sub(r"\\([_*\[\]()~`>#\+\-=|{}.!\\])", r"\1", text)
# Remove MarkdownV2 bold markers that format_message converted from **bold**
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
cleaned = re.sub(r"\*([^*]+)\*", r"\1", cleaned)
return cleaned
class TelegramAdapter(BasePlatformAdapter):
"""
Telegram bot adapter.
Handles:
- Receiving messages from users and groups
- Sending responses with Telegram markdown
- Forum topics (thread_id support)
- Media messages
"""
# Telegram message limits
MAX_MESSAGE_LENGTH = 4096
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.TELEGRAM)
self._app: Optional[Application] = None
self._bot: Optional[Bot] = None
self._app: Application | None = None
self._bot: Bot | None = None
async def connect(self) -> bool:
"""Connect to Telegram and start polling for updates."""
if not TELEGRAM_AVAILABLE:
print(f"[{self.name}] python-telegram-bot not installed. Run: pip install python-telegram-bot")
return False
if not self.config.token:
print(f"[{self.name}] No bot token configured")
return False
try:
# Build the application
self._app = Application.builder().token(self.config.token).build()
self._bot = self._app.bot
# Register handlers
self._app.add_handler(TelegramMessageHandler(
filters.TEXT & ~filters.COMMAND,
self._handle_text_message
))
self._app.add_handler(TelegramMessageHandler(
filters.COMMAND,
self._handle_command
))
self._app.add_handler(TelegramMessageHandler(
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION),
self._handle_location_message
))
self._app.add_handler(TelegramMessageHandler(
filters.PHOTO | filters.VIDEO | filters.AUDIO | filters.VOICE | filters.Document.ALL | filters.Sticker.ALL,
self._handle_media_message
))
self._app.add_handler(TelegramMessageHandler(filters.TEXT & ~filters.COMMAND, self._handle_text_message))
self._app.add_handler(TelegramMessageHandler(filters.COMMAND, self._handle_command))
self._app.add_handler(
TelegramMessageHandler(
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION), self._handle_location_message
)
)
self._app.add_handler(
TelegramMessageHandler(
filters.PHOTO
| filters.VIDEO
| filters.AUDIO
| filters.VOICE
| filters.Document.ALL
| filters.Sticker.ALL,
self._handle_media_message,
)
)
# Start polling in background
await self._app.initialize()
await self._app.start()
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
# Register bot commands so Telegram shows a hint menu when users type /
try:
from telegram import BotCommand
await self._bot.set_my_commands([
BotCommand("new", "Start a new conversation"),
BotCommand("reset", "Reset conversation history"),
BotCommand("model", "Show or change the model"),
BotCommand("personality", "Set a personality"),
BotCommand("retry", "Retry your last message"),
BotCommand("undo", "Remove the last exchange"),
BotCommand("status", "Show session info"),
BotCommand("stop", "Stop the running agent"),
BotCommand("sethome", "Set this chat as the home channel"),
BotCommand("compress", "Compress conversation context"),
BotCommand("title", "Set or show the session title"),
BotCommand("resume", "Resume a previously-named session"),
BotCommand("usage", "Show token usage for this session"),
BotCommand("provider", "Show available providers"),
BotCommand("insights", "Show usage insights and analytics"),
BotCommand("update", "Update Hermes to the latest version"),
BotCommand("reload_mcp", "Reload MCP servers from config"),
BotCommand("help", "Show available commands"),
])
await self._bot.set_my_commands(
[
BotCommand("new", "Start a new conversation"),
BotCommand("reset", "Reset conversation history"),
BotCommand("model", "Show or change the model"),
BotCommand("personality", "Set a personality"),
BotCommand("retry", "Retry your last message"),
BotCommand("undo", "Remove the last exchange"),
BotCommand("status", "Show session info"),
BotCommand("stop", "Stop the running agent"),
BotCommand("sethome", "Set this chat as the home channel"),
BotCommand("compress", "Compress conversation context"),
BotCommand("title", "Set or show the session title"),
BotCommand("resume", "Resume a previously-named session"),
BotCommand("usage", "Show token usage for this session"),
BotCommand("provider", "Show available providers"),
BotCommand("insights", "Show usage insights and analytics"),
BotCommand("update", "Update Hermes to the latest version"),
BotCommand("reload_mcp", "Reload MCP servers from config"),
BotCommand("help", "Show available commands"),
]
)
except Exception as e:
print(f"[{self.name}] Could not register command menu: {e}")
self._running = True
print(f"[{self.name}] Connected and polling for updates")
return True
except Exception as e:
print(f"[{self.name}] Failed to connect: {e}")
return False
async def disconnect(self) -> None:
"""Stop polling and disconnect."""
if self._app:
@@ -189,31 +198,27 @@ class TelegramAdapter(BasePlatformAdapter):
await self._app.shutdown()
except Exception as e:
print(f"[{self.name}] Error during disconnect: {e}")
self._running = False
self._app = None
self._bot = None
print(f"[{self.name}] Disconnected")
async def send(
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
) -> SendResult:
"""Send a message to a Telegram chat."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
# Format and split message if needed
formatted = self.format_message(content)
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
message_ids = []
thread_id = metadata.get("thread_id") if metadata else None
for i, chunk in enumerate(chunks):
# Try Markdown first, fall back to plain text if it fails
try:
@@ -227,7 +232,9 @@ class TelegramAdapter(BasePlatformAdapter):
except Exception as md_error:
# Markdown parsing failed, try plain text
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
logger.warning(
"[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error
)
# Strip MDV2 escape backslashes so the user doesn't
# see raw backslashes littered through the message.
plain_chunk = _strip_mdv2(chunk)
@@ -241,13 +248,13 @@ class TelegramAdapter(BasePlatformAdapter):
else:
raise # Re-raise if not a parse error
message_ids.append(str(msg.message_id))
return SendResult(
success=True,
message_id=message_ids[0] if message_ids else None,
raw_response={"message_ids": message_ids}
raw_response={"message_ids": message_ids},
)
except Exception as e:
return SendResult(success=False, error=str(e))
@@ -284,18 +291,19 @@ class TelegramAdapter(BasePlatformAdapter):
self,
chat_id: str,
audio_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send audio as a native Telegram voice message or audio file."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
import os
if not os.path.exists(audio_path):
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
with open(audio_path, "rb") as audio_file:
# .ogg files -> send as voice (round playable bubble)
if audio_path.endswith(".ogg") or audio_path.endswith(".opus"):
@@ -317,23 +325,24 @@ class TelegramAdapter(BasePlatformAdapter):
except Exception as e:
print(f"[{self.name}] Failed to send voice/audio: {e}")
return await super().send_voice(chat_id, audio_path, caption, reply_to)
async def send_image_file(
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send a local image file natively as a Telegram photo."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
import os
if not os.path.exists(image_path):
return SendResult(success=False, error=f"Image file not found: {image_path}")
with open(image_path, "rb") as image_file:
msg = await self._bot.send_photo(
chat_id=int(chat_id),
@@ -350,17 +359,17 @@ class TelegramAdapter(BasePlatformAdapter):
self,
chat_id: str,
image_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send an image natively as a Telegram photo.
Tries URL-based send first (fast, works for <5MB images).
Falls back to downloading and uploading as file (supports up to 10MB).
"""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
# Telegram can send photos directly from URLs (up to ~5MB)
msg = await self._bot.send_photo(
@@ -375,11 +384,12 @@ class TelegramAdapter(BasePlatformAdapter):
# Fallback: download and upload as file (supports up to 10MB)
try:
import httpx
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(image_url)
resp.raise_for_status()
image_data = resp.content
msg = await self._bot.send_photo(
chat_id=int(chat_id),
photo=image_data,
@@ -391,18 +401,18 @@ class TelegramAdapter(BasePlatformAdapter):
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
# Final fallback: send URL as text
return await super().send_image(chat_id, image_url, caption, reply_to)
async def send_animation(
self,
chat_id: str,
animation_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send an animated GIF natively as a Telegram animation (auto-plays inline)."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
msg = await self._bot.send_animation(
chat_id=int(chat_id),
@@ -420,21 +430,18 @@ class TelegramAdapter(BasePlatformAdapter):
"""Send typing indicator."""
if self._bot:
try:
await self._bot.send_chat_action(
chat_id=int(chat_id),
action="typing"
)
await self._bot.send_chat_action(chat_id=int(chat_id), action="typing")
except Exception:
pass # Ignore typing indicator failures
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
"""Get information about a Telegram chat."""
if not self._bot:
return {"name": "Unknown", "type": "dm"}
try:
chat = await self._bot.get_chat(int(chat_id))
chat_type = "dm"
if chat.type == ChatType.GROUP:
chat_type = "group"
@@ -444,7 +451,7 @@ class TelegramAdapter(BasePlatformAdapter):
chat_type = "forum"
elif chat.type == ChatType.CHANNEL:
chat_type = "channel"
return {
"name": chat.title or chat.full_name or str(chat_id),
"type": chat_type,
@@ -453,7 +460,7 @@ class TelegramAdapter(BasePlatformAdapter):
}
except Exception as e:
return {"name": str(chat_id), "type": "dm", "error": str(e)}
def format_message(self, content: str) -> str:
"""
Convert standard markdown to Telegram MarkdownV2 format.
@@ -480,38 +487,36 @@ class TelegramAdapter(BasePlatformAdapter):
# 1) Protect fenced code blocks (``` ... ```)
text = re.sub(
r'(```(?:[^\n]*\n)?[\s\S]*?```)',
r"(```(?:[^\n]*\n)?[\s\S]*?```)",
lambda m: _ph(m.group(0)),
text,
)
# 2) Protect inline code (`...`)
text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text)
text = re.sub(r"(`[^`]+`)", lambda m: _ph(m.group(0)), text)
# 3) Convert markdown links escape the display text; inside the URL
# only ')' and '\' need escaping per the MarkdownV2 spec.
def _convert_link(m):
display = _escape_mdv2(m.group(1))
url = m.group(2).replace('\\', '\\\\').replace(')', '\\)')
return _ph(f'[{display}]({url})')
url = m.group(2).replace("\\", "\\\\").replace(")", "\\)")
return _ph(f"[{display}]({url})")
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', _convert_link, text)
text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _convert_link, text)
# 4) Convert markdown headers (## Title) → bold *Title*
def _convert_header(m):
inner = m.group(1).strip()
# Strip redundant bold markers that may appear inside a header
inner = re.sub(r'\*\*(.+?)\*\*', r'\1', inner)
return _ph(f'*{_escape_mdv2(inner)}*')
inner = re.sub(r"\*\*(.+?)\*\*", r"\1", inner)
return _ph(f"*{_escape_mdv2(inner)}*")
text = re.sub(
r'^#{1,6}\s+(.+)$', _convert_header, text, flags=re.MULTILINE
)
text = re.sub(r"^#{1,6}\s+(.+)$", _convert_header, text, flags=re.MULTILINE)
# 5) Convert bold: **text** → *text* (MarkdownV2 bold)
text = re.sub(
r'\*\*(.+?)\*\*',
lambda m: _ph(f'*{_escape_mdv2(m.group(1))}*'),
r"\*\*(.+?)\*\*",
lambda m: _ph(f"*{_escape_mdv2(m.group(1))}*"),
text,
)
@@ -519,8 +524,8 @@ class TelegramAdapter(BasePlatformAdapter):
# [^*\n]+ prevents matching across newlines (which would corrupt
# bullet lists using * markers and multi-line content).
text = re.sub(
r'\*([^*\n]+)\*',
lambda m: _ph(f'_{_escape_mdv2(m.group(1))}_'),
r"\*([^*\n]+)\*",
lambda m: _ph(f"_{_escape_mdv2(m.group(1))}_"),
text,
)
@@ -533,23 +538,23 @@ class TelegramAdapter(BasePlatformAdapter):
text = text.replace(key, placeholders[key])
return text
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming text messages."""
if not update.message or not update.message.text:
return
event = self._build_message_event(update.message, MessageType.TEXT)
await self.handle_message(event)
async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming command messages."""
if not update.message or not update.message.text:
return
event = self._build_message_event(update.message, MessageType.COMMAND)
await self.handle_message(event)
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming location/venue pin messages."""
if not update.message:
@@ -589,9 +594,9 @@ class TelegramAdapter(BasePlatformAdapter):
"""Handle incoming media messages, downloading images to local cache."""
if not update.message:
return
msg = update.message
# Determine media type
if msg.sticker:
msg_type = MessageType.STICKER
@@ -607,19 +612,19 @@ class TelegramAdapter(BasePlatformAdapter):
msg_type = MessageType.DOCUMENT
else:
msg_type = MessageType.DOCUMENT
event = self._build_message_event(msg, msg_type)
# Add caption as text
if msg.caption:
event.text = msg.caption
# Handle stickers: describe via vision tool with caching
if msg.sticker:
await self._handle_sticker(msg, event)
await self.handle_message(event)
return
# Download photo to local image cache so the vision tool can access it
# even after Telegram's ephemeral file URLs expire (~1 hour).
if msg.photo:
@@ -643,7 +648,7 @@ class TelegramAdapter(BasePlatformAdapter):
print(f"[Telegram] Cached user photo: {cached_path}", flush=True)
except Exception as e:
print(f"[Telegram] Failed to cache photo: {e}", flush=True)
# Download voice/audio messages to cache for STT transcription
if msg.voice:
try:
@@ -685,10 +690,7 @@ class TelegramAdapter(BasePlatformAdapter):
# Check if supported
if ext not in SUPPORTED_DOCUMENT_TYPES:
supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys()))
event.text = (
f"Unsupported document type '{ext or 'unknown'}'. "
f"Supported types: {supported_list}"
)
event.text = f"Unsupported document type '{ext or 'unknown'}'. Supported types: {supported_list}"
print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True)
await self.handle_message(event)
return
@@ -696,10 +698,7 @@ class TelegramAdapter(BasePlatformAdapter):
# Check file size (Telegram Bot API limit: 20 MB)
MAX_DOC_BYTES = 20 * 1024 * 1024
if not doc.file_size or doc.file_size > MAX_DOC_BYTES:
event.text = (
"The document is too large or its size could not be verified. "
"Maximum: 20 MB."
)
event.text = "The document is too large or its size could not be verified. Maximum: 20 MB."
print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True)
await self.handle_message(event)
return
@@ -720,20 +719,20 @@ class TelegramAdapter(BasePlatformAdapter):
try:
text_content = raw_bytes.decode("utf-8")
display_name = original_filename or f"document{ext}"
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
injection = f"[Content of {display_name}]:\n{text_content}"
if event.text:
event.text = f"{injection}\n\n{event.text}"
else:
event.text = injection
except UnicodeDecodeError:
print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
print("[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
except Exception as e:
print(f"[Telegram] Failed to cache document: {e}", flush=True)
await self.handle_message(event)
async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None:
"""
Describe a Telegram sticker via vision analysis, with caching.
@@ -743,11 +742,11 @@ class TelegramAdapter(BasePlatformAdapter):
a placeholder noting the emoji.
"""
from gateway.sticker_cache import (
get_cached_description,
cache_sticker_description,
build_sticker_injection,
build_animated_sticker_injection,
STICKER_VISION_PROMPT,
build_animated_sticker_injection,
build_sticker_injection,
cache_sticker_description,
get_cached_description,
)
sticker = msg.sticker
@@ -775,9 +774,10 @@ class TelegramAdapter(BasePlatformAdapter):
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=".webp")
print(f"[Telegram] Analyzing sticker: {cached_path}", flush=True)
from tools.vision_tools import vision_analyze_tool
import json as _json
from tools.vision_tools import vision_analyze_tool
result_json = await vision_analyze_tool(
image_url=cached_path,
user_prompt=STICKER_VISION_PROMPT,
@@ -792,27 +792,29 @@ class TelegramAdapter(BasePlatformAdapter):
# Vision failed -- use emoji as fallback
event.text = build_sticker_injection(
f"a sticker with emoji {emoji}" if emoji else "a sticker",
emoji, set_name,
emoji,
set_name,
)
except Exception as e:
print(f"[Telegram] Sticker analysis error: {e}", flush=True)
event.text = build_sticker_injection(
f"a sticker with emoji {emoji}" if emoji else "a sticker",
emoji, set_name,
emoji,
set_name,
)
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
"""Build a MessageEvent from a Telegram message."""
chat = message.chat
user = message.from_user
# Determine chat type
chat_type = "dm"
if chat.type in (ChatType.GROUP, ChatType.SUPERGROUP):
chat_type = "group"
elif chat.type == ChatType.CHANNEL:
chat_type = "channel"
# Build source
source = self.build_source(
chat_id=str(chat.id),
@@ -822,7 +824,7 @@ class TelegramAdapter(BasePlatformAdapter):
user_name=user.full_name if user else None,
thread_id=str(message.message_thread_id) if message.message_thread_id else None,
)
return MessageEvent(
text=message.text or "",
message_type=msg_type,

View File

@@ -16,7 +16,6 @@ with different backends via a bridge pattern.
"""
import asyncio
import json
import logging
import os
import platform
@@ -24,7 +23,7 @@ import subprocess
_IS_WINDOWS = platform.system() == "Windows"
from pathlib import Path
from typing import Dict, List, Optional, Any
from typing import Any
logger = logging.getLogger(__name__)
@@ -36,7 +35,9 @@ def _kill_port_process(port: int) -> None:
# Use netstat to find the PID bound to this port, then taskkill
result = subprocess.run(
["netstat", "-ano", "-p", "TCP"],
capture_output=True, text=True, timeout=5,
capture_output=True,
text=True,
timeout=5,
)
for line in result.stdout.splitlines():
parts = line.split()
@@ -46,24 +47,29 @@ def _kill_port_process(port: int) -> None:
try:
subprocess.run(
["taskkill", "/PID", parts[4], "/F"],
capture_output=True, timeout=5,
capture_output=True,
timeout=5,
)
except subprocess.SubprocessError:
pass
else:
result = subprocess.run(
["fuser", f"{port}/tcp"],
capture_output=True, timeout=5,
capture_output=True,
timeout=5,
)
if result.returncode == 0:
subprocess.run(
["fuser", "-k", f"{port}/tcp"],
capture_output=True, timeout=5,
capture_output=True,
timeout=5,
)
except Exception:
pass
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
@@ -72,25 +78,20 @@ from gateway.platforms.base import (
MessageEvent,
MessageType,
SendResult,
cache_image_from_url,
cache_audio_from_url,
cache_image_from_url,
)
def check_whatsapp_requirements() -> bool:
"""
Check if WhatsApp dependencies are available.
WhatsApp requires a Node.js bridge for most implementations.
"""
# Check for Node.js
try:
result = subprocess.run(
["node", "--version"],
capture_output=True,
text=True,
timeout=5
)
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=5)
return result.returncode == 0
except Exception:
return False
@@ -99,62 +100,61 @@ def check_whatsapp_requirements() -> bool:
class WhatsAppAdapter(BasePlatformAdapter):
"""
WhatsApp adapter.
This implementation uses a simple HTTP bridge pattern where:
1. A Node.js process runs the WhatsApp Web client
2. Messages are forwarded via HTTP/IPC to this Python adapter
3. Responses are sent back through the bridge
The actual Node.js bridge implementation can vary:
- whatsapp-web.js based
- Baileys based
- Business API based
Configuration:
- bridge_script: Path to the Node.js bridge script
- bridge_port: Port for HTTP communication (default: 3000)
- session_path: Path to store WhatsApp session data
"""
# WhatsApp message limits
MAX_MESSAGE_LENGTH = 65536 # WhatsApp allows longer messages
# Default bridge location relative to the hermes-agent install
_DEFAULT_BRIDGE_DIR = Path(__file__).resolve().parents[2] / "scripts" / "whatsapp-bridge"
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.WHATSAPP)
self._bridge_process: Optional[subprocess.Popen] = None
self._bridge_process: subprocess.Popen | None = None
self._bridge_port: int = config.extra.get("bridge_port", 3000)
self._bridge_script: Optional[str] = config.extra.get(
self._bridge_script: str | None = config.extra.get(
"bridge_script",
str(self._DEFAULT_BRIDGE_DIR / "bridge.js"),
)
self._session_path: Path = Path(config.extra.get(
"session_path",
Path.home() / ".hermes" / "whatsapp" / "session"
))
self._session_path: Path = Path(
config.extra.get("session_path", Path.home() / ".hermes" / "whatsapp" / "session")
)
self._message_queue: asyncio.Queue = asyncio.Queue()
self._bridge_log_fh = None
self._bridge_log: Optional[Path] = None
self._bridge_log: Path | None = None
async def connect(self) -> bool:
"""
Start the WhatsApp bridge.
This launches the Node.js bridge process and waits for it to be ready.
"""
if not check_whatsapp_requirements():
logger.warning("[%s] Node.js not found. WhatsApp requires Node.js.", self.name)
return False
bridge_path = Path(self._bridge_script)
if not bridge_path.exists():
logger.warning("[%s] Bridge script not found: %s", self.name, bridge_path)
return False
logger.info("[%s] Bridge found at %s", self.name, bridge_path)
# Auto-install npm dependencies if node_modules doesn't exist
bridge_dir = bridge_path.parent
if not (bridge_dir / "node_modules").exists():
@@ -174,16 +174,17 @@ class WhatsAppAdapter(BasePlatformAdapter):
except Exception as e:
print(f"[{self.name}] Failed to install dependencies: {e}")
return False
try:
# Ensure session directory exists
self._session_path.mkdir(parents=True, exist_ok=True)
# Kill any orphaned bridge from a previous gateway run
_kill_port_process(self._bridge_port)
import time
time.sleep(1)
# Start the bridge process in its own process group.
# Route output to a log file so QR codes, errors, and reconnection
# messages are preserved for troubleshooting.
@@ -195,19 +196,23 @@ class WhatsAppAdapter(BasePlatformAdapter):
[
"node",
str(bridge_path),
"--port", str(self._bridge_port),
"--session", str(self._session_path),
"--mode", whatsapp_mode,
"--port",
str(self._bridge_port),
"--session",
str(self._session_path),
"--mode",
whatsapp_mode,
],
stdout=bridge_log_fh,
stderr=bridge_log_fh,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
# Wait for the bridge to connect to WhatsApp.
# Phase 1: wait for the HTTP server to come up (up to 15s).
# Phase 2: wait for WhatsApp status: connected (up to 15s more).
import aiohttp
http_ready = False
data = {}
for attempt in range(15):
@@ -218,17 +223,18 @@ class WhatsAppAdapter(BasePlatformAdapter):
self._close_bridge_log()
return False
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://localhost:{self._bridge_port}/health",
timeout=aiohttp.ClientTimeout(total=2)
) as resp:
if resp.status == 200:
http_ready = True
data = await resp.json()
if data.get("status") == "connected":
print(f"[{self.name}] Bridge ready (status: connected)")
break
async with (
aiohttp.ClientSession() as session,
session.get(
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
) as resp,
):
if resp.status == 200:
http_ready = True
data = await resp.json()
if data.get("status") == "connected":
print(f"[{self.name}] Bridge ready (status: connected)")
break
except Exception:
continue
@@ -237,7 +243,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
print(f"[{self.name}] Check log: {self._bridge_log}")
self._close_bridge_log()
return False
# Phase 2: HTTP is up but WhatsApp may still be connecting.
# Give it more time to authenticate with saved credentials.
if data.get("status") != "connected":
@@ -250,16 +256,17 @@ class WhatsAppAdapter(BasePlatformAdapter):
self._close_bridge_log()
return False
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://localhost:{self._bridge_port}/health",
timeout=aiohttp.ClientTimeout(total=2)
) as resp:
if resp.status == 200:
data = await resp.json()
if data.get("status") == "connected":
print(f"[{self.name}] Bridge ready (status: connected)")
break
async with (
aiohttp.ClientSession() as session,
session.get(
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
) as resp,
):
if resp.status == 200:
data = await resp.json()
if data.get("status") == "connected":
print(f"[{self.name}] Bridge ready (status: connected)")
break
except Exception:
continue
else:
@@ -268,19 +275,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
print(f"[{self.name}] ⚠ WhatsApp not connected after 30s")
print(f"[{self.name}] Bridge log: {self._bridge_log}")
print(f"[{self.name}] If session expired, re-pair: hermes whatsapp")
# Start message polling task
asyncio.create_task(self._poll_messages())
self._running = True
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
return True
except Exception as e:
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
self._close_bridge_log()
return False
def _close_bridge_log(self) -> None:
"""Close the bridge log file handle if open."""
if self._bridge_log_fh:
@@ -296,6 +303,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
try:
# Kill the entire process group so child node processes die too
import signal
try:
if _IS_WINDOWS:
self._bridge_process.terminate()
@@ -314,29 +322,25 @@ class WhatsAppAdapter(BasePlatformAdapter):
self._bridge_process.kill()
except Exception as e:
print(f"[{self.name}] Error stopping bridge: {e}")
# Also kill any orphaned bridge processes on our port
_kill_port_process(self._bridge_port)
self._running = False
self._bridge_process = None
self._close_bridge_log()
print(f"[{self.name}] Disconnected")
async def send(
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
) -> SendResult:
"""Send a message via the WhatsApp bridge."""
if not self._running:
return SendResult(success=False, error="Not connected")
try:
import aiohttp
async with aiohttp.ClientSession() as session:
payload = {
"chatId": chat_id,
@@ -344,28 +348,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
}
if reply_to:
payload["replyTo"] = reply_to
async with session.post(
f"http://localhost:{self._bridge_port}/send",
json=payload,
timeout=aiohttp.ClientTimeout(total=30)
f"http://localhost:{self._bridge_port}/send", json=payload, timeout=aiohttp.ClientTimeout(total=30)
) as resp:
if resp.status == 200:
data = await resp.json()
return SendResult(
success=True,
message_id=data.get("messageId"),
raw_response=data
)
return SendResult(success=True, message_id=data.get("messageId"), raw_response=data)
else:
error = await resp.text()
return SendResult(success=False, error=error)
except ImportError:
return SendResult(
success=False,
error="aiohttp not installed. Run: pip install aiohttp"
)
return SendResult(success=False, error="aiohttp not installed. Run: pip install aiohttp")
except Exception as e:
return SendResult(success=False, error=str(e))
@@ -380,21 +375,24 @@ class WhatsAppAdapter(BasePlatformAdapter):
return SendResult(success=False, error="Not connected")
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.post(
async with (
aiohttp.ClientSession() as session,
session.post(
f"http://localhost:{self._bridge_port}/edit",
json={
"chatId": chat_id,
"messageId": message_id,
"message": content,
},
timeout=aiohttp.ClientTimeout(total=15)
) as resp:
if resp.status == 200:
return SendResult(success=True, message_id=message_id)
else:
error = await resp.text()
return SendResult(success=False, error=error)
timeout=aiohttp.ClientTimeout(total=15),
) as resp,
):
if resp.status == 200:
return SendResult(success=True, message_id=message_id)
else:
error = await resp.text()
return SendResult(success=False, error=error)
except Exception as e:
return SendResult(success=False, error=str(e))
@@ -403,8 +401,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
chat_id: str,
file_path: str,
media_type: str,
caption: Optional[str] = None,
file_name: Optional[str] = None,
caption: str | None = None,
file_name: str | None = None,
) -> SendResult:
"""Send any media file via bridge /send-media endpoint."""
if not self._running:
@@ -415,7 +413,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
if not os.path.exists(file_path):
return SendResult(success=False, error=f"File not found: {file_path}")
payload: Dict[str, Any] = {
payload: dict[str, Any] = {
"chatId": chat_id,
"filePath": file_path,
"mediaType": media_type,
@@ -425,22 +423,24 @@ class WhatsAppAdapter(BasePlatformAdapter):
if file_name:
payload["fileName"] = file_name
async with aiohttp.ClientSession() as session:
async with session.post(
async with (
aiohttp.ClientSession() as session,
session.post(
f"http://localhost:{self._bridge_port}/send-media",
json=payload,
timeout=aiohttp.ClientTimeout(total=120),
) as resp:
if resp.status == 200:
data = await resp.json()
return SendResult(
success=True,
message_id=data.get("messageId"),
raw_response=data,
)
else:
error = await resp.text()
return SendResult(success=False, error=error)
) as resp,
):
if resp.status == 200:
data = await resp.json()
return SendResult(
success=True,
message_id=data.get("messageId"),
raw_response=data,
)
else:
error = await resp.text()
return SendResult(success=False, error=error)
except Exception as e:
return SendResult(success=False, error=str(e))
@@ -449,8 +449,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
self,
chat_id: str,
image_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Download image URL to cache, send natively via bridge."""
try:
@@ -463,8 +463,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send a local image file natively via bridge."""
return await self._send_media_to_bridge(chat_id, image_path, "image", caption)
@@ -473,8 +473,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
self,
chat_id: str,
video_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send a video natively via bridge — plays inline in WhatsApp."""
return await self._send_media_to_bridge(chat_id, video_path, "video", caption)
@@ -483,13 +483,16 @@ class WhatsAppAdapter(BasePlatformAdapter):
self,
chat_id: str,
file_path: str,
caption: Optional[str] = None,
file_name: Optional[str] = None,
reply_to: Optional[str] = None,
caption: str | None = None,
file_name: str | None = None,
reply_to: str | None = None,
) -> SendResult:
"""Send a document/file as a downloadable attachment via bridge."""
return await self._send_media_to_bridge(
chat_id, file_path, "document", caption,
chat_id,
file_path,
"document",
caption,
file_name or os.path.basename(file_path),
)
@@ -497,44 +500,45 @@ class WhatsAppAdapter(BasePlatformAdapter):
"""Send typing indicator via bridge."""
if not self._running:
return
try:
import aiohttp
async with aiohttp.ClientSession() as session:
await session.post(
f"http://localhost:{self._bridge_port}/typing",
json={"chatId": chat_id},
timeout=aiohttp.ClientTimeout(total=5)
timeout=aiohttp.ClientTimeout(total=5),
)
except Exception:
pass # Ignore typing indicator failures
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
"""Get information about a WhatsApp chat."""
if not self._running:
return {"name": "Unknown", "type": "dm"}
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://localhost:{self._bridge_port}/chat/{chat_id}",
timeout=aiohttp.ClientTimeout(total=10)
) as resp:
if resp.status == 200:
data = await resp.json()
return {
"name": data.get("name", chat_id),
"type": "group" if data.get("isGroup") else "dm",
"participants": data.get("participants", []),
}
async with (
aiohttp.ClientSession() as session,
session.get(
f"http://localhost:{self._bridge_port}/chat/{chat_id}", timeout=aiohttp.ClientTimeout(total=10)
) as resp,
):
if resp.status == 200:
data = await resp.json()
return {
"name": data.get("name", chat_id),
"type": "group" if data.get("isGroup") else "dm",
"participants": data.get("participants", []),
}
except Exception as e:
logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e)
return {"name": chat_id, "type": "dm"}
async def _poll_messages(self) -> None:
"""Poll the bridge for incoming messages."""
try:
@@ -542,29 +546,30 @@ class WhatsAppAdapter(BasePlatformAdapter):
except ImportError:
print(f"[{self.name}] aiohttp not installed, message polling disabled")
return
while self._running:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"http://localhost:{self._bridge_port}/messages",
timeout=aiohttp.ClientTimeout(total=30)
) as resp:
if resp.status == 200:
messages = await resp.json()
for msg_data in messages:
event = await self._build_message_event(msg_data)
if event:
await self.handle_message(event)
async with (
aiohttp.ClientSession() as session,
session.get(
f"http://localhost:{self._bridge_port}/messages", timeout=aiohttp.ClientTimeout(total=30)
) as resp,
):
if resp.status == 200:
messages = await resp.json()
for msg_data in messages:
event = await self._build_message_event(msg_data)
if event:
await self.handle_message(event)
except asyncio.CancelledError:
break
except Exception as e:
print(f"[{self.name}] Poll error: {e}")
await asyncio.sleep(5)
await asyncio.sleep(1) # Poll interval
async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
async def _build_message_event(self, data: dict[str, Any]) -> MessageEvent | None:
"""Build a MessageEvent from bridge message data, downloading images to cache."""
try:
# Determine message type
@@ -579,11 +584,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
msg_type = MessageType.VOICE
else:
msg_type = MessageType.DOCUMENT
# Determine chat type
is_group = data.get("isGroup", False)
chat_type = "group" if is_group else "dm"
# Build source
source = self.build_source(
chat_id=data.get("chatId", ""),
@@ -592,7 +597,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
user_id=data.get("senderId"),
user_name=data.get("senderName"),
)
# Download image media URLs to the local cache so the vision tool
# can access them reliably regardless of URL expiration.
raw_urls = data.get("mediaUrls", [])
@@ -622,7 +627,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
else:
cached_urls.append(url)
media_types.append("unknown")
return MessageEvent(
text=data.get("body", ""),
message_type=msg_type,
@@ -635,4 +640,3 @@ class WhatsAppAdapter(BasePlatformAdapter):
except Exception as e:
print(f"[{self.name}] Error building event: {e}")
return None

File diff suppressed because it is too large Load Diff

View File

@@ -8,22 +8,20 @@ Handles:
- Dynamic system prompt injection (agent knows its context)
"""
import logging
import os
import json
import logging
import uuid
from pathlib import Path
from dataclasses import dataclass
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
from .config import (
Platform,
GatewayConfig,
SessionResetPolicy,
HomeChannel,
Platform,
)
@@ -31,29 +29,30 @@ from .config import (
class SessionSource:
"""
Describes where a message originated from.
This information is used to:
1. Route responses back to the right place
2. Inject context into the system prompt
3. Track origin for cron job delivery
"""
platform: Platform
chat_id: str
chat_name: Optional[str] = None
chat_name: str | None = None
chat_type: str = "dm" # "dm", "group", "channel", "thread"
user_id: Optional[str] = None
user_name: Optional[str] = None
thread_id: Optional[str] = None # For forum topics, Discord threads, etc.
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number)
chat_id_alt: Optional[str] = None # Signal group internal ID
user_id: str | None = None
user_name: str | None = None
thread_id: str | None = None # For forum topics, Discord threads, etc.
chat_topic: str | None = None # Channel topic/description (Discord, Slack)
user_id_alt: str | None = None # Signal UUID (alternative to phone number)
chat_id_alt: str | None = None # Signal group internal ID
@property
def description(self) -> str:
"""Human-readable description of the source."""
if self.platform == Platform.LOCAL:
return "CLI terminal"
parts = []
if self.chat_type == "dm":
parts.append(f"DM with {self.user_name or self.user_id or 'user'}")
@@ -63,13 +62,13 @@ class SessionSource:
parts.append(f"channel: {self.chat_name or self.chat_id}")
else:
parts.append(self.chat_name or self.chat_id)
if self.thread_id:
parts.append(f"thread: {self.thread_id}")
return ", ".join(parts)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
d = {
"platform": self.platform.value,
"chat_id": self.chat_id,
@@ -85,9 +84,9 @@ class SessionSource:
if self.chat_id_alt:
d["chat_id_alt"] = self.chat_id_alt
return d
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SessionSource":
def from_dict(cls, data: dict[str, Any]) -> "SessionSource":
return cls(
platform=Platform(data["platform"]),
chat_id=str(data["chat_id"]),
@@ -100,7 +99,7 @@ class SessionSource:
user_id_alt=data.get("user_id_alt"),
chat_id_alt=data.get("chat_id_alt"),
)
@classmethod
def local_cli(cls) -> "SessionSource":
"""Create a source representing the local CLI."""
@@ -116,29 +115,28 @@ class SessionSource:
class SessionContext:
"""
Full context for a session, used for dynamic system prompt injection.
The agent receives this information to understand:
- Where messages are coming from
- What platforms are available
- Where it can deliver scheduled task outputs
"""
source: SessionSource
connected_platforms: List[Platform]
home_channels: Dict[Platform, HomeChannel]
connected_platforms: list[Platform]
home_channels: dict[Platform, HomeChannel]
# Session metadata
session_key: str = ""
session_id: str = ""
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
def to_dict(self) -> Dict[str, Any]:
created_at: datetime | None = None
updated_at: datetime | None = None
def to_dict(self) -> dict[str, Any]:
return {
"source": self.source.to_dict(),
"connected_platforms": [p.value for p in self.connected_platforms],
"home_channels": {
p.value: hc.to_dict() for p, hc in self.home_channels.items()
},
"home_channels": {p.value: hc.to_dict() for p, hc in self.home_channels.items()},
"session_key": self.session_key,
"session_id": self.session_id,
"created_at": self.created_at.isoformat() if self.created_at else None,
@@ -149,7 +147,7 @@ class SessionContext:
def build_session_context_prompt(context: SessionContext) -> str:
"""
Build the dynamic system prompt section that tells the agent about its context.
This is injected into the system prompt so the agent knows:
- Where messages are coming from
- What platforms are connected
@@ -159,14 +157,14 @@ def build_session_context_prompt(context: SessionContext) -> str:
"## Current Session Context",
"",
]
# Source info
platform_name = context.source.platform.value.title()
if context.source.platform == Platform.LOCAL:
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
else:
lines.append(f"**Source:** {platform_name} ({context.source.description})")
# Channel topic (if available - provides context about the channel's purpose)
if context.source.chat_topic:
lines.append(f"**Channel Topic:** {context.source.chat_topic}")
@@ -176,43 +174,43 @@ def build_session_context_prompt(context: SessionContext) -> str:
lines.append(f"**User:** {context.source.user_name}")
elif context.source.user_id:
lines.append(f"**User ID:** {context.source.user_id}")
# Connected platforms
platforms_list = ["local (files on this machine)"]
for p in context.connected_platforms:
if p != Platform.LOCAL:
platforms_list.append(f"{p.value}: Connected ✓")
lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}")
# Home channels
if context.home_channels:
lines.append("")
lines.append("**Home Channels (default destinations):**")
for platform, home in context.home_channels.items():
lines.append(f" - {platform.value}: {home.name} (ID: {home.chat_id})")
# Delivery options for scheduled tasks
lines.append("")
lines.append("**Delivery options for scheduled tasks:**")
# Origin delivery
if context.source.platform == Platform.LOCAL:
lines.append("- `\"origin\"` → Local output (saved to files)")
lines.append('- `"origin"` → Local output (saved to files)')
else:
lines.append(f"- `\"origin\"` → Back to this chat ({context.source.chat_name or context.source.chat_id})")
lines.append(f'- `"origin"` → Back to this chat ({context.source.chat_name or context.source.chat_id})')
# Local always available
lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)")
lines.append('- `"local"` → Save to local files only (~/.hermes/cron/output/)')
# Platform home channels
for platform, home in context.home_channels.items():
lines.append(f"- `\"{platform.value}\"` → Home channel ({home.name})")
lines.append(f'- `"{platform.value}"` → Home channel ({home.name})')
# Note about explicit targeting
lines.append("")
lines.append("*For explicit targeting, use `\"platform:chat_id\"` format if the user provides a specific chat ID.*")
lines.append('*For explicit targeting, use `"platform:chat_id"` format if the user provides a specific chat ID.*')
return "\n".join(lines)
@@ -220,32 +218,33 @@ def build_session_context_prompt(context: SessionContext) -> str:
class SessionEntry:
"""
Entry in the session store.
Maps a session key to its current session ID and metadata.
"""
session_key: str
session_id: str
created_at: datetime
updated_at: datetime
# Origin metadata for delivery routing
origin: Optional[SessionSource] = None
origin: SessionSource | None = None
# Display metadata
display_name: Optional[str] = None
platform: Optional[Platform] = None
display_name: str | None = None
platform: Platform | None = None
chat_type: str = "dm"
# Token tracking
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
# Set when a session was created because the previous one expired;
# consumed once by the message handler to inject a notice into context
was_auto_reset: bool = False
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
result = {
"session_key": self.session_key,
"session_id": self.session_id,
@@ -261,20 +260,20 @@ class SessionEntry:
if self.origin:
result["origin"] = self.origin.to_dict()
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SessionEntry":
def from_dict(cls, data: dict[str, Any]) -> "SessionEntry":
origin = None
if "origin" in data and data["origin"]:
origin = SessionSource.from_dict(data["origin"])
platform = None
if data.get("platform"):
try:
platform = Platform(data["platform"])
except ValueError:
pass
return cls(
session_key=data["session_key"],
session_id=data["session_id"],
@@ -307,66 +306,65 @@ def build_session_key(source: SessionSource) -> str:
class SessionStore:
"""
Manages session storage and retrieval.
Uses SQLite (via SessionDB) for session metadata and message transcripts.
Falls back to legacy JSONL files if SQLite is unavailable.
"""
def __init__(self, sessions_dir: Path, config: GatewayConfig,
has_active_processes_fn=None,
on_auto_reset=None):
def __init__(self, sessions_dir: Path, config: GatewayConfig, has_active_processes_fn=None, on_auto_reset=None):
self.sessions_dir = sessions_dir
self.config = config
self._entries: Dict[str, SessionEntry] = {}
self._entries: dict[str, SessionEntry] = {}
self._loaded = False
self._has_active_processes_fn = has_active_processes_fn
# on_auto_reset is deprecated — memory flush now runs proactively
# via the background session expiry watcher in GatewayRunner.
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
# Initialize SQLite session database
self._db = None
try:
from hermes_state import SessionDB
self._db = SessionDB()
except Exception as e:
print(f"[gateway] Warning: SQLite session store unavailable, falling back to JSONL: {e}")
def _ensure_loaded(self) -> None:
"""Load sessions index from disk if not already loaded."""
if self._loaded:
return
self.sessions_dir.mkdir(parents=True, exist_ok=True)
sessions_file = self.sessions_dir / "sessions.json"
if sessions_file.exists():
try:
with open(sessions_file, "r", encoding="utf-8") as f:
with open(sessions_file, encoding="utf-8") as f:
data = json.load(f)
for key, entry_data in data.items():
self._entries[key] = SessionEntry.from_dict(entry_data)
except Exception as e:
print(f"[gateway] Warning: Failed to load sessions: {e}")
self._loaded = True
def _save(self) -> None:
"""Save sessions index to disk (kept for session key -> ID mapping)."""
self.sessions_dir.mkdir(parents=True, exist_ok=True)
sessions_file = self.sessions_dir / "sessions.json"
data = {key: entry.to_dict() for key, entry in self._entries.items()}
with open(sessions_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
def _generate_session_key(self, source: SessionSource) -> str:
"""Generate a session key from a source."""
return build_session_key(source)
def _is_session_expired(self, entry: SessionEntry) -> bool:
"""Check if a session has expired based on its reset policy.
Works from the entry alone — no SessionSource needed.
Used by the background expiry watcher to proactively flush memories.
Sessions with active background processes are never considered expired.
@@ -393,7 +391,9 @@ class SessionStore:
if policy.mode in ("daily", "both"):
today_reset = now.replace(
hour=policy.at_hour,
minute=0, second=0, microsecond=0,
minute=0,
second=0,
microsecond=0,
)
if now.hour < policy.at_hour:
today_reset -= timedelta(days=1)
@@ -405,7 +405,7 @@ class SessionStore:
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
"""
Check if a session should be reset based on policy.
Sessions with active background processes are never reset.
"""
if self._has_active_processes_fn:
@@ -413,36 +413,28 @@ class SessionStore:
if self._has_active_processes_fn(session_key):
return False
policy = self.config.get_reset_policy(
platform=source.platform,
session_type=source.chat_type
)
policy = self.config.get_reset_policy(platform=source.platform, session_type=source.chat_type)
if policy.mode == "none":
return False
now = datetime.now()
if policy.mode in ("idle", "both"):
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
if now > idle_deadline:
return True
if policy.mode in ("daily", "both"):
today_reset = now.replace(
hour=policy.at_hour,
minute=0,
second=0,
microsecond=0
)
today_reset = now.replace(hour=policy.at_hour, minute=0, second=0, microsecond=0)
if now.hour < policy.at_hour:
today_reset -= timedelta(days=1)
if entry.updated_at < today_reset:
return True
return False
def has_any_sessions(self) -> bool:
"""Check if any sessions have ever been created (across all platforms).
@@ -463,26 +455,22 @@ class SessionStore:
# This covers the rare case where the DB is unavailable.
self._ensure_loaded()
return len(self._entries) > 1
def get_or_create_session(
self,
source: SessionSource,
force_new: bool = False
) -> SessionEntry:
def get_or_create_session(self, source: SessionSource, force_new: bool = False) -> SessionEntry:
"""
Get an existing session or create a new one.
Evaluates reset policy to determine if the existing session is stale.
Creates a session record in SQLite when a new session starts.
"""
self._ensure_loaded()
session_key = self._generate_session_key(source)
now = datetime.now()
if session_key in self._entries and not force_new:
entry = self._entries[session_key]
if not self._should_reset(entry, source):
entry.updated_at = now
self._save()
@@ -500,10 +488,10 @@ class SessionStore:
logger.debug("Session DB operation failed: %s", e)
else:
was_auto_reset = False
# Create new session
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
entry = SessionEntry(
session_key=session_key,
session_id=session_id,
@@ -515,10 +503,10 @@ class SessionStore:
chat_type=source.chat_type,
was_auto_reset=was_auto_reset,
)
self._entries[session_key] = entry
self._save()
# Create session in SQLite
if self._db:
try:
@@ -529,18 +517,13 @@ class SessionStore:
)
except Exception as e:
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
return entry
def update_session(
self,
session_key: str,
input_tokens: int = 0,
output_tokens: int = 0
) -> None:
def update_session(self, session_key: str, input_tokens: int = 0, output_tokens: int = 0) -> None:
"""Update a session's metadata after an interaction."""
self._ensure_loaded()
if session_key in self._entries:
entry = self._entries[session_key]
entry.updated_at = datetime.now()
@@ -548,34 +531,32 @@ class SessionStore:
entry.output_tokens += output_tokens
entry.total_tokens = entry.input_tokens + entry.output_tokens
self._save()
if self._db:
try:
self._db.update_token_counts(
entry.session_id, input_tokens, output_tokens
)
self._db.update_token_counts(entry.session_id, input_tokens, output_tokens)
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
def reset_session(self, session_key: str) -> SessionEntry | None:
"""Force reset a session, creating a new session ID."""
self._ensure_loaded()
if session_key not in self._entries:
return None
old_entry = self._entries[session_key]
# End old session in SQLite
if self._db:
try:
self._db.end_session(old_entry.session_id, "session_reset")
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
now = datetime.now()
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
new_entry = SessionEntry(
session_key=session_key,
session_id=session_id,
@@ -586,10 +567,10 @@ class SessionStore:
platform=old_entry.platform,
chat_type=old_entry.chat_type,
)
self._entries[session_key] = new_entry
self._save()
# Create new session in SQLite
if self._db:
try:
@@ -600,10 +581,10 @@ class SessionStore:
)
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
return new_entry
def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]:
def switch_session(self, session_key: str, target_session_id: str) -> SessionEntry | None:
"""Switch a session key to point at an existing session ID.
Used by ``/resume`` to restore a previously-named session.
@@ -645,25 +626,25 @@ class SessionStore:
self._save()
return new_entry
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
def list_sessions(self, active_minutes: int | None = None) -> list[SessionEntry]:
"""List all sessions, optionally filtered by activity."""
self._ensure_loaded()
entries = list(self._entries.values())
if active_minutes is not None:
cutoff = datetime.now() - timedelta(minutes=active_minutes)
entries = [e for e in entries if e.updated_at >= cutoff]
entries.sort(key=lambda e: e.updated_at, reverse=True)
return entries
def get_transcript_path(self, session_id: str) -> Path:
"""Get the path to a session's legacy transcript file."""
return self.sessions_dir / f"{session_id}.jsonl"
def append_to_transcript(self, session_id: str, message: Dict[str, Any]) -> None:
def append_to_transcript(self, session_id: str, message: dict[str, Any]) -> None:
"""Append a message to a session's transcript (SQLite + legacy JSONL)."""
# Write to SQLite
if self._db:
@@ -678,15 +659,15 @@ class SessionStore:
)
except Exception as e:
logger.debug("Session DB operation failed: %s", e)
# Also write legacy JSONL (keeps existing tooling working during transition)
transcript_path = self.get_transcript_path(session_id)
with open(transcript_path, "a", encoding="utf-8") as f:
f.write(json.dumps(message, ensure_ascii=False) + "\n")
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
def rewrite_transcript(self, session_id: str, messages: list[dict[str, Any]]) -> None:
"""Replace the entire transcript for a session with new messages.
Used by /retry, /undo, and /compress to persist modified conversation history.
Rewrites both SQLite and legacy JSONL storage.
"""
@@ -705,14 +686,14 @@ class SessionStore:
)
except Exception as e:
logger.debug("Failed to rewrite transcript in DB: %s", e)
# JSONL: overwrite the file
transcript_path = self.get_transcript_path(session_id)
with open(transcript_path, "w", encoding="utf-8") as f:
for msg in messages:
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
def load_transcript(self, session_id: str) -> list[dict[str, Any]]:
"""Load all messages from a session's transcript."""
# Try SQLite first
if self._db:
@@ -722,51 +703,49 @@ class SessionStore:
return messages
except Exception as e:
logger.debug("Could not load messages from DB: %s", e)
# Fall back to legacy JSONL
transcript_path = self.get_transcript_path(session_id)
if not transcript_path.exists():
return []
messages = []
with open(transcript_path, "r", encoding="utf-8") as f:
with open(transcript_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
messages.append(json.loads(line))
return messages
def build_session_context(
source: SessionSource,
config: GatewayConfig,
session_entry: Optional[SessionEntry] = None
source: SessionSource, config: GatewayConfig, session_entry: SessionEntry | None = None
) -> SessionContext:
"""
Build a full session context from a source and config.
This is used to inject context into the agent's system prompt.
"""
connected = config.get_connected_platforms()
home_channels = {}
for platform in connected:
home = config.get_home_channel(platform)
if home:
home_channels[platform] = home
context = SessionContext(
source=source,
connected_platforms=connected,
home_channels=home_channels,
)
if session_entry:
context.session_key = session_entry.session_key
context.session_id = session_entry.session_id
context.created_at = session_entry.created_at
context.updated_at = session_entry.updated_at
return context

View File

@@ -13,7 +13,6 @@ concurrently under distinct configurations).
import os
from pathlib import Path
from typing import Optional
def _get_pid_path() -> Path:
@@ -37,7 +36,7 @@ def remove_pid_file() -> None:
pass
def get_running_pid() -> Optional[int]:
def get_running_pid() -> int | None:
"""Return the PID of a running gateway instance, or ``None``.
Checks the PID file and verifies the process is actually alive.

View File

@@ -12,8 +12,6 @@ import json
import os
import time
from pathlib import Path
from typing import Optional
CACHE_PATH = Path(os.path.expanduser("~/.hermes/sticker_cache.json"))
@@ -43,7 +41,7 @@ def _save_cache(cache: dict) -> None:
)
def get_cached_description(file_unique_id: str) -> Optional[dict]:
def get_cached_description(file_unique_id: str) -> dict | None:
"""
Look up a cached sticker description.
@@ -92,11 +90,11 @@ def build_sticker_injection(
"""
context = ""
if set_name and emoji:
context = f" {emoji} from \"{set_name}\""
context = f' {emoji} from "{set_name}"'
elif emoji:
context = f" {emoji}"
return f"[The user sent a sticker{context}~ It shows: \"{description}\" (=^.w.^=)]"
return f'[The user sent a sticker{context}~ It shows: "{description}" (=^.w.^=)]'
def build_animated_sticker_injection(emoji: str = "") -> str:

View File

@@ -5,7 +5,7 @@ Provides subcommands for:
- hermes chat - Interactive chat (same as ./hermes)
- hermes gateway - Run gateway in foreground
- hermes gateway start - Start gateway service
- hermes gateway stop - Stop gateway service
- hermes gateway stop - Stop gateway service
- hermes setup - Interactive setup wizard
- hermes status - Show status of all components
- hermes cron - Manage cron jobs

View File

@@ -15,27 +15,25 @@ Architecture:
from __future__ import annotations
import base64
import hashlib
import json
import logging
import os
import shutil
import stat
import base64
import hashlib
import subprocess
import time
import uuid
import webbrowser
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
import httpx
import yaml
from hermes_cli.config import get_hermes_home, get_config_path
from hermes_cli.config import get_config_path, get_hermes_home
from hermes_constants import OPENROUTER_BASE_URL
logger = logging.getLogger(__name__)
@@ -58,8 +56,8 @@ DEFAULT_NOUS_INFERENCE_URL = "https://inference-api.nousresearch.com/v1"
DEFAULT_NOUS_CLIENT_ID = "hermes-cli"
DEFAULT_NOUS_SCOPE = "inference:mint_agent_key"
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
@@ -70,9 +68,11 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
# Provider Registry
# =============================================================================
@dataclass
class ProviderConfig:
"""Describes a known inference provider."""
id: str
name: str
auth_type: str # "oauth_device_code", "oauth_external", or "api_key"
@@ -80,14 +80,14 @@ class ProviderConfig:
inference_base_url: str = ""
client_id: str = ""
scope: str = ""
extra: Dict[str, Any] = field(default_factory=dict)
extra: dict[str, Any] = field(default_factory=dict)
# For API-key providers: env vars to check (in priority order)
api_key_env_vars: tuple = ()
# Optional env var for base URL override
base_url_env_var: str = ""
PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
PROVIDER_REGISTRY: dict[str, ProviderConfig] = {
"nous": ProviderConfig(
id="nous",
name="Nous Portal",
@@ -172,14 +172,14 @@ def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) ->
ZAI_ENDPOINTS = [
# (id, base_url, default_model, label)
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
]
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str, str]]:
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> dict[str, str] | None:
"""Probe z.ai endpoints to find one that accepts this API key.
Returns {"id": ..., "base_url": ..., "model": ..., "label": ...} for the
@@ -219,6 +219,7 @@ def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str
# Error Types
# =============================================================================
class AuthError(RuntimeError):
"""Structured auth error with UX mapping hints."""
@@ -227,7 +228,7 @@ class AuthError(RuntimeError):
message: str,
*,
provider: str = "",
code: Optional[str] = None,
code: str | None = None,
relogin_required: bool = False,
) -> None:
super().__init__(message)
@@ -245,16 +246,10 @@ def format_auth_error(error: Exception) -> str:
return f"{error} Run `hermes model` to re-authenticate."
if error.code == "subscription_required":
return (
"No active paid subscription found on Nous Portal. "
"Please purchase/activate a subscription, then retry."
)
return "No active paid subscription found on Nous Portal. Please purchase/activate a subscription, then retry."
if error.code == "insufficient_credits":
return (
"Subscription credits are exhausted. "
"Top up/renew credits in Nous Portal, then retry."
)
return "Subscription credits are exhausted. Top up/renew credits in Nous Portal, then retry."
if error.code == "temporarily_unavailable":
return f"{error} Please retry in a few seconds."
@@ -262,7 +257,7 @@ def format_auth_error(error: Exception) -> str:
return str(error)
def _token_fingerprint(token: Any) -> Optional[str]:
def _token_fingerprint(token: Any) -> str | None:
"""Return a short hash fingerprint for telemetry without leaking token bytes."""
if not isinstance(token, str):
return None
@@ -277,10 +272,10 @@ def _oauth_trace_enabled() -> bool:
return raw in {"1", "true", "yes", "on"}
def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None:
def _oauth_trace(event: str, *, sequence_id: str | None = None, **fields: Any) -> None:
if not _oauth_trace_enabled():
return
payload: Dict[str, Any] = {"event": event}
payload: dict[str, Any] = {"event": event}
if sequence_id:
payload["sequence_id"] = sequence_id
payload.update(fields)
@@ -291,6 +286,7 @@ def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any
# Auth Store — persistence layer for ~/.hermes/auth.json
# =============================================================================
def _auth_file_path() -> Path:
return get_hermes_home() / "auth.json"
@@ -326,7 +322,7 @@ def _auth_store_lock(timeout_seconds: float = AUTH_LOCK_TIMEOUT_SECONDS):
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
def _load_auth_store(auth_file: Optional[Path] = None) -> Dict[str, Any]:
def _load_auth_store(auth_file: Path | None = None) -> dict[str, Any]:
auth_file = auth_file or _auth_file_path()
if not auth_file.exists():
return {"version": AUTH_STORE_VERSION, "providers": {}}
@@ -345,17 +341,16 @@ def _load_auth_store(auth_file: Optional[Path] = None) -> Dict[str, Any]:
providers = {}
if "nous_portal" in systems:
providers["nous"] = systems["nous_portal"]
return {"version": AUTH_STORE_VERSION, "providers": providers,
"active_provider": "nous" if providers else None}
return {"version": AUTH_STORE_VERSION, "providers": providers, "active_provider": "nous" if providers else None}
return {"version": AUTH_STORE_VERSION, "providers": {}}
def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
def _save_auth_store(auth_store: dict[str, Any]) -> Path:
auth_file = _auth_file_path()
auth_file.parent.mkdir(parents=True, exist_ok=True)
auth_store["version"] = AUTH_STORE_VERSION
auth_store["updated_at"] = datetime.now(timezone.utc).isoformat()
auth_store["updated_at"] = datetime.now(UTC).isoformat()
payload = json.dumps(auth_store, indent=2) + "\n"
tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
try:
@@ -387,7 +382,7 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
return auth_file
def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Optional[Dict[str, Any]]:
def _load_provider_state(auth_store: dict[str, Any], provider_id: str) -> dict[str, Any] | None:
providers = auth_store.get("providers")
if not isinstance(providers, dict):
return None
@@ -395,7 +390,7 @@ def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Option
return dict(state) if isinstance(state, dict) else None
def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Dict[str, Any]) -> None:
def _save_provider_state(auth_store: dict[str, Any], provider_id: str, state: dict[str, Any]) -> None:
providers = auth_store.setdefault("providers", {})
if not isinstance(providers, dict):
auth_store["providers"] = {}
@@ -404,19 +399,19 @@ def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Di
auth_store["active_provider"] = provider_id
def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]:
def get_provider_auth_state(provider_id: str) -> dict[str, Any] | None:
"""Return persisted auth state for a provider, or None."""
auth_store = _load_auth_store()
return _load_provider_state(auth_store, provider_id)
def get_active_provider() -> Optional[str]:
def get_active_provider() -> str | None:
"""Return the currently active provider ID from auth store."""
auth_store = _load_auth_store()
return auth_store.get("active_provider")
def clear_provider_auth(provider_id: Optional[str] = None) -> bool:
def clear_provider_auth(provider_id: str | None = None) -> bool:
"""
Clear auth state for a provider. Used by `hermes logout`.
If provider_id is None, clears the active provider.
@@ -455,11 +450,12 @@ def deactivate_provider() -> None:
# Provider Resolution — picks which provider to use
# =============================================================================
def resolve_provider(
requested: Optional[str] = None,
requested: str | None = None,
*,
explicit_api_key: Optional[str] = None,
explicit_base_url: Optional[str] = None,
explicit_api_key: str | None = None,
explicit_base_url: str | None = None,
) -> str:
"""
Determine which inference provider to use.
@@ -475,9 +471,14 @@ def resolve_provider(
# Normalize provider aliases
_PROVIDER_ALIASES = {
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
"kimi": "kimi-coding", "moonshot": "kimi-coding",
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
"glm": "zai",
"z-ai": "zai",
"z.ai": "zai",
"zhipu": "zai",
"kimi": "kimi-coding",
"moonshot": "kimi-coding",
"minimax-china": "minimax-cn",
"minimax_cn": "minimax-cn",
}
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
@@ -524,7 +525,8 @@ def resolve_provider(
# Timestamp / TTL helpers
# =============================================================================
def _parse_iso_timestamp(value: Any) -> Optional[float]:
def _parse_iso_timestamp(value: Any) -> float | None:
if not isinstance(value, str) or not value:
return None
text = value.strip()
@@ -537,7 +539,7 @@ def _parse_iso_timestamp(value: Any) -> Optional[float]:
except Exception:
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
parsed = parsed.replace(tzinfo=UTC)
return parsed.timestamp()
@@ -556,14 +558,14 @@ def _coerce_ttl_seconds(expires_in: Any) -> int:
return max(0, ttl)
def _optional_base_url(value: Any) -> Optional[str]:
def _optional_base_url(value: Any) -> str | None:
if not isinstance(value, str):
return None
cleaned = value.strip().rstrip("/")
return cleaned if cleaned else None
def _decode_jwt_claims(token: Any) -> Dict[str, Any]:
def _decode_jwt_claims(token: Any) -> dict[str, Any]:
if not isinstance(token, str) or token.count(".") != 2:
return {}
payload = token.split(".")[1]
@@ -588,6 +590,7 @@ def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> boo
# SSH / remote session detection
# =============================================================================
def _is_remote_session() -> bool:
"""Detect if running in an SSH session where webbrowser.open() won't work."""
return bool(os.getenv("SSH_CLIENT") or os.getenv("SSH_TTY"))
@@ -601,9 +604,10 @@ def _is_remote_session() -> bool:
# where one app's refresh invalidates the other's session.
# =============================================================================
def _read_codex_tokens(*, _lock: bool = True) -> Dict[str, Any]:
def _read_codex_tokens(*, _lock: bool = True) -> dict[str, Any]:
"""Read Codex OAuth tokens from Hermes auth store (~/.hermes/auth.json).
Returns dict with 'tokens' (access_token, refresh_token) and 'last_refresh'.
Raises AuthError if no Codex tokens are stored.
"""
@@ -650,10 +654,10 @@ def _read_codex_tokens(*, _lock: bool = True) -> Dict[str, Any]:
}
def _save_codex_tokens(tokens: Dict[str, str], last_refresh: str = None) -> None:
def _save_codex_tokens(tokens: dict[str, str], last_refresh: str = None) -> None:
"""Save Codex OAuth tokens to Hermes auth store (~/.hermes/auth.json)."""
if last_refresh is None:
last_refresh = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
last_refresh = datetime.now(UTC).isoformat().replace("+00:00", "Z")
with _auth_store_lock():
auth_store = _load_auth_store()
state = _load_provider_state(auth_store, "openai-codex") or {}
@@ -665,11 +669,11 @@ def _save_codex_tokens(tokens: Dict[str, str], last_refresh: str = None) -> None
def _refresh_codex_auth_tokens(
tokens: Dict[str, str],
tokens: dict[str, str],
timeout_seconds: float,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Refresh Codex access token using the refresh token.
Saves the new tokens to Hermes auth store automatically.
"""
refresh_token = tokens.get("refresh_token")
@@ -746,9 +750,9 @@ def _refresh_codex_auth_tokens(
return updated_tokens
def _import_codex_cli_tokens() -> Optional[Dict[str, str]]:
def _import_codex_cli_tokens() -> dict[str, str] | None:
"""Try to read tokens from ~/.codex/auth.json (Codex CLI shared file).
Returns tokens dict if valid, None otherwise. Does NOT write to the shared file.
"""
codex_home = os.getenv("CODEX_HOME", "").strip()
@@ -774,7 +778,7 @@ def resolve_codex_runtime_credentials(
force_refresh: bool = False,
refresh_if_expiring: bool = True,
refresh_skew_seconds: int = CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Resolve runtime credentials from Hermes's own Codex token store."""
try:
data = _read_codex_tokens()
@@ -817,10 +821,7 @@ def resolve_codex_runtime_credentials(
tokens = _refresh_codex_auth_tokens(tokens, refresh_timeout_seconds)
access_token = str(tokens.get("access_token", "") or "").strip()
base_url = (
os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/")
or DEFAULT_CODEX_BASE_URL
)
base_url = os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") or DEFAULT_CODEX_BASE_URL
return {
"provider": "openai-codex",
@@ -836,24 +837,19 @@ def resolve_codex_runtime_credentials(
# TLS verification helper
# =============================================================================
def _resolve_verify(
*,
insecure: Optional[bool] = None,
ca_bundle: Optional[str] = None,
auth_state: Optional[Dict[str, Any]] = None,
insecure: bool | None = None,
ca_bundle: str | None = None,
auth_state: dict[str, Any] | None = None,
) -> bool | str:
tls_state = auth_state.get("tls") if isinstance(auth_state, dict) else {}
tls_state = tls_state if isinstance(tls_state, dict) else {}
effective_insecure = (
bool(insecure) if insecure is not None
else bool(tls_state.get("insecure", False))
)
effective_insecure = bool(insecure) if insecure is not None else bool(tls_state.get("insecure", False))
effective_ca = (
ca_bundle
or tls_state.get("ca_bundle")
or os.getenv("HERMES_CA_BUNDLE")
or os.getenv("SSL_CERT_FILE")
ca_bundle or tls_state.get("ca_bundle") or os.getenv("HERMES_CA_BUNDLE") or os.getenv("SSL_CERT_FILE")
)
if effective_insecure:
@@ -867,12 +863,13 @@ def _resolve_verify(
# OAuth Device Code Flow — generic, parameterized by provider
# =============================================================================
def _request_device_code(
client: httpx.Client,
portal_base_url: str,
client_id: str,
scope: Optional[str],
) -> Dict[str, Any]:
scope: str | None,
) -> dict[str, Any]:
"""POST to the device code endpoint. Returns device_code, user_code, etc."""
response = client.post(
f"{portal_base_url}/api/oauth/device/code",
@@ -885,8 +882,12 @@ def _request_device_code(
data = response.json()
required_fields = [
"device_code", "user_code", "verification_uri",
"verification_uri_complete", "expires_in", "interval",
"device_code",
"user_code",
"verification_uri",
"verification_uri_complete",
"expires_in",
"interval",
]
missing = [f for f in required_fields if f not in data]
if missing:
@@ -901,7 +902,7 @@ def _poll_for_token(
device_code: str,
expires_in: int,
poll_interval: int,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Poll the token endpoint until the user approves or the code expires."""
deadline = time.time() + max(1, expires_in)
current_interval = max(1, min(poll_interval, DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS))
@@ -947,13 +948,14 @@ def _poll_for_token(
# Nous Portal — token refresh, agent key minting, model discovery
# =============================================================================
def _refresh_access_token(
*,
client: httpx.Client,
portal_base_url: str,
client_id: str,
refresh_token: str,
) -> Dict[str, Any]:
) -> dict[str, Any]:
response = client.post(
f"{portal_base_url}/api/oauth/token",
data={
@@ -966,15 +968,15 @@ def _refresh_access_token(
if response.status_code == 200:
payload = response.json()
if "access_token" not in payload:
raise AuthError("Refresh response missing access_token",
provider="nous", code="invalid_token", relogin_required=True)
raise AuthError(
"Refresh response missing access_token", provider="nous", code="invalid_token", relogin_required=True
)
return payload
try:
error_payload = response.json()
except Exception as exc:
raise AuthError("Refresh token exchange failed",
provider="nous", relogin_required=True) from exc
raise AuthError("Refresh token exchange failed", provider="nous", relogin_required=True) from exc
code = str(error_payload.get("error", "invalid_grant"))
description = str(error_payload.get("error_description") or "Refresh token exchange failed")
@@ -988,7 +990,7 @@ def _mint_agent_key(
portal_base_url: str,
access_token: str,
min_ttl_seconds: int,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Mint (or reuse) a short-lived inference API key."""
response = client.post(
f"{portal_base_url}/api/oauth/agent-key",
@@ -999,15 +1001,13 @@ def _mint_agent_key(
if response.status_code == 200:
payload = response.json()
if "api_key" not in payload:
raise AuthError("Mint response missing api_key",
provider="nous", code="server_error")
raise AuthError("Mint response missing api_key", provider="nous", code="server_error")
return payload
try:
error_payload = response.json()
except Exception as exc:
raise AuthError("Agent key mint request failed",
provider="nous", code="server_error") from exc
raise AuthError("Agent key mint request failed", provider="nous", code="server_error") from exc
code = str(error_payload.get("error", "server_error"))
description = str(error_payload.get("error_description") or "Agent key mint request failed")
@@ -1021,7 +1021,7 @@ def fetch_nous_models(
api_key: str,
timeout_seconds: float = 15.0,
verify: bool | str = True,
) -> List[str]:
) -> list[str]:
"""Fetch available model IDs from the Nous inference API."""
timeout = httpx.Timeout(timeout_seconds)
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
@@ -1044,7 +1044,7 @@ def fetch_nous_models(
if not isinstance(data, list):
return []
model_ids: List[str] = []
model_ids: list[str] = []
for item in data:
if not isinstance(item, dict):
continue
@@ -1059,7 +1059,7 @@ def fetch_nous_models(
return list(dict.fromkeys(model_ids))
def _agent_key_is_usable(state: Dict[str, Any], min_ttl_seconds: int) -> bool:
def _agent_key_is_usable(state: dict[str, Any], min_ttl_seconds: int) -> bool:
key = state.get("agent_key")
if not isinstance(key, str) or not key.strip():
return False
@@ -1070,10 +1070,10 @@ def resolve_nous_runtime_credentials(
*,
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
timeout_seconds: float = 15.0,
insecure: Optional[bool] = None,
ca_bundle: Optional[str] = None,
insecure: bool | None = None,
ca_bundle: str | None = None,
force_mint: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Resolve Nous inference credentials for runtime use.
@@ -1092,8 +1092,7 @@ def resolve_nous_runtime_credentials(
state = _load_provider_state(auth_store, "nous")
if not state:
raise AuthError("Hermes is not logged into Nous Portal.",
provider="nous", relogin_required=True)
raise AuthError("Hermes is not logged into Nous Portal.", provider="nous", relogin_required=True)
portal_base_url = (
_optional_base_url(state.get("portal_base_url"))
@@ -1143,14 +1142,14 @@ def resolve_nous_runtime_credentials(
refresh_token = state.get("refresh_token")
if not isinstance(access_token, str) or not access_token:
raise AuthError("No access token found for Nous Portal login.",
provider="nous", relogin_required=True)
raise AuthError("No access token found for Nous Portal login.", provider="nous", relogin_required=True)
# Step 1: refresh access token if expiring
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
if not isinstance(refresh_token, str) or not refresh_token:
raise AuthError("Session expired and no refresh token is available.",
provider="nous", relogin_required=True)
raise AuthError(
"Session expired and no refresh token is available.", provider="nous", relogin_required=True
)
_oauth_trace(
"refresh_start",
@@ -1159,10 +1158,12 @@ def resolve_nous_runtime_credentials(
refresh_token_fp=_token_fingerprint(refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token,
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
refresh_token=refresh_token,
)
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
previous_refresh_token = refresh_token
state["access_token"] = refreshed["access_token"]
@@ -1174,9 +1175,7 @@ def resolve_nous_runtime_credentials(
inference_base_url = refreshed_url
state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
state["expires_at"] = datetime.fromtimestamp(now.timestamp() + access_ttl, tz=UTC).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
@@ -1191,7 +1190,7 @@ def resolve_nous_runtime_credentials(
# Step 2: mint agent key if missing/expiring
used_cached_key = False
mint_payload: Optional[Dict[str, Any]] = None
mint_payload: dict[str, Any] | None = None
if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
used_cached_key = True
@@ -1204,8 +1203,10 @@ def resolve_nous_runtime_credentials(
access_token_fp=_token_fingerprint(access_token),
)
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
client=client,
portal_base_url=portal_base_url,
access_token=access_token,
min_ttl_seconds=min_key_ttl_seconds,
)
except AuthError as exc:
_oauth_trace(
@@ -1227,10 +1228,12 @@ def resolve_nous_runtime_credentials(
refresh_token_fp=_token_fingerprint(latest_refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=latest_refresh_token,
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
refresh_token=latest_refresh_token,
)
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
@@ -1241,9 +1244,7 @@ def resolve_nous_runtime_credentials(
inference_base_url = refreshed_url
state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
state["expires_at"] = datetime.fromtimestamp(now.timestamp() + access_ttl, tz=UTC).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
@@ -1257,14 +1258,16 @@ def resolve_nous_runtime_credentials(
_persist_state("post_refresh_mint_retry")
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
client=client,
portal_base_url=portal_base_url,
access_token=access_token,
min_ttl_seconds=min_key_ttl_seconds,
)
else:
raise
if mint_payload is not None:
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
state["agent_key"] = mint_payload.get("api_key")
state["agent_key_id"] = mint_payload.get("key_id")
state["agent_key_expires_at"] = mint_payload.get("expires_at")
@@ -1293,8 +1296,7 @@ def resolve_nous_runtime_credentials(
api_key = state.get("agent_key")
if not isinstance(api_key, str) or not api_key:
raise AuthError("Failed to resolve a Nous inference API key",
provider="nous", code="server_error")
raise AuthError("Failed to resolve a Nous inference API key", provider="nous", code="server_error")
expires_at = state.get("agent_key_expires_at")
expires_epoch = _parse_iso_timestamp(expires_at)
@@ -1319,7 +1321,8 @@ def resolve_nous_runtime_credentials(
# Status helpers
# =============================================================================
def get_nous_auth_status() -> Dict[str, Any]:
def get_nous_auth_status() -> dict[str, Any]:
"""Status snapshot for `hermes status` output."""
state = get_provider_auth_state("nous")
if not state:
@@ -1341,7 +1344,7 @@ def get_nous_auth_status() -> Dict[str, Any]:
}
def get_codex_auth_status() -> Dict[str, Any]:
def get_codex_auth_status() -> dict[str, Any]:
"""Status snapshot for Codex auth."""
try:
creds = resolve_codex_runtime_credentials()
@@ -1360,7 +1363,7 @@ def get_codex_auth_status() -> Dict[str, Any]:
}
def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
def get_api_key_provider_status(provider_id: str) -> dict[str, Any]:
"""Status snapshot for API-key providers (z.ai, Kimi, MiniMax)."""
pconfig = PROVIDER_REGISTRY.get(provider_id)
if not pconfig or pconfig.auth_type != "api_key":
@@ -1396,7 +1399,7 @@ def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
}
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
def get_auth_status(provider_id: str | None = None) -> dict[str, Any]:
"""Generic auth status dispatcher."""
target = provider_id or get_active_provider()
if target == "nous":
@@ -1410,7 +1413,7 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
return {"logged_in": False}
def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
def resolve_api_key_provider_credentials(provider_id: str) -> dict[str, Any]:
"""Resolve API key and base URL for an API-key provider.
Returns dict with: provider, api_key, base_url, source.
@@ -1455,7 +1458,8 @@ def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
# External credential detection
# =============================================================================
def detect_external_credentials() -> List[Dict[str, Any]]:
def detect_external_credentials() -> list[dict[str, Any]]:
"""Scan for credentials from other CLI tools that Hermes can reuse.
Returns a list of dicts, each with:
@@ -1463,17 +1467,19 @@ def detect_external_credentials() -> List[Dict[str, Any]]:
- path: str -- filesystem path where creds were found
- label: str -- human-friendly description for the setup UI
"""
found: List[Dict[str, Any]] = []
found: list[dict[str, Any]] = []
# Codex CLI: ~/.codex/auth.json (importable, not shared)
cli_tokens = _import_codex_cli_tokens()
if cli_tokens:
codex_path = Path.home() / ".codex" / "auth.json"
found.append({
"provider": "openai-codex",
"path": str(codex_path),
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes login` to create a separate session",
})
found.append(
{
"provider": "openai-codex",
"path": str(codex_path),
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes login` to create a separate session",
}
)
return found
@@ -1482,6 +1488,7 @@ def detect_external_credentials() -> List[Dict[str, Any]]:
# CLI Commands — login / logout
# =============================================================================
def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Path:
"""Update config.yaml and auth.json to reflect the active provider."""
# Set active_provider in auth.json so auto-resolution picks this provider
@@ -1494,7 +1501,7 @@ def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Pa
config_path = get_config_path()
config_path.parent.mkdir(parents=True, exist_ok=True)
config: Dict[str, Any] = {}
config: dict[str, Any] = {}
if config_path.exists():
try:
loaded = yaml.safe_load(config_path.read_text()) or {}
@@ -1542,7 +1549,7 @@ def _reset_config_provider() -> Path:
return config_path
def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Optional[str]:
def _prompt_model_selection(model_ids: list[str], current_model: str = "") -> str | None:
"""Interactive model selection. Puts current_model first with a marker. Returns chosen model ID or None."""
# Reorder: current model first, then the rest (deduplicated)
ordered = []
@@ -1564,6 +1571,7 @@ def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Op
# Try arrow-key menu first, fall back to number input
try:
from simple_term_menu import TerminalMenu
choices = [f" {_label(mid)}" for mid in ordered]
choices.append(" Enter custom model name")
choices.append(" Skip (keep current)")
@@ -1621,7 +1629,7 @@ def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Op
def _save_model_choice(model_id: str) -> None:
"""Save the selected model to config.yaml and .env."""
from hermes_cli.config import save_config, load_config, save_env_value
from hermes_cli.config import load_config, save_config, save_env_value
config = load_config()
# Handle both string and dict model formats
@@ -1693,11 +1701,11 @@ def _login_openai_codex(args, pconfig: ProviderConfig) -> None:
config_path = _update_config_for_provider("openai-codex", creds.get("base_url", DEFAULT_CODEX_BASE_URL))
print()
print("Login successful!")
print(f" Auth state: ~/.hermes/auth.json")
print(" Auth state: ~/.hermes/auth.json")
print(f" Config updated: {config_path} (model.provider=openai-codex)")
def _codex_device_code_login() -> Dict[str, Any]:
def _codex_device_code_login() -> dict[str, Any]:
"""Run the OpenAI device code login flow and return credentials dict."""
import time as _time
@@ -1715,13 +1723,15 @@ def _codex_device_code_login() -> Dict[str, Any]:
except Exception as exc:
raise AuthError(
f"Failed to request device code: {exc}",
provider="openai-codex", code="device_code_request_failed",
provider="openai-codex",
code="device_code_request_failed",
)
if resp.status_code != 200:
raise AuthError(
f"Device code request returned status {resp.status_code}.",
provider="openai-codex", code="device_code_request_error",
provider="openai-codex",
code="device_code_request_error",
)
device_data = resp.json()
@@ -1732,14 +1742,15 @@ def _codex_device_code_login() -> Dict[str, Any]:
if not user_code or not device_auth_id:
raise AuthError(
"Device code response missing required fields.",
provider="openai-codex", code="device_code_incomplete",
provider="openai-codex",
code="device_code_incomplete",
)
# Step 2: Show user the code
print("To continue, follow these steps:\n")
print(f" 1. Open this URL in your browser:")
print(" 1. Open this URL in your browser:")
print(f" \033[94m{issuer}/codex/device\033[0m\n")
print(f" 2. Enter this code:")
print(" 2. Enter this code:")
print(f" \033[94m{user_code}\033[0m\n")
print("Waiting for sign-in... (press Ctrl+C to cancel)")
@@ -1766,7 +1777,8 @@ def _codex_device_code_login() -> Dict[str, Any]:
else:
raise AuthError(
f"Device auth polling returned status {poll_resp.status_code}.",
provider="openai-codex", code="device_code_poll_error",
provider="openai-codex",
code="device_code_poll_error",
)
except KeyboardInterrupt:
print("\nLogin cancelled.")
@@ -1775,7 +1787,8 @@ def _codex_device_code_login() -> Dict[str, Any]:
if code_resp is None:
raise AuthError(
"Login timed out after 15 minutes.",
provider="openai-codex", code="device_code_timeout",
provider="openai-codex",
code="device_code_timeout",
)
# Step 4: Exchange authorization code for tokens
@@ -1786,7 +1799,8 @@ def _codex_device_code_login() -> Dict[str, Any]:
if not authorization_code or not code_verifier:
raise AuthError(
"Device auth response missing authorization_code or code_verifier.",
provider="openai-codex", code="device_code_incomplete_exchange",
provider="openai-codex",
code="device_code_incomplete_exchange",
)
try:
@@ -1805,13 +1819,15 @@ def _codex_device_code_login() -> Dict[str, Any]:
except Exception as exc:
raise AuthError(
f"Token exchange failed: {exc}",
provider="openai-codex", code="token_exchange_failed",
provider="openai-codex",
code="token_exchange_failed",
)
if token_resp.status_code != 200:
raise AuthError(
f"Token exchange returned status {token_resp.status_code}.",
provider="openai-codex", code="token_exchange_error",
provider="openai-codex",
code="token_exchange_error",
)
tokens = token_resp.json()
@@ -1821,14 +1837,12 @@ def _codex_device_code_login() -> Dict[str, Any]:
if not access_token:
raise AuthError(
"Token exchange did not return an access_token.",
provider="openai-codex", code="token_exchange_no_access_token",
provider="openai-codex",
code="token_exchange_no_access_token",
)
# Return tokens for the caller to persist (no longer writes to ~/.codex/)
base_url = (
os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/")
or DEFAULT_CODEX_BASE_URL
)
base_url = os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") or DEFAULT_CODEX_BASE_URL
return {
"tokens": {
@@ -1836,7 +1850,7 @@ def _codex_device_code_login() -> Dict[str, Any]:
"refresh_token": refresh_token,
},
"base_url": base_url,
"last_refresh": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
"last_refresh": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
"auth_mode": "chatgpt",
"source": "device-code",
}
@@ -1851,9 +1865,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
or pconfig.portal_base_url
).rstrip("/")
requested_inference_url = (
getattr(args, "inference_url", None)
or os.getenv("NOUS_INFERENCE_BASE_URL")
or pconfig.inference_base_url
getattr(args, "inference_url", None) or os.getenv("NOUS_INFERENCE_BASE_URL") or pconfig.inference_base_url
).rstrip("/")
client_id = getattr(args, "client_id", None) or pconfig.client_id
scope = getattr(args, "scope", None) or pconfig.scope
@@ -1862,11 +1874,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
timeout = httpx.Timeout(timeout_seconds)
insecure = bool(getattr(args, "insecure", False))
ca_bundle = (
getattr(args, "ca_bundle", None)
or os.getenv("HERMES_CA_BUNDLE")
or os.getenv("SSL_CERT_FILE")
)
ca_bundle = getattr(args, "ca_bundle", None) or os.getenv("HERMES_CA_BUNDLE") or os.getenv("SSL_CERT_FILE")
verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True)
# Skip browser open in SSH sessions
@@ -1883,8 +1891,10 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
try:
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
device_data = _request_device_code(
client=client, portal_base_url=portal_base_url,
client_id=client_id, scope=scope,
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
)
verification_url = str(device_data["verification_uri_complete"])
@@ -1908,19 +1918,19 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
print(f"Waiting for approval (polling every {effective_interval}s)...")
token_data = _poll_for_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, device_code=str(device_data["device_code"]),
expires_in=expires_in, poll_interval=interval,
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
device_code=str(device_data["device_code"]),
expires_in=expires_in,
poll_interval=interval,
)
# Process token response
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
token_expires_in = _coerce_ttl_seconds(token_data.get("expires_in", 0))
expires_at = now.timestamp() + token_expires_in
inference_base_url = (
_optional_base_url(token_data.get("inference_base_url"))
or requested_inference_url
)
inference_base_url = _optional_base_url(token_data.get("inference_base_url")) or requested_inference_url
if inference_base_url != requested_inference_url:
print(f"Using portal-provided inference URL: {inference_base_url}")
@@ -1933,7 +1943,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
"access_token": token_data["access_token"],
"refresh_token": token_data.get("refresh_token"),
"obtained_at": now.isoformat(),
"expires_at": datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
"expires_at": datetime.fromtimestamp(expires_at, tz=UTC).isoformat(),
"expires_in": token_expires_in,
"tls": {
"insecure": verify is False,
@@ -1964,13 +1974,13 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
runtime_creds = resolve_nous_runtime_credentials(
min_key_ttl_seconds=5 * 60,
timeout_seconds=timeout_seconds,
insecure=insecure, ca_bundle=ca_bundle,
insecure=insecure,
ca_bundle=ca_bundle,
)
runtime_key = runtime_creds.get("api_key")
runtime_base_url = runtime_creds.get("base_url") or inference_base_url
if not isinstance(runtime_key, str) or not runtime_key:
raise AuthError("No runtime API key available to fetch models",
provider="nous", code="invalid_token")
raise AuthError("No runtime API key available to fetch models", provider="nous", code="invalid_token")
model_ids = fetch_nous_models(
inference_base_url=runtime_base_url,

View File

@@ -9,14 +9,12 @@ import os
import subprocess
import time
from pathlib import Path
from typing import Dict, List, Any, Optional
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from prompt_toolkit import print_formatted_text as _pt_print
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
logger = logging.getLogger(__name__)
@@ -77,7 +75,8 @@ COMPACT_BANNER = """
# Skills scanning
# =========================================================================
def get_available_skills() -> Dict[str, List[str]]:
def get_available_skills() -> dict[str, list[str]]:
"""Scan ~/.hermes/skills/ and return skills grouped by category."""
import os
@@ -110,7 +109,7 @@ def get_available_skills() -> Dict[str, List[str]]:
_UPDATE_CHECK_CACHE_SECONDS = 6 * 3600
def check_for_updates() -> Optional[int]:
def check_for_updates() -> int | None:
"""Check how many commits behind origin/main the local repo is.
Does a ``git fetch`` at most once every 6 hours (cached to
@@ -139,7 +138,8 @@ def check_for_updates() -> Optional[int]:
try:
subprocess.run(
["git", "fetch", "origin", "--quiet"],
capture_output=True, timeout=10,
capture_output=True,
timeout=10,
cwd=str(repo_dir),
)
except Exception:
@@ -149,7 +149,9 @@ def check_for_updates() -> Optional[int]:
try:
result = subprocess.run(
["git", "rev-list", "--count", "HEAD..origin/main"],
capture_output=True, text=True, timeout=5,
capture_output=True,
text=True,
timeout=5,
cwd=str(repo_dir),
)
if result.returncode == 0:
@@ -172,6 +174,7 @@ def check_for_updates() -> Optional[int]:
# Welcome banner
# =========================================================================
def _format_context_length(tokens: int) -> str:
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
if tokens >= 1_000_000:
@@ -183,12 +186,16 @@ def _format_context_length(tokens: int) -> str:
return str(tokens)
def build_welcome_banner(console: Console, model: str, cwd: str,
tools: List[dict] = None,
enabled_toolsets: List[str] = None,
session_id: str = None,
get_toolset_for_tool=None,
context_length: int = None):
def build_welcome_banner(
console: Console,
model: str,
cwd: str,
tools: list[dict] = None,
enabled_toolsets: list[str] = None,
session_id: str = None,
get_toolset_for_tool=None,
context_length: int = None,
):
"""Build and print a welcome banner with caduceus on left and info on right.
Args:
@@ -201,7 +208,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
get_toolset_for_tool: Callable to map tool name -> toolset name.
context_length: Model's context window size in tokens.
"""
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
from model_tools import check_tool_availability
if get_toolset_for_tool is None:
from model_tools import get_toolset_for_tool
@@ -221,7 +229,9 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
model_short = model.split("/")[-1] if "/" in model else model
if len(model_short) > 28:
model_short = model_short[:25] + "..."
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
ctx_str = (
f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
)
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
left_lines.append(f"[dim #B8860B]{cwd}[/]")
if session_id:
@@ -229,7 +239,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
left_content = "\n".join(left_lines)
right_lines = ["[bold #FFBF00]Available Tools[/]"]
toolsets_dict: Dict[str, list] = {}
toolsets_dict: dict[str, list] = {}
for tool in tools:
tool_name = tool["function"]["name"]
@@ -286,6 +296,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
# MCP Servers section (only if configured)
try:
from tools.mcp_tool import get_mcp_status
mcp_status = get_mcp_status()
except Exception:
mcp_status = []
@@ -300,10 +311,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
f"[dim #B8860B]—[/] [#FFF8DC]{srv['tools']} tool(s)[/]"
)
else:
right_lines.append(
f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] "
f"[red]— failed[/]"
)
right_lines.append(f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] [red]— failed[/]")
right_lines.append("")
right_lines.append("[bold #FFBF00]Available Skills[/]")

View File

@@ -9,7 +9,7 @@ with the TUI.
import queue
import time as _time
from hermes_cli.banner import cprint, _DIM, _RST
from hermes_cli.banner import _DIM, _RST, cprint
def clarify_callback(cli, question, choices):
@@ -33,7 +33,7 @@ def clarify_callback(cli, question, choices):
cli._clarify_deadline = _time.monotonic() + timeout
cli._clarify_freetext = is_open_ended
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
while True:
@@ -45,13 +45,13 @@ def clarify_callback(cli, question, choices):
remaining = cli._clarify_deadline - _time.monotonic()
if remaining <= 0:
break
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
cli._clarify_state = None
cli._clarify_freetext = False
cli._clarify_deadline = 0
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
cprint(f"\n{_DIM}(clarify timed out after {timeout}s — agent will decide){_RST}")
return (
@@ -71,7 +71,7 @@ def sudo_password_callback(cli) -> str:
cli._sudo_state = {"response_queue": response_queue}
cli._sudo_deadline = _time.monotonic() + timeout
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
while True:
@@ -79,7 +79,7 @@ def sudo_password_callback(cli) -> str:
result = response_queue.get(timeout=1)
cli._sudo_state = None
cli._sudo_deadline = 0
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
if result:
cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}")
@@ -90,12 +90,12 @@ def sudo_password_callback(cli) -> str:
remaining = cli._sudo_deadline - _time.monotonic()
if remaining <= 0:
break
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
cli._sudo_state = None
cli._sudo_deadline = 0
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}")
return ""
@@ -119,7 +119,7 @@ def approval_callback(cli, command: str, description: str) -> str:
}
cli._approval_deadline = _time.monotonic() + timeout
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
while True:
@@ -127,19 +127,19 @@ def approval_callback(cli, command: str, description: str) -> str:
result = response_queue.get(timeout=1)
cli._approval_state = None
cli._approval_deadline = 0
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
return result
except queue.Empty:
remaining = cli._approval_deadline - _time.monotonic()
if remaining <= 0:
break
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
cli._approval_state = None
cli._approval_deadline = 0
if hasattr(cli, '_app') and cli._app:
if hasattr(cli, "_app") and cli._app:
cli._app.invalidate()
cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}")
return "deny"

View File

@@ -51,6 +51,7 @@ def has_clipboard_image() -> bool:
# ── macOS ────────────────────────────────────────────────────────────────
def _macos_save(dest: Path) -> bool:
"""Try pngpaste first (fast, handles more formats), fall back to osascript."""
return _macos_pngpaste(dest) or _macos_osascript(dest)
@@ -61,7 +62,9 @@ def _macos_has_image() -> bool:
try:
info = subprocess.run(
["osascript", "-e", "clipboard info"],
capture_output=True, text=True, timeout=3,
capture_output=True,
text=True,
timeout=3,
)
return "«class PNGf»" in info.stdout or "«class TIFF»" in info.stdout
except Exception:
@@ -73,7 +76,8 @@ def _macos_pngpaste(dest: Path) -> bool:
try:
r = subprocess.run(
["pngpaste", str(dest)],
capture_output=True, timeout=3,
capture_output=True,
timeout=3,
)
if r.returncode == 0 and dest.exists() and dest.stat().st_size > 0:
return True
@@ -91,19 +95,21 @@ def _macos_osascript(dest: Path) -> bool:
# Extract as PNG
script = (
'try\n'
' set imgData to the clipboard as «class PNGf»\n'
"try\n"
" set imgData to the clipboard as «class PNGf»\n"
f' set f to open for access POSIX file "{dest}" with write permission\n'
' write imgData to f\n'
' close access f\n'
'on error\n'
" write imgData to f\n"
" close access f\n"
"on error\n"
' return "fail"\n'
'end try\n'
"end try\n"
)
try:
r = subprocess.run(
["osascript", "-e", script],
capture_output=True, text=True, timeout=5,
capture_output=True,
text=True,
timeout=5,
)
if r.returncode == 0 and "fail" not in r.stdout and dest.exists() and dest.stat().st_size > 0:
return True
@@ -114,13 +120,14 @@ def _macos_osascript(dest: Path) -> bool:
# ── Linux ────────────────────────────────────────────────────────────────
def _is_wsl() -> bool:
"""Detect if running inside WSL (1 or 2)."""
global _wsl_detected
if _wsl_detected is not None:
return _wsl_detected
try:
with open("/proc/version", "r") as f:
with open("/proc/version") as f:
_wsl_detected = "microsoft" in f.read().lower()
except Exception:
_wsl_detected = False
@@ -145,10 +152,7 @@ def _linux_save(dest: Path) -> bool:
# PowerShell script: get clipboard image as base64-encoded PNG on stdout.
# Using .NET System.Windows.Forms.Clipboard — always available on Windows.
_PS_CHECK_IMAGE = (
"Add-Type -AssemblyName System.Windows.Forms;"
"[System.Windows.Forms.Clipboard]::ContainsImage()"
)
_PS_CHECK_IMAGE = "Add-Type -AssemblyName System.Windows.Forms;[System.Windows.Forms.Clipboard]::ContainsImage()"
_PS_EXTRACT_IMAGE = (
"Add-Type -AssemblyName System.Windows.Forms;"
@@ -165,9 +169,10 @@ def _wsl_has_image() -> bool:
"""Check if Windows clipboard has an image (via powershell.exe)."""
try:
r = subprocess.run(
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
_PS_CHECK_IMAGE],
capture_output=True, text=True, timeout=8,
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", _PS_CHECK_IMAGE],
capture_output=True,
text=True,
timeout=8,
)
return r.returncode == 0 and "True" in r.stdout
except FileNotFoundError:
@@ -181,9 +186,10 @@ def _wsl_save(dest: Path) -> bool:
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
try:
r = subprocess.run(
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
_PS_EXTRACT_IMAGE],
capture_output=True, text=True, timeout=15,
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", _PS_EXTRACT_IMAGE],
capture_output=True,
text=True,
timeout=15,
)
if r.returncode != 0:
return False
@@ -206,16 +212,17 @@ def _wsl_save(dest: Path) -> bool:
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
def _wayland_has_image() -> bool:
"""Check if Wayland clipboard has image content."""
try:
r = subprocess.run(
["wl-paste", "--list-types"],
capture_output=True, text=True, timeout=3,
)
return r.returncode == 0 and any(
t.startswith("image/") for t in r.stdout.splitlines()
capture_output=True,
text=True,
timeout=3,
)
return r.returncode == 0 and any(t.startswith("image/") for t in r.stdout.splitlines())
except FileNotFoundError:
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
except Exception:
@@ -229,7 +236,9 @@ def _wayland_save(dest: Path) -> bool:
# Check available MIME types
types_r = subprocess.run(
["wl-paste", "--list-types"],
capture_output=True, text=True, timeout=3,
capture_output=True,
text=True,
timeout=3,
)
if types_r.returncode != 0:
return False
@@ -237,8 +246,7 @@ def _wayland_save(dest: Path) -> bool:
# Prefer PNG, fall back to other image formats
mime = None
for preferred in ("image/png", "image/jpeg", "image/bmp",
"image/gif", "image/webp"):
for preferred in ("image/png", "image/jpeg", "image/bmp", "image/gif", "image/webp"):
if preferred in types:
mime = preferred
break
@@ -250,7 +258,10 @@ def _wayland_save(dest: Path) -> bool:
with open(dest, "wb") as f:
subprocess.run(
["wl-paste", "--type", mime],
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
stdout=f,
stderr=subprocess.DEVNULL,
timeout=5,
check=True,
)
if not dest.exists() or dest.stat().st_size == 0:
@@ -276,6 +287,7 @@ def _convert_to_png(path: Path) -> bool:
# Try Pillow first (likely installed in the venv)
try:
from PIL import Image
img = Image.open(path)
img.save(path, "PNG")
return True
@@ -290,7 +302,8 @@ def _convert_to_png(path: Path) -> bool:
path.rename(tmp)
r = subprocess.run(
["convert", str(tmp), "png:" + str(path)],
capture_output=True, timeout=5,
capture_output=True,
timeout=5,
)
tmp.unlink(missing_ok=True)
if r.returncode == 0 and path.exists() and path.stat().st_size > 0:
@@ -310,12 +323,15 @@ def _convert_to_png(path: Path) -> bool:
# ── X11 (xclip) ─────────────────────────────────────────────────────────
def _xclip_has_image() -> bool:
"""Check if X11 clipboard has image content."""
try:
r = subprocess.run(
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
capture_output=True, text=True, timeout=3,
capture_output=True,
text=True,
timeout=3,
)
return r.returncode == 0 and "image/png" in r.stdout
except FileNotFoundError:
@@ -331,7 +347,9 @@ def _xclip_save(dest: Path) -> bool:
try:
targets = subprocess.run(
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
capture_output=True, text=True, timeout=3,
capture_output=True,
text=True,
timeout=3,
)
if "image/png" not in targets.stdout:
return False
@@ -346,7 +364,10 @@ def _xclip_save(dest: Path) -> bool:
with open(dest, "wb") as f:
subprocess.run(
["xclip", "-selection", "clipboard", "-t", "image/png", "-o"],
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
stdout=f,
stderr=subprocess.DEVNULL,
timeout=5,
check=True,
)
if dest.exists() and dest.stat().st_size > 0:
return True

View File

@@ -4,14 +4,12 @@ from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import List, Optional
import os
from pathlib import Path
logger = logging.getLogger(__name__)
DEFAULT_CODEX_MODELS: List[str] = [
DEFAULT_CODEX_MODELS: list[str] = [
"gpt-5.3-codex",
"gpt-5.2-codex",
"gpt-5.1-codex-max",
@@ -19,10 +17,11 @@ DEFAULT_CODEX_MODELS: List[str] = [
]
def _fetch_models_from_api(access_token: str) -> List[str]:
def _fetch_models_from_api(access_token: str) -> list[str]:
"""Fetch available models from the Codex API. Returns visible models sorted by priority."""
try:
import httpx
resp = httpx.get(
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
headers={"Authorization": f"Bearer {access_token}"},
@@ -57,7 +56,7 @@ def _fetch_models_from_api(access_token: str) -> List[str]:
return [slug for _, slug in sortable]
def _read_default_model(codex_home: Path) -> Optional[str]:
def _read_default_model(codex_home: Path) -> str | None:
config_path = codex_home / "config.toml"
if not config_path.exists():
return None
@@ -75,7 +74,7 @@ def _read_default_model(codex_home: Path) -> Optional[str]:
return None
def _read_cache_models(codex_home: Path) -> List[str]:
def _read_cache_models(codex_home: Path) -> list[str]:
cache_path = codex_home / "models_cache.json"
if not cache_path.exists():
return []
@@ -104,22 +103,22 @@ def _read_cache_models(codex_home: Path) -> List[str]:
sortable.append((rank, slug))
sortable.sort(key=lambda item: (item[0], item[1]))
deduped: List[str] = []
deduped: list[str] = []
for _, slug in sortable:
if slug not in deduped:
deduped.append(slug)
return deduped
def get_codex_model_ids(access_token: Optional[str] = None) -> List[str]:
def get_codex_model_ids(access_token: str | None = None) -> list[str]:
"""Return available Codex model IDs, trying API first, then local sources.
Resolution order: API (live, if token provided) > config.toml default >
local cache > hardcoded defaults.
"""
codex_home_str = os.getenv("CODEX_HOME", "").strip() or str(Path.home() / ".codex")
codex_home = Path(codex_home_str).expanduser()
ordered: List[str] = []
ordered: list[str] = []
# Try live API if we have a token
if access_token:

View File

@@ -12,7 +12,6 @@ from typing import Any
from prompt_toolkit.completion import Completer, Completion
COMMANDS = {
"/help": "Show this help message",
"/tools": "List available tools",

File diff suppressed because it is too large Load Diff

View File

@@ -20,46 +20,46 @@ from hermes_cli.colors import Colors, color
def cron_list(show_all: bool = False):
"""List all scheduled jobs."""
from cron.jobs import list_jobs
jobs = list_jobs(include_disabled=show_all)
if not jobs:
print(color("No scheduled jobs.", Colors.DIM))
print(color("Create one with the /cron add command in chat, or via Telegram.", Colors.DIM))
return
print()
print(color("┌─────────────────────────────────────────────────────────────────────────┐", Colors.CYAN))
print(color("│ Scheduled Jobs │", Colors.CYAN))
print(color("└─────────────────────────────────────────────────────────────────────────┘", Colors.CYAN))
print()
for job in jobs:
job_id = job.get("id", "?")[:8]
name = job.get("name", "(unnamed)")
schedule = job.get("schedule_display", job.get("schedule", {}).get("value", "?"))
enabled = job.get("enabled", True)
next_run = job.get("next_run_at", "?")
repeat_info = job.get("repeat", {})
repeat_times = repeat_info.get("times")
repeat_completed = repeat_info.get("completed", 0)
if repeat_times:
repeat_str = f"{repeat_completed}/{repeat_times}"
else:
repeat_str = ""
deliver = job.get("deliver", ["local"])
if isinstance(deliver, str):
deliver = [deliver]
deliver_str = ", ".join(deliver)
if not enabled:
status = color("[disabled]", Colors.RED)
else:
status = color("[active]", Colors.GREEN)
print(f" {color(job_id, Colors.YELLOW)} {status}")
print(f" Name: {name}")
print(f" Schedule: {schedule}")
@@ -67,9 +67,10 @@ def cron_list(show_all: bool = False):
print(f" Next run: {next_run}")
print(f" Deliver: {deliver_str}")
print()
# Warn if gateway isn't running
from hermes_cli.gateway import find_gateway_pids
if not find_gateway_pids():
print(color(" ⚠ Gateway is not running — jobs won't fire automatically.", Colors.YELLOW))
print(color(" Start it with: hermes gateway install", Colors.DIM))
@@ -79,6 +80,7 @@ def cron_list(show_all: bool = False):
def cron_tick():
"""Run due jobs once and exit."""
from cron.scheduler import tick
tick(verbose=True)
@@ -86,9 +88,9 @@ def cron_status():
"""Show cron execution status."""
from cron.jobs import list_jobs
from hermes_cli.gateway import find_gateway_pids
print()
pids = find_gateway_pids()
if pids:
print(color("✓ Gateway is running — cron jobs will fire automatically", Colors.GREEN))
@@ -99,9 +101,9 @@ def cron_status():
print(" To enable automatic execution:")
print(" hermes gateway install # Install as system service (recommended)")
print(" hermes gateway # Or run in foreground")
print()
jobs = list_jobs(include_disabled=False)
if jobs:
next_runs = [j.get("next_run_at") for j in jobs if j.get("next_run_at")]
@@ -110,24 +112,24 @@ def cron_status():
print(f" Next run: {min(next_runs)}")
else:
print(" No active jobs")
print()
def cron_command(args):
"""Handle cron subcommands."""
subcmd = getattr(args, 'cron_command', None)
subcmd = getattr(args, "cron_command", None)
if subcmd is None or subcmd == "list":
show_all = getattr(args, 'all', False)
show_all = getattr(args, "all", False)
cron_list(show_all)
elif subcmd == "tick":
cron_tick()
elif subcmd == "status":
cron_status()
else:
print(f"Unknown cron command: {subcmd}")
print("Usage: hermes cron [list|status|tick]")

View File

@@ -5,18 +5,18 @@ Diagnoses issues with Hermes Agent setup.
"""
import os
import sys
import subprocess
import shutil
from pathlib import Path
import subprocess
import sys
from hermes_cli.config import get_project_root, get_hermes_home, get_env_path
from hermes_cli.config import get_env_path, get_hermes_home, get_project_root
PROJECT_ROOT = get_project_root()
HERMES_HOME = get_hermes_home()
# Load environment variables from ~/.hermes/.env so API key checks work
from dotenv import load_dotenv
_env_path = get_env_path()
if _env_path.exists():
try:
@@ -33,7 +33,6 @@ os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
from hermes_cli.colors import Colors, color
from hermes_constants import OPENROUTER_MODELS_URL
_PROVIDER_ENV_HINTS = (
"OPENROUTER_API_KEY",
"OPENAI_API_KEY",
@@ -56,35 +55,38 @@ def _has_provider_env_config(content: str) -> bool:
def check_ok(text: str, detail: str = ""):
print(f" {color('', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
def check_warn(text: str, detail: str = ""):
print(f" {color('', Colors.YELLOW)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
def check_fail(text: str, detail: str = ""):
print(f" {color('', Colors.RED)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
def check_info(text: str):
print(f" {color('', Colors.CYAN)} {text}")
def run_doctor(args):
"""Run diagnostic checks."""
should_fix = getattr(args, 'fix', False)
should_fix = getattr(args, "fix", False)
issues = []
manual_issues = [] # issues that can't be auto-fixed
fixed_count = 0
print()
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
print(color("│ 🩺 Hermes Doctor │", Colors.CYAN))
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
# =========================================================================
# Check: Python version
# =========================================================================
print()
print(color("◆ Python Environment", Colors.CYAN, Colors.BOLD))
py_version = sys.version_info
if py_version >= (3, 11):
check_ok(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}")
@@ -96,20 +98,20 @@ def run_doctor(args):
else:
check_fail(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}", "(3.10+ required)")
issues.append("Upgrade Python to 3.10+")
# Check if in virtual environment
in_venv = sys.prefix != sys.base_prefix
if in_venv:
check_ok("Virtual environment active")
else:
check_warn("Not in virtual environment", "(recommended)")
# =========================================================================
# Check: Required packages
# =========================================================================
print()
print(color("◆ Required Packages", Colors.CYAN, Colors.BOLD))
required_packages = [
("openai", "OpenAI SDK"),
("rich", "Rich (terminal UI)"),
@@ -117,13 +119,13 @@ def run_doctor(args):
("yaml", "PyYAML"),
("httpx", "HTTPX"),
]
optional_packages = [
("croniter", "Croniter (cron expressions)"),
("telegram", "python-telegram-bot"),
("discord", "discord.py"),
]
for module, name in required_packages:
try:
__import__(module)
@@ -131,25 +133,25 @@ def run_doctor(args):
except ImportError:
check_fail(name, "(missing)")
issues.append(f"Install {name}: uv pip install {module}")
for module, name in optional_packages:
try:
__import__(module)
check_ok(name, "(optional)")
except ImportError:
check_warn(name, "(optional, not installed)")
# =========================================================================
# Check: Configuration files
# =========================================================================
print()
print(color("◆ Configuration Files", Colors.CYAN, Colors.BOLD))
# Check ~/.hermes/.env (primary location for user config)
env_path = HERMES_HOME / '.env'
env_path = HERMES_HOME / ".env"
if env_path.exists():
check_ok("~/.hermes/.env file exists")
# Check for common issues
content = env_path.read_text()
if _has_provider_env_config(content):
@@ -159,7 +161,7 @@ def run_doctor(args):
issues.append("Run 'hermes setup' to configure API keys")
else:
# Also check project root as fallback
fallback_env = PROJECT_ROOT / '.env'
fallback_env = PROJECT_ROOT / ".env"
if fallback_env.exists():
check_ok(".env file exists (in project directory)")
else:
@@ -173,17 +175,17 @@ def run_doctor(args):
else:
check_info("Run 'hermes setup' to create one")
issues.append("Run 'hermes setup' to create .env")
# Check ~/.hermes/config.yaml (primary) or project cli-config.yaml (fallback)
config_path = HERMES_HOME / 'config.yaml'
config_path = HERMES_HOME / "config.yaml"
if config_path.exists():
check_ok("~/.hermes/config.yaml exists")
else:
fallback_config = PROJECT_ROOT / 'cli-config.yaml'
fallback_config = PROJECT_ROOT / "cli-config.yaml"
if fallback_config.exists():
check_ok("cli-config.yaml exists (in project directory)")
else:
example_config = PROJECT_ROOT / 'cli-config.yaml.example'
example_config = PROJECT_ROOT / "cli-config.yaml.example"
if should_fix and example_config.exists():
config_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(str(example_config), str(config_path))
@@ -194,7 +196,7 @@ def run_doctor(args):
manual_issues.append("Create ~/.hermes/config.yaml manually")
else:
check_warn("config.yaml not found", "(using defaults)")
# =========================================================================
# Check: Auth providers
# =========================================================================
@@ -202,7 +204,7 @@ def run_doctor(args):
print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD))
try:
from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status
from hermes_cli.auth import get_codex_auth_status, get_nous_auth_status
nous_status = get_nous_auth_status()
if nous_status.get("logged_in"):
@@ -230,7 +232,7 @@ def run_doctor(args):
# =========================================================================
print()
print(color("◆ Directory Structure", Colors.CYAN, Colors.BOLD))
hermes_home = HERMES_HOME
if hermes_home.exists():
check_ok("~/.hermes directory exists")
@@ -241,7 +243,7 @@ def run_doctor(args):
fixed_count += 1
else:
check_warn("~/.hermes not found", "(will be created on first use)")
# Check expected subdirectories
expected_subdirs = ["cron", "sessions", "logs", "skills", "memories"]
for subdir_name in expected_subdirs:
@@ -255,7 +257,7 @@ def run_doctor(args):
fixed_count += 1
else:
check_warn(f"~/.hermes/{subdir_name}/ not found", "(will be created on first use)")
# Check for SOUL.md persona file
soul_path = hermes_home / "SOUL.md"
if soul_path.exists():
@@ -278,7 +280,7 @@ def run_doctor(args):
)
check_ok("Created ~/.hermes/SOUL.md with basic template")
fixed_count += 1
# Check memory directory
memories_dir = hermes_home / "memories"
if memories_dir.exists():
@@ -301,12 +303,13 @@ def run_doctor(args):
memories_dir.mkdir(parents=True, exist_ok=True)
check_ok("Created ~/.hermes/memories/")
fixed_count += 1
# Check SQLite session store
state_db_path = hermes_home / "state.db"
if state_db_path.exists():
try:
import sqlite3
conn = sqlite3.connect(str(state_db_path))
cursor = conn.execute("SELECT COUNT(*) FROM sessions")
count = cursor.fetchone()[0]
@@ -316,26 +319,26 @@ def run_doctor(args):
check_warn(f"~/.hermes/state.db exists but has issues: {e}")
else:
check_info("~/.hermes/state.db not created yet (will be created on first session)")
# =========================================================================
# Check: External tools
# =========================================================================
print()
print(color("◆ External Tools", Colors.CYAN, Colors.BOLD))
# Git
if shutil.which("git"):
check_ok("git")
else:
check_warn("git not found", "(optional)")
# ripgrep (optional, for faster file search)
if shutil.which("rg"):
check_ok("ripgrep (rg)", "(faster file search)")
else:
check_warn("ripgrep (rg) not found", "(file search uses grep fallback)")
check_info("Install for faster search: sudo apt install ripgrep")
# Docker (optional)
terminal_env = os.getenv("TERMINAL_ENV", "local")
if terminal_env == "docker":
@@ -355,7 +358,7 @@ def run_doctor(args):
check_ok("docker", "(optional)")
else:
check_warn("docker not found", "(optional)")
# SSH (if using ssh backend)
if terminal_env == "ssh":
ssh_host = os.getenv("TERMINAL_SSH_HOST")
@@ -364,7 +367,7 @@ def run_doctor(args):
result = subprocess.run(
["ssh", "-o", "ConnectTimeout=5", "-o", "BatchMode=yes", ssh_host, "echo ok"],
capture_output=True,
text=True
text=True,
)
if result.returncode == 0:
check_ok(f"SSH connection to {ssh_host}")
@@ -374,7 +377,7 @@ def run_doctor(args):
else:
check_fail("TERMINAL_SSH_HOST not set", "(required for TERMINAL_ENV=ssh)")
issues.append("Set TERMINAL_SSH_HOST in .env")
# Daytona (if using daytona backend)
if terminal_env == "daytona":
daytona_key = os.getenv("DAYTONA_API_KEY")
@@ -385,6 +388,7 @@ def run_doctor(args):
issues.append("Set DAYTONA_API_KEY environment variable")
try:
from daytona import Daytona
check_ok("daytona SDK", "(installed)")
except ImportError:
check_fail("daytona SDK not installed", "(pip install daytona)")
@@ -401,7 +405,7 @@ def run_doctor(args):
check_warn("agent-browser not installed", "(run: npm install)")
else:
check_warn("Node.js not found", "(optional, needed for browser tools)")
# npm audit for all Node.js packages
if shutil.which("npm"):
npm_dirs = [
@@ -415,9 +419,12 @@ def run_doctor(args):
audit_result = subprocess.run(
["npm", "audit", "--json"],
cwd=str(npm_dir),
capture_output=True, text=True, timeout=30,
capture_output=True,
text=True,
timeout=30,
)
import json as _json
audit_data = _json.loads(audit_result.stdout) if audit_result.stdout.strip() else {}
vuln_count = audit_data.get("metadata", {}).get("vulnerabilities", {})
critical = vuln_count.get("critical", 0)
@@ -429,7 +436,7 @@ def run_doctor(args):
elif critical > 0 or high > 0:
check_warn(
f"{label} deps",
f"({critical} critical, {high} high, {moderate} moderate — run: cd {npm_dir} && npm audit fix)"
f"({critical} critical, {high} high, {moderate} moderate — run: cd {npm_dir} && npm audit fix)",
)
issues.append(f"{label} has {total} npm vulnerability(ies)")
else:
@@ -442,47 +449,50 @@ def run_doctor(args):
# =========================================================================
print()
print(color("◆ API Connectivity", Colors.CYAN, Colors.BOLD))
openrouter_key = os.getenv("OPENROUTER_API_KEY")
if openrouter_key:
print(" Checking OpenRouter API...", end="", flush=True)
try:
import httpx
response = httpx.get(
OPENROUTER_MODELS_URL,
headers={"Authorization": f"Bearer {openrouter_key}"},
timeout=10
OPENROUTER_MODELS_URL, headers={"Authorization": f"Bearer {openrouter_key}"}, timeout=10
)
if response.status_code == 200:
print(f"\r {color('', Colors.GREEN)} OpenRouter API ")
elif response.status_code == 401:
print(f"\r {color('', Colors.RED)} OpenRouter API {color('(invalid API key)', Colors.DIM)} ")
print(
f"\r {color('', Colors.RED)} OpenRouter API {color('(invalid API key)', Colors.DIM)} "
)
issues.append("Check OPENROUTER_API_KEY in .env")
else:
print(f"\r {color('', Colors.RED)} OpenRouter API {color(f'(HTTP {response.status_code})', Colors.DIM)} ")
print(
f"\r {color('', Colors.RED)} OpenRouter API {color(f'(HTTP {response.status_code})', Colors.DIM)} "
)
except Exception as e:
print(f"\r {color('', Colors.RED)} OpenRouter API {color(f'({e})', Colors.DIM)} ")
issues.append("Check network connectivity")
else:
check_warn("OpenRouter API", "(not configured)")
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
if anthropic_key:
print(" Checking Anthropic API...", end="", flush=True)
try:
import httpx
response = httpx.get(
"https://api.anthropic.com/v1/models",
headers={
"x-api-key": anthropic_key,
"anthropic-version": "2023-06-01"
},
timeout=10
headers={"x-api-key": anthropic_key, "anthropic-version": "2023-06-01"},
timeout=10,
)
if response.status_code == 200:
print(f"\r {color('', Colors.GREEN)} Anthropic API ")
elif response.status_code == 401:
print(f"\r {color('', Colors.RED)} Anthropic API {color('(invalid API key)', Colors.DIM)} ")
print(
f"\r {color('', Colors.RED)} Anthropic API {color('(invalid API key)', Colors.DIM)} "
)
else:
msg = "(couldn't verify)"
print(f"\r {color('', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ")
@@ -491,10 +501,15 @@ def run_doctor(args):
# -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) --
_apikey_providers = [
("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"),
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
(
"Z.AI / GLM",
("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
"https://api.z.ai/api/paas/v4/models",
"GLM_BASE_URL",
),
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
]
for _pname, _env_vars, _default_url, _base_env in _apikey_providers:
_key = ""
@@ -507,6 +522,7 @@ def run_doctor(args):
print(f" Checking {_pname} API...", end="", flush=True)
try:
import httpx
_base = os.getenv(_base_env, "")
# Auto-detect Kimi Code keys (sk-kimi-) → api.kimi.com
if not _base and _key.startswith("sk-kimi-"):
@@ -526,7 +542,9 @@ def run_doctor(args):
print(f"\r {color('', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ")
issues.append(f"Check {_env_vars[0]} in .env")
else:
print(f"\r {color('', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ")
print(
f"\r {color('', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} "
)
except Exception as _e:
print(f"\r {color('', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ")
@@ -535,7 +553,7 @@ def run_doctor(args):
# =========================================================================
print()
print(color("◆ Submodules", Colors.CYAN, Colors.BOLD))
# mini-swe-agent (terminal tool backend)
mini_swe_dir = PROJECT_ROOT / "mini-swe-agent"
if mini_swe_dir.exists() and (mini_swe_dir / "pyproject.toml").exists():
@@ -547,7 +565,7 @@ def run_doctor(args):
issues.append("Install mini-swe-agent: uv pip install -e ./mini-swe-agent")
else:
check_warn("mini-swe-agent not found", "(run: git submodule update --init --recursive)")
# tinker-atropos (RL training backend)
tinker_dir = PROJECT_ROOT / "tinker-atropos"
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
@@ -562,24 +580,24 @@ def run_doctor(args):
check_warn("tinker-atropos requires Python 3.11+", f"(current: {py_version.major}.{py_version.minor})")
else:
check_warn("tinker-atropos not found", "(run: git submodule update --init --recursive)")
# =========================================================================
# Check: Tool Availability
# =========================================================================
print()
print(color("◆ Tool Availability", Colors.CYAN, Colors.BOLD))
try:
# Add project root to path for imports
sys.path.insert(0, str(PROJECT_ROOT))
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
from model_tools import TOOLSET_REQUIREMENTS, check_tool_availability
available, unavailable = check_tool_availability()
for tid in available:
info = TOOLSET_REQUIREMENTS.get(tid, {})
check_ok(info.get("name", tid))
for item in unavailable:
env_vars = item.get("missing_vars") or item.get("env_vars") or []
if env_vars:
@@ -594,7 +612,7 @@ def run_doctor(args):
issues.append("Run 'hermes setup' to configure missing API keys for full tool access")
except Exception as e:
check_warn("Could not check tool availability", f"({e})")
# =========================================================================
# Check: Skills Hub
# =========================================================================
@@ -608,6 +626,7 @@ def run_doctor(args):
if lock_file.exists():
try:
import json
lock_data = json.loads(lock_file.read_text())
count = len(lock_data.get("installed", {}))
check_ok(f"Lock file OK ({count} hub-installed skill(s))")
@@ -621,6 +640,7 @@ def run_doctor(args):
check_warn("Skills Hub directory not initialized", "(run: hermes skills list)")
from hermes_cli.config import get_env_value
github_token = get_env_value("GITHUB_TOKEN") or get_env_value("GH_TOKEN")
if github_token:
check_ok("GitHub token configured (authenticated API access)")
@@ -656,5 +676,5 @@ def run_doctor(args):
else:
print(color("" * 60, Colors.GREEN))
print(color(" All checks passed! 🎉", Colors.GREEN, Colors.BOLD))
print()

View File

@@ -13,18 +13,24 @@ from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
from hermes_cli.colors import Colors, color
from hermes_cli.config import get_env_value, save_env_value
from hermes_cli.setup import (
print_header, print_info, print_success, print_warning, print_error,
prompt, prompt_choice, prompt_yes_no,
print_error,
print_header,
print_info,
print_success,
print_warning,
prompt,
prompt_choice,
prompt_yes_no,
)
from hermes_cli.colors import Colors, color
# =============================================================================
# Process Management (for manual gateway runs)
# =============================================================================
def find_gateway_pids() -> list:
"""Find PIDs of running gateway processes."""
pids = []
@@ -38,17 +44,16 @@ def find_gateway_pids() -> list:
if is_windows():
# Windows: use wmic to search command lines
result = subprocess.run(
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
capture_output=True, text=True
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True
)
# Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n"
current_cmd = ""
for line in result.stdout.split('\n'):
for line in result.stdout.split("\n"):
line = line.strip()
if line.startswith("CommandLine="):
current_cmd = line[len("CommandLine="):]
current_cmd = line[len("CommandLine=") :]
elif line.startswith("ProcessId="):
pid_str = line[len("ProcessId="):]
pid_str = line[len("ProcessId=") :]
if any(p in current_cmd for p in patterns):
try:
pid = int(pid_str)
@@ -58,14 +63,10 @@ def find_gateway_pids() -> list:
pass
current_cmd = ""
else:
result = subprocess.run(
["ps", "aux"],
capture_output=True,
text=True
)
for line in result.stdout.split('\n'):
result = subprocess.run(["ps", "aux"], capture_output=True, text=True)
for line in result.stdout.split("\n"):
# Skip grep and current process
if 'grep' in line or str(os.getpid()) in line:
if "grep" in line or str(os.getpid()) in line:
continue
for pattern in patterns:
if pattern in line:
@@ -88,7 +89,7 @@ def kill_gateway_processes(force: bool = False) -> int:
"""Kill any running gateway processes. Returns count killed."""
pids = find_gateway_pids()
killed = 0
for pid in pids:
try:
if force and not is_windows():
@@ -101,18 +102,20 @@ def kill_gateway_processes(force: bool = False) -> int:
pass
except PermissionError:
print(f"⚠ Permission denied to kill PID {pid}")
return killed
def is_linux() -> bool:
return sys.platform.startswith('linux')
return sys.platform.startswith("linux")
def is_macos() -> bool:
return sys.platform == 'darwin'
return sys.platform == "darwin"
def is_windows() -> bool:
return sys.platform == 'win32'
return sys.platform == "win32"
# =============================================================================
@@ -122,12 +125,15 @@ def is_windows() -> bool:
SERVICE_NAME = "hermes-gateway"
SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
def get_systemd_unit_path() -> Path:
return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service"
def get_launchd_plist_path() -> Path:
return Path.home() / "Library" / "LaunchAgents" / "ai.hermes.gateway.plist"
def get_python_path() -> str:
if is_windows():
venv_python = PROJECT_ROOT / "venv" / "Scripts" / "python.exe"
@@ -137,14 +143,16 @@ def get_python_path() -> str:
return str(venv_python)
return sys.executable
def get_hermes_cli_path() -> str:
"""Get the path to the hermes CLI."""
# Check if installed via pip
import shutil
hermes_bin = shutil.which("hermes")
if hermes_bin:
return hermes_bin
# Fallback to direct module execution
return f"{get_python_path()} -m hermes_cli.main"
@@ -153,8 +161,10 @@ def get_hermes_cli_path() -> str:
# Systemd (Linux)
# =============================================================================
def generate_systemd_unit() -> str:
import shutil
python_path = get_python_path()
working_dir = str(PROJECT_ROOT)
venv_dir = str(PROJECT_ROOT / "venv")
@@ -163,7 +173,7 @@ def generate_systemd_unit() -> str:
# Build a PATH that includes the venv, node_modules, and standard system dirs
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
return f"""[Unit]
Description={SERVICE_DESCRIPTION}
@@ -188,56 +198,62 @@ StandardError=journal
WantedBy=default.target
"""
def systemd_install(force: bool = False):
unit_path = get_systemd_unit_path()
if unit_path.exists() and not force:
print(f"Service already installed at: {unit_path}")
print("Use --force to reinstall")
return
unit_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Installing systemd service to: {unit_path}")
unit_path.write_text(generate_systemd_unit())
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
subprocess.run(["systemctl", "--user", "enable", SERVICE_NAME], check=True)
print()
print("✓ Service installed and enabled!")
print()
print("Next steps:")
print(f" hermes gateway start # Start the service")
print(f" hermes gateway status # Check status")
print(" hermes gateway start # Start the service")
print(" hermes gateway status # Check status")
print(f" journalctl --user -u {SERVICE_NAME} -f # View logs")
print()
print("To enable lingering (keeps running after logout):")
print(" sudo loginctl enable-linger $USER")
def systemd_uninstall():
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=False)
subprocess.run(["systemctl", "--user", "disable", SERVICE_NAME], check=False)
unit_path = get_systemd_unit_path()
if unit_path.exists():
unit_path.unlink()
print(f"✓ Removed {unit_path}")
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
print("✓ Service uninstalled")
def systemd_start():
subprocess.run(["systemctl", "--user", "start", SERVICE_NAME], check=True)
print("✓ Service started")
def systemd_stop():
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=True)
print("✓ Service stopped")
def systemd_restart():
subprocess.run(["systemctl", "--user", "restart", SERVICE_NAME], check=True)
print("✓ Service restarted")
def systemd_status(deep: bool = False):
# Check if service unit file exists
unit_path = get_systemd_unit_path()
@@ -245,54 +261,45 @@ def systemd_status(deep: bool = False):
print("✗ Gateway service is not installed")
print(" Run: hermes gateway install")
return
# Show detailed status first
subprocess.run(
["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"],
capture_output=False
)
subprocess.run(["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"], capture_output=False)
# Check if service is active
result = subprocess.run(
["systemctl", "--user", "is-active", SERVICE_NAME],
capture_output=True,
text=True
)
result = subprocess.run(["systemctl", "--user", "is-active", SERVICE_NAME], capture_output=True, text=True)
status = result.stdout.strip()
if status == "active":
print("✓ Gateway service is running")
else:
print("✗ Gateway service is stopped")
print(" Run: hermes gateway start")
if deep:
print()
print("Recent logs:")
subprocess.run([
"journalctl", "--user", "-u", SERVICE_NAME,
"-n", "20", "--no-pager"
])
subprocess.run(["journalctl", "--user", "-u", SERVICE_NAME, "-n", "20", "--no-pager"])
# =============================================================================
# Launchd (macOS)
# =============================================================================
def generate_launchd_plist() -> str:
python_path = get_python_path()
working_dir = str(PROJECT_ROOT)
log_dir = Path.home() / ".hermes" / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
return f"""<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>ai.hermes.gateway</string>
<key>ProgramArguments</key>
<array>
<string>{python_path}</string>
@@ -301,42 +308,43 @@ def generate_launchd_plist() -> str:
<string>gateway</string>
<string>run</string>
</array>
<key>WorkingDirectory</key>
<string>{working_dir}</string>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<dict>
<key>SuccessfulExit</key>
<false/>
</dict>
<key>StandardOutPath</key>
<string>{log_dir}/gateway.log</string>
<key>StandardErrorPath</key>
<string>{log_dir}/gateway.error.log</string>
</dict>
</plist>
"""
def launchd_install(force: bool = False):
plist_path = get_launchd_plist_path()
if plist_path.exists() and not force:
print(f"Service already installed at: {plist_path}")
print("Use --force to reinstall")
return
plist_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Installing launchd service to: {plist_path}")
plist_path.write_text(generate_launchd_plist())
subprocess.run(["launchctl", "load", str(plist_path)], check=True)
print()
print("✓ Service installed and loaded!")
print()
@@ -344,41 +352,42 @@ def launchd_install(force: bool = False):
print(" hermes gateway status # Check status")
print(" tail -f ~/.hermes/logs/gateway.log # View logs")
def launchd_uninstall():
plist_path = get_launchd_plist_path()
subprocess.run(["launchctl", "unload", str(plist_path)], check=False)
if plist_path.exists():
plist_path.unlink()
print(f"✓ Removed {plist_path}")
print("✓ Service uninstalled")
def launchd_start():
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
print("✓ Service started")
def launchd_stop():
subprocess.run(["launchctl", "stop", "ai.hermes.gateway"], check=True)
print("✓ Service stopped")
def launchd_restart():
launchd_stop()
launchd_start()
def launchd_status(deep: bool = False):
result = subprocess.run(
["launchctl", "list", "ai.hermes.gateway"],
capture_output=True,
text=True
)
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
if result.returncode == 0:
print("✓ Gateway service is loaded")
print(result.stdout)
else:
print("✗ Gateway service is not loaded")
if deep:
log_file = Path.home() / ".hermes" / "logs" / "gateway.log"
if log_file.exists():
@@ -391,9 +400,10 @@ def launchd_status(deep: bool = False):
# Gateway Runner
# =============================================================================
def run_gateway(verbose: bool = False, replace: bool = False):
"""Run the gateway in foreground.
Args:
verbose: Enable verbose logging output.
replace: If True, kill any existing gateway instance before starting.
@@ -401,9 +411,9 @@ def run_gateway(verbose: bool = False, replace: bool = False):
hasn't fully exited yet.
"""
sys.path.insert(0, str(PROJECT_ROOT))
from gateway.run import start_gateway
print("┌─────────────────────────────────────────────────────────┐")
print("│ ⚕ Hermes Gateway Starting... │")
print("├─────────────────────────────────────────────────────────┤")
@@ -411,7 +421,7 @@ def run_gateway(verbose: bool = False, replace: bool = False):
print("│ Press Ctrl+C to stop │")
print("└─────────────────────────────────────────────────────────┘")
print()
# Exit with code 1 if gateway fails to connect any platform,
# so systemd Restart=on-failure will retry on transient errors
success = asyncio.run(start_gateway(replace=replace))
@@ -438,13 +448,25 @@ _PLATFORMS = [
"4. To find your user ID: message @userinfobot — it replies with your numeric ID",
],
"vars": [
{"name": "TELEGRAM_BOT_TOKEN", "prompt": "Bot token", "password": True,
"help": "Paste the token from @BotFather (step 3 above)."},
{"name": "TELEGRAM_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False,
"is_allowlist": True,
"help": "Paste your user ID from step 4 above."},
{"name": "TELEGRAM_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
"help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat."},
{
"name": "TELEGRAM_BOT_TOKEN",
"prompt": "Bot token",
"password": True,
"help": "Paste the token from @BotFather (step 3 above).",
},
{
"name": "TELEGRAM_ALLOWED_USERS",
"prompt": "Allowed user IDs (comma-separated)",
"password": False,
"is_allowlist": True,
"help": "Paste your user ID from step 4 above.",
},
{
"name": "TELEGRAM_HOME_CHANNEL",
"prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)",
"password": False,
"help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat.",
},
],
},
{
@@ -466,13 +488,25 @@ _PLATFORMS = [
" then right-click your name → Copy ID",
],
"vars": [
{"name": "DISCORD_BOT_TOKEN", "prompt": "Bot token", "password": True,
"help": "Paste the token from step 2 above."},
{"name": "DISCORD_ALLOWED_USERS", "prompt": "Allowed user IDs or usernames (comma-separated)", "password": False,
"is_allowlist": True,
"help": "Paste your user ID from step 5 above."},
{"name": "DISCORD_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
"help": "Right-click a channel → Copy Channel ID (requires Developer Mode)."},
{
"name": "DISCORD_BOT_TOKEN",
"prompt": "Bot token",
"password": True,
"help": "Paste the token from step 2 above.",
},
{
"name": "DISCORD_ALLOWED_USERS",
"prompt": "Allowed user IDs or usernames (comma-separated)",
"password": False,
"is_allowlist": True,
"help": "Paste your user ID from step 5 above.",
},
{
"name": "DISCORD_HOME_CHANNEL",
"prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)",
"password": False,
"help": "Right-click a channel → Copy Channel ID (requires Developer Mode).",
},
],
},
{
@@ -497,13 +531,25 @@ _PLATFORMS = [
"8. Invite the bot to channels: /invite @YourBot",
],
"vars": [
{"name": "SLACK_BOT_TOKEN", "prompt": "Bot Token (xoxb-...)", "password": True,
"help": "Paste the bot token from step 3 above."},
{"name": "SLACK_APP_TOKEN", "prompt": "App Token (xapp-...)", "password": True,
"help": "Paste the app-level token from step 4 above."},
{"name": "SLACK_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False,
"is_allowlist": True,
"help": "Paste your member ID from step 7 above."},
{
"name": "SLACK_BOT_TOKEN",
"prompt": "Bot Token (xoxb-...)",
"password": True,
"help": "Paste the bot token from step 3 above.",
},
{
"name": "SLACK_APP_TOKEN",
"prompt": "App Token (xapp-...)",
"password": True,
"help": "Paste the app-level token from step 4 above.",
},
{
"name": "SLACK_ALLOWED_USERS",
"prompt": "Allowed user IDs (comma-separated)",
"password": False,
"is_allowlist": True,
"help": "Paste your member ID from step 7 above.",
},
],
},
{
@@ -582,14 +628,14 @@ def _setup_standard_platform(platform: dict):
# Allowlist fields get special handling for the deny-by-default security model
if var.get("is_allowlist"):
print_info(f" The gateway DENIES all users by default for security.")
print_info(f" Enter user IDs to create an allowlist, or leave empty")
print_info(f" and you'll be asked about open access next.")
print_info(" The gateway DENIES all users by default for security.")
print_info(" Enter user IDs to create an allowlist, or leave empty")
print_info(" and you'll be asked about open access next.")
value = prompt(f" {var['prompt']}", password=False)
if value:
cleaned = value.replace(" ", "")
save_env_value(var["name"], cleaned)
print_success(f" Saved — only these users can interact with the bot.")
print_success(" Saved — only these users can interact with the bot.")
allowed_val_set = cleaned
else:
# No allowlist — ask about open access vs DM pairing
@@ -618,7 +664,7 @@ def _setup_standard_platform(platform: dict):
print_warning(f" Skipped — {label} won't work without this.")
return
else:
print_info(f" Skipped (can configure later)")
print_info(" Skipped (can configure later)")
# If an allowlist was set and home channel wasn't, offer to reuse
# the first user ID (common for Telegram DMs).
@@ -636,8 +682,10 @@ def _setup_standard_platform(platform: dict):
def _setup_whatsapp():
"""Delegate to the existing WhatsApp setup flow."""
from hermes_cli.main import cmd_whatsapp
import argparse
from hermes_cli.main import cmd_whatsapp
cmd_whatsapp(argparse.Namespace())
@@ -653,16 +701,10 @@ def _is_service_installed() -> bool:
def _is_service_running() -> bool:
"""Check if the gateway service is currently running."""
if is_linux() and get_systemd_unit_path().exists():
result = subprocess.run(
["systemctl", "--user", "is-active", SERVICE_NAME],
capture_output=True, text=True
)
result = subprocess.run(["systemctl", "--user", "is-active", SERVICE_NAME], capture_output=True, text=True)
return result.stdout.strip() == "active"
elif is_macos() and get_launchd_plist_path().exists():
result = subprocess.run(
["launchctl", "list", "ai.hermes.gateway"],
capture_output=True, text=True
)
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
return result.returncode == 0
# Check for manual processes
return len(find_gateway_pids()) > 0
@@ -697,7 +739,7 @@ def _setup_signal():
print_info(" Docker: bbernhard/signal-cli-rest-api")
print()
print_info(" After installing, link your account and start the daemon:")
print_info(" signal-cli link -n \"HermesAgent\"")
print_info(' signal-cli link -n "HermesAgent"')
print_info(" signal-cli --account +YOURNUMBER daemon --http 127.0.0.1:8080")
print()
@@ -715,6 +757,7 @@ def _setup_signal():
print_info(" Testing connection...")
try:
import httpx
resp = httpx.get(f"{url.rstrip('/')}/api/v1/check", timeout=10.0)
if resp.status_code == 200:
print_success(" signal-cli daemon is reachable!")
@@ -779,7 +822,7 @@ def _setup_signal():
print_success("Signal configured!")
print_info(f" URL: {url}")
print_info(f" Account: {account}")
print_info(f" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
print_info(" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
print_info(f" Groups: {'enabled' if get_env_value('SIGNAL_GROUP_ALLOWED_USERS') else 'disabled'}")
@@ -841,11 +884,10 @@ def gateway_setup():
_setup_standard_platform(platform)
# ── Post-setup: offer to install/restart gateway ──
any_configured = any(
bool(get_env_value(p["token_var"]))
for p in _PLATFORMS
if p["key"] != "whatsapp"
) or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true"
any_configured = (
any(bool(get_env_value(p["token_var"])) for p in _PLATFORMS if p["key"] != "whatsapp")
or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true"
)
if any_configured:
print()
@@ -878,7 +920,9 @@ def gateway_setup():
print()
if is_linux() or is_macos():
platform_name = "systemd" if is_linux() else "launchd"
if prompt_yes_no(f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True):
if prompt_yes_no(
f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True
):
try:
force = False
if is_linux():
@@ -914,14 +958,15 @@ def gateway_setup():
# Main Command Handler
# =============================================================================
def gateway_command(args):
"""Handle gateway subcommands."""
subcmd = getattr(args, 'gateway_command', None)
subcmd = getattr(args, "gateway_command", None)
# Default to run if no subcommand
if subcmd is None or subcmd == "run":
verbose = getattr(args, 'verbose', False)
replace = getattr(args, 'replace', False)
verbose = getattr(args, "verbose", False)
replace = getattr(args, "replace", False)
run_gateway(verbose, replace=replace)
return
@@ -931,7 +976,7 @@ def gateway_command(args):
# Service management commands
if subcmd == "install":
force = getattr(args, 'force', False)
force = getattr(args, "force", False)
if is_linux():
systemd_install(force)
elif is_macos():
@@ -940,7 +985,7 @@ def gateway_command(args):
print("Service installation not supported on this platform.")
print("Run manually: hermes gateway run")
sys.exit(1)
elif subcmd == "uninstall":
if is_linux():
systemd_uninstall()
@@ -949,7 +994,7 @@ def gateway_command(args):
else:
print("Not supported on this platform.")
sys.exit(1)
elif subcmd == "start":
if is_linux():
systemd_start()
@@ -958,11 +1003,11 @@ def gateway_command(args):
else:
print("Not supported on this platform.")
sys.exit(1)
elif subcmd == "stop":
# Try service first, fall back to killing processes directly
service_available = False
if is_linux() and get_systemd_unit_path().exists():
try:
systemd_stop()
@@ -975,7 +1020,7 @@ def gateway_command(args):
service_available = True
except subprocess.CalledProcessError:
pass
if not service_available:
# Kill gateway processes directly
killed = kill_gateway_processes()
@@ -983,11 +1028,11 @@ def gateway_command(args):
print(f"✓ Stopped {killed} gateway process(es)")
else:
print("✗ No gateway processes found")
elif subcmd == "restart":
# Try service first, fall back to killing and restarting
service_available = False
if is_linux() and get_systemd_unit_path().exists():
try:
systemd_restart()
@@ -1000,23 +1045,24 @@ def gateway_command(args):
service_available = True
except subprocess.CalledProcessError:
pass
if not service_available:
# Manual restart: kill existing processes
killed = kill_gateway_processes()
if killed:
print(f"✓ Stopped {killed} gateway process(es)")
import time
time.sleep(2)
# Start fresh
print("Starting gateway...")
run_gateway(verbose=False)
elif subcmd == "status":
deep = getattr(args, 'deep', False)
deep = getattr(args, "deep", False)
# Check for service first
if is_linux() and get_systemd_unit_path().exists():
systemd_status(deep)

File diff suppressed because it is too large Load Diff

View File

@@ -8,26 +8,26 @@ Add, remove, or reorder entries here — both `hermes setup` and
from __future__ import annotations
import json
import urllib.request
import urllib.error
import urllib.request
from difflib import get_close_matches
from typing import Any, Optional
from typing import Any
# (model_id, display description shown in menus)
OPENROUTER_MODELS: list[tuple[str, str]] = [
("anthropic/claude-opus-4.6", "recommended"),
("anthropic/claude-sonnet-4.5", ""),
("openai/gpt-5.4-pro", ""),
("openai/gpt-5.4", ""),
("openai/gpt-5.3-codex", ""),
("google/gemini-3-pro-preview", ""),
("google/gemini-3-flash-preview", ""),
("qwen/qwen3.5-plus-02-15", ""),
("qwen/qwen3.5-35b-a3b", ""),
("stepfun/step-3.5-flash", ""),
("z-ai/glm-5", ""),
("moonshotai/kimi-k2.5", ""),
("minimax/minimax-m2.5", ""),
("anthropic/claude-opus-4.6", "recommended"),
("anthropic/claude-sonnet-4.5", ""),
("openai/gpt-5.4-pro", ""),
("openai/gpt-5.4", ""),
("openai/gpt-5.3-codex", ""),
("google/gemini-3-pro-preview", ""),
("google/gemini-3-flash-preview", ""),
("qwen/qwen3.5-plus-02-15", ""),
("qwen/qwen3.5-35b-a3b", ""),
("stepfun/step-3.5-flash", ""),
("z-ai/glm-5", ""),
("moonshotai/kimi-k2.5", ""),
("minimax/minimax-m2.5", ""),
]
_PROVIDER_MODELS: dict[str, list[str]] = {
@@ -93,9 +93,7 @@ def menu_labels() -> list[str]:
# All provider IDs and aliases that are valid for the provider:model syntax.
_KNOWN_PROVIDER_NAMES: set[str] = (
set(_PROVIDER_LABELS.keys())
| set(_PROVIDER_ALIASES.keys())
| {"openrouter", "custom"}
set(_PROVIDER_LABELS.keys()) | set(_PROVIDER_ALIASES.keys()) | {"openrouter", "custom"}
)
@@ -107,8 +105,13 @@ def list_available_providers() -> list[dict[str, str]]:
"""
# Canonical providers in display order
_PROVIDER_ORDER = [
"openrouter", "nous", "openai-codex",
"zai", "kimi-coding", "minimax", "minimax-cn",
"openrouter",
"nous",
"openai-codex",
"zai",
"kimi-coding",
"minimax",
"minimax-cn",
]
# Build reverse alias map
aliases_for: dict[str, list[str]] = {}
@@ -123,16 +126,19 @@ def list_available_providers() -> list[dict[str, str]]:
has_creds = False
try:
from hermes_cli.runtime_provider import resolve_runtime_provider
runtime = resolve_runtime_provider(requested=pid)
has_creds = bool(runtime.get("api_key"))
except Exception:
pass
result.append({
"id": pid,
"label": label,
"aliases": alias_list,
"authenticated": has_creds,
})
result.append(
{
"id": pid,
"label": label,
"aliases": alias_list,
"authenticated": has_creds,
}
)
return result
@@ -157,13 +163,13 @@ def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
colon = stripped.find(":")
if colon > 0:
provider_part = stripped[:colon].strip().lower()
model_part = stripped[colon + 1:].strip()
model_part = stripped[colon + 1 :].strip()
if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES:
return (normalize_provider(provider_part), model_part)
return (current_provider, stripped)
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
def curated_models_for_provider(provider: str | None) -> list[tuple[str, str]]:
"""Return ``(model_id, description)`` tuples for a provider's curated list."""
normalized = normalize_provider(provider)
if normalized == "openrouter":
@@ -172,7 +178,7 @@ def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]
return [(m, "") for m in models]
def normalize_provider(provider: Optional[str]) -> str:
def normalize_provider(provider: str | None) -> str:
"""Normalize provider aliases to Hermes' canonical provider ids.
Note: ``"auto"`` passes through unchanged — use
@@ -183,7 +189,7 @@ def normalize_provider(provider: Optional[str]) -> str:
return _PROVIDER_ALIASES.get(normalized, normalized)
def provider_model_ids(provider: Optional[str]) -> list[str]:
def provider_model_ids(provider: str | None) -> list[str]:
"""Return the best known model catalog for a provider."""
normalized = normalize_provider(provider)
if normalized == "openrouter":
@@ -196,10 +202,10 @@ def provider_model_ids(provider: Optional[str]) -> list[str]:
def fetch_api_models(
api_key: Optional[str],
base_url: Optional[str],
api_key: str | None,
base_url: str | None,
timeout: float = 5.0,
) -> Optional[list[str]]:
) -> list[str] | None:
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
Returns a list of model ID strings, or ``None`` if the endpoint could not
@@ -225,10 +231,10 @@ def fetch_api_models(
def validate_requested_model(
model_name: str,
provider: Optional[str],
provider: str | None,
*,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
api_key: str | None = None,
base_url: str | None = None,
) -> dict[str, Any]:
"""
Validate a ``/model`` value for the active provider.
@@ -286,10 +292,7 @@ def validate_requested_model(
"accepted": False,
"persist": False,
"recognized": False,
"message": (
f"Error: `{requested}` is not a valid model for this provider."
f"{suggestion_text}"
),
"message": (f"Error: `{requested}` is not a valid model for this provider.{suggestion_text}"),
}
# api_models is None — couldn't reach API, fall back to catalog check

View File

@@ -8,6 +8,7 @@ Usage:
hermes pairing clear-pending # Clear all expired/pending codes
"""
def pairing_command(args):
"""Handle hermes pairing subcommands."""
from gateway.pairing import PairingStore
@@ -72,10 +73,10 @@ def _cmd_approve(store, platform: str, code: str):
name = result.get("user_name", "")
display = f"{name} ({uid})" if name else uid
print(f"\n Approved! User {display} on {platform} can now use the bot~")
print(f" They'll be recognized automatically on their next message.\n")
print(" They'll be recognized automatically on their next message.\n")
else:
print(f"\n Code '{code}' not found or expired for platform '{platform}'.")
print(f" Run 'hermes pairing list' to see pending codes.\n")
print(" Run 'hermes pairing list' to see pending codes.\n")
def _cmd_revoke(store, platform: str, user_id: str):

View File

@@ -3,22 +3,22 @@
from __future__ import annotations
import os
from typing import Any, Dict, Optional
from typing import Any
from hermes_cli.auth import (
AuthError,
PROVIDER_REGISTRY,
AuthError,
format_auth_error,
resolve_provider,
resolve_nous_runtime_credentials,
resolve_codex_runtime_credentials,
resolve_api_key_provider_credentials,
resolve_codex_runtime_credentials,
resolve_nous_runtime_credentials,
resolve_provider,
)
from hermes_cli.config import load_config
from hermes_constants import OPENROUTER_BASE_URL
def _get_model_config() -> Dict[str, Any]:
def _get_model_config() -> dict[str, Any]:
config = load_config()
model_cfg = config.get("model")
if isinstance(model_cfg, dict):
@@ -28,7 +28,7 @@ def _get_model_config() -> Dict[str, Any]:
return {}
def resolve_requested_provider(requested: Optional[str] = None) -> str:
def resolve_requested_provider(requested: str | None = None) -> str:
"""Resolve provider request from explicit arg, env, then config."""
if requested and requested.strip():
return requested.strip().lower()
@@ -48,9 +48,9 @@ def resolve_requested_provider(requested: Optional[str] = None) -> str:
def _resolve_openrouter_runtime(
*,
requested_provider: str,
explicit_api_key: Optional[str] = None,
explicit_base_url: Optional[str] = None,
) -> Dict[str, Any]:
explicit_api_key: str | None = None,
explicit_base_url: str | None = None,
) -> dict[str, Any]:
model_cfg = _get_model_config()
cfg_base_url = model_cfg.get("base_url") if isinstance(model_cfg.get("base_url"), str) else ""
cfg_provider = model_cfg.get("provider") if isinstance(model_cfg.get("provider"), str) else ""
@@ -81,19 +81,9 @@ def _resolve_openrouter_runtime(
# provider (issues #420, #560).
_is_openrouter_url = "openrouter.ai" in base_url
if _is_openrouter_url:
api_key = (
explicit_api_key
or os.getenv("OPENROUTER_API_KEY")
or os.getenv("OPENAI_API_KEY")
or ""
)
api_key = explicit_api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
else:
api_key = (
explicit_api_key
or os.getenv("OPENAI_API_KEY")
or os.getenv("OPENROUTER_API_KEY")
or ""
)
api_key = explicit_api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY") or ""
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
@@ -108,10 +98,10 @@ def _resolve_openrouter_runtime(
def resolve_runtime_provider(
*,
requested: Optional[str] = None,
explicit_api_key: Optional[str] = None,
explicit_base_url: Optional[str] = None,
) -> Dict[str, Any]:
requested: str | None = None,
explicit_api_key: str | None = None,
explicit_base_url: str | None = None,
) -> dict[str, Any]:
"""Resolve runtime provider credentials for agent execution."""
requested_provider = resolve_requested_provider(requested)

File diff suppressed because it is too large Load Diff

View File

@@ -13,7 +13,6 @@ handler are thin wrappers that parse args and delegate.
import json
import shutil
from pathlib import Path
from typing import Optional
from rich.console import Console
from rich.panel import Panel
@@ -29,6 +28,7 @@ _console = Console()
# Shared do_* functions
# ---------------------------------------------------------------------------
def _resolve_short_name(name: str, sources, console: Console) -> str:
"""
Resolve a short skill name (e.g. 'pptx') to a full identifier by searching
@@ -57,7 +57,9 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
table.add_column("Trust", style="dim")
table.add_column("Identifier", style="bold cyan")
for r in exact:
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(
r.trust_level, "dim"
)
trust_label = "official" if r.source == "official" else r.trust_level
table.add_row(r.source, f"[{trust_style}]{trust_label}[/]", r.identifier)
c.print(table)
@@ -76,8 +78,7 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
return ""
def do_search(query: str, source: str = "all", limit: int = 10,
console: Optional[Console] = None) -> None:
def do_search(query: str, source: str = "all", limit: int = 10, console: Console | None = None) -> None:
"""Search registries and display results as a Rich table."""
from tools.skills_hub import GitHubAuth, create_source_router, unified_search
@@ -111,18 +112,19 @@ def do_search(query: str, source: str = "all", limit: int = 10,
)
c.print(table)
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
"hermes skills install <identifier> to install[/]\n")
c.print(
"[dim]Use: hermes skills inspect <identifier> to preview, hermes skills install <identifier> to install[/]\n"
)
def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
console: Optional[Console] = None) -> None:
def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console: Console | None = None) -> None:
"""Browse all available skills across registries, paginated.
Official skills are always shown first, regardless of source filter.
"""
from tools.skills_hub import (
GitHubAuth, create_source_router, OptionalSkillSource, SkillMeta,
GitHubAuth,
create_source_router,
)
# Clamp page_size to safe range
@@ -136,8 +138,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
# Collect results from all (or filtered) sources
# Use empty query to get everything; per-source limits prevent overload
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50,
"claude-marketplace": 50, "lobehub": 50}
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50, "claude-marketplace": 50, "lobehub": 50}
all_results: list = []
source_counts: dict = {}
@@ -168,11 +169,13 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
deduped = list(seen.values())
# Sort: official first, then by trust level (desc), then alphabetically
deduped.sort(key=lambda r: (
-_TRUST_RANK.get(r.trust_level, 0),
r.source != "official",
r.name.lower(),
))
deduped.sort(
key=lambda r: (
-_TRUST_RANK.get(r.trust_level, 0),
r.source != "official",
r.name.lower(),
)
)
# Paginate
total = len(deduped)
@@ -187,8 +190,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
# Build header
source_label = f"{source}" if source != "all" else "— all sources"
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
f" [dim]({total} skills, page {page}/{total_pages})[/]")
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/] [dim]({total} skills, page {page}/{total_pages})[/]")
if official_count > 0 and page == 1:
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
c.print()
@@ -202,8 +204,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
table.add_column("Trust", width=10)
for i, r in enumerate(page_items, start=start + 1):
trust_style = {"builtin": "bright_cyan", "trusted": "green",
"community": "yellow"}.get(r.trust_level, "dim")
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
trust_label = "★ official" if r.source == "official" else r.trust_level
desc = r.description[:50]
@@ -235,18 +236,22 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
"hermes skills install <identifier> to install[/]\n")
def do_install(identifier: str, category: str = "", force: bool = False,
console: Optional[Console] = None) -> None:
"""Fetch, quarantine, scan, confirm, and install a skill."""
from tools.skills_hub import (
GitHubAuth, create_source_router, ensure_hub_dirs,
quarantine_bundle, install_from_quarantine, HubLockFile,
c.print(
"[dim]Use: hermes skills inspect <identifier> to preview, hermes skills install <identifier> to install[/]\n"
)
def do_install(identifier: str, category: str = "", force: bool = False, console: Console | None = None) -> None:
"""Fetch, quarantine, scan, confirm, and install a skill."""
from tools.skills_guard import format_scan_report, scan_skill, should_allow_install
from tools.skills_hub import (
GitHubAuth,
HubLockFile,
create_source_router,
ensure_hub_dirs,
install_from_quarantine,
quarantine_bundle,
)
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
c = console or _console
ensure_hub_dirs()
@@ -304,33 +309,43 @@ def do_install(identifier: str, category: str = "", force: bool = False,
# Clean up quarantine
shutil.rmtree(q_path, ignore_errors=True)
from tools.skills_hub import append_audit_log
append_audit_log("BLOCKED", bundle.name, bundle.source,
bundle.trust_level, result.verdict,
f"{len(result.findings)}_findings")
append_audit_log(
"BLOCKED",
bundle.name,
bundle.source,
bundle.trust_level,
result.verdict,
f"{len(result.findings)}_findings",
)
return
# Confirm with user — show appropriate warning based on source
if not force:
c.print()
if bundle.source == "official":
c.print(Panel(
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
"It ships with hermes-agent but is not activated by default.\n"
"Installing will copy it to your skills directory where the agent can use it.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Official Skill",
border_style="bright_cyan",
))
c.print(
Panel(
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
"It ships with hermes-agent but is not activated by default.\n"
"Installing will copy it to your skills directory where the agent can use it.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Official Skill",
border_style="bright_cyan",
)
)
else:
c.print(Panel(
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
"External skills can contain instructions that influence agent behavior,\n"
"shell commands, and scripts. Even after automated scanning, you should\n"
"review the installed files before use.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Disclaimer",
border_style="yellow",
))
c.print(
Panel(
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
"External skills can contain instructions that influence agent behavior,\n"
"shell commands, and scripts. Even after automated scanning, you should\n"
"review the installed files before use.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Disclaimer",
border_style="yellow",
)
)
c.print(f"[bold]Install '{bundle.name}'?[/]")
try:
answer = input("Confirm [y/N]: ").strip().lower()
@@ -344,11 +359,12 @@ def do_install(identifier: str, category: str = "", force: bool = False,
# Install
install_dir = install_from_quarantine(q_path, bundle.name, category, bundle, result)
from tools.skills_hub import SKILLS_DIR
c.print(f"[bold green]Installed:[/] {install_dir.relative_to(SKILLS_DIR)}")
c.print(f"[dim]Files: {', '.join(bundle.files.keys())}[/]\n")
def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
def do_inspect(identifier: str, console: Console | None = None) -> None:
"""Preview a skill's SKILL.md content without installing."""
from tools.skills_hub import GitHubAuth, create_source_router
@@ -406,7 +422,7 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
c.print()
def do_list(source_filter: str = "all", console: Optional[Console] = None) -> None:
def do_list(source_filter: str = "all", console: Console | None = None) -> None:
"""List installed skills, distinguishing builtins from hub-installed."""
from tools.skills_hub import HubLockFile, ensure_hub_dirs
from tools.skills_tool import _find_all_skills
@@ -446,14 +462,13 @@ def do_list(source_filter: str = "all", console: Optional[Console] = None) -> No
table.add_row(name, category, source_display, f"[{trust_style}]{trust_label}[/]")
c.print(table)
c.print(f"[dim]{len(hub_installed)} hub-installed, "
f"{len(all_skills) - len(hub_installed)} builtin[/]\n")
c.print(f"[dim]{len(hub_installed)} hub-installed, {len(all_skills) - len(hub_installed)} builtin[/]\n")
def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> None:
def do_audit(name: str | None = None, console: Console | None = None) -> None:
"""Re-run security scan on installed hub skills."""
from tools.skills_hub import HubLockFile, SKILLS_DIR
from tools.skills_guard import scan_skill, format_scan_report
from tools.skills_guard import format_scan_report, scan_skill
from tools.skills_hub import SKILLS_DIR, HubLockFile
c = console or _console
lock = HubLockFile()
@@ -483,7 +498,7 @@ def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> N
c.print()
def do_uninstall(name: str, console: Optional[Console] = None) -> None:
def do_uninstall(name: str, console: Console | None = None) -> None:
"""Remove a hub-installed skill with confirmation."""
from tools.skills_hub import uninstall_skill
@@ -505,7 +520,7 @@ def do_uninstall(name: str, console: Optional[Console] = None) -> None:
c.print(f"[bold red]Error:[/] {msg}\n")
def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> None:
def do_tap(action: str, repo: str = "", console: Console | None = None) -> None:
"""Manage taps (custom GitHub repo sources)."""
from tools.skills_hub import TapsManager
@@ -547,11 +562,10 @@ def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> No
c.print(f"[bold red]Unknown tap action:[/] {action}. Use: list, add, remove\n")
def do_publish(skill_path: str, target: str = "github", repo: str = "",
console: Optional[Console] = None) -> None:
def do_publish(skill_path: str, target: str = "github", repo: str = "", console: Console | None = None) -> None:
"""Publish a local skill to a registry (GitHub PR or ClawHub submission)."""
from tools.skills_hub import GitHubAuth, SKILLS_DIR
from tools.skills_guard import scan_skill, format_scan_report
from tools.skills_guard import format_scan_report, scan_skill
from tools.skills_hub import SKILLS_DIR, GitHubAuth
c = console or _console
path = Path(skill_path)
@@ -565,14 +579,16 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "",
# Validate the skill
import yaml
skill_md = (path / "SKILL.md").read_text(encoding="utf-8")
fm = {}
if skill_md.startswith("---"):
import re
match = re.search(r'\n---\s*\n', skill_md[3:])
match = re.search(r"\n---\s*\n", skill_md[3:])
if match:
try:
fm = yaml.safe_load(skill_md[3:match.start() + 3]) or {}
fm = yaml.safe_load(skill_md[3 : match.start() + 3]) or {}
except yaml.YAMLError:
pass
@@ -592,14 +608,18 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "",
if target == "github":
if not repo:
c.print("[bold red]Error:[/] --repo required for GitHub publish.\n"
"Usage: hermes skills publish <path> --to github --repo owner/repo\n")
c.print(
"[bold red]Error:[/] --repo required for GitHub publish.\n"
"Usage: hermes skills publish <path> --to github --repo owner/repo\n"
)
return
auth = GitHubAuth()
if not auth.is_authenticated():
c.print("[bold red]Error:[/] GitHub authentication required.\n"
"Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n")
c.print(
"[bold red]Error:[/] GitHub authentication required.\n"
"Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n"
)
return
c.print(f"[bold]Publishing '{name}' to {repo}...[/]")
@@ -610,14 +630,12 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "",
c.print(f"[bold red]Error:[/] {msg}\n")
elif target == "clawhub":
c.print("[yellow]ClawHub publishing is not yet supported. "
"Submit manually at https://clawhub.ai/submit[/]\n")
c.print("[yellow]ClawHub publishing is not yet supported. Submit manually at https://clawhub.ai/submit[/]\n")
else:
c.print(f"[bold red]Unknown target:[/] {target}. Use 'github' or 'clawhub'.\n")
def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
auth) -> tuple:
def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -> tuple:
"""Create a PR to a GitHub repo with the skill. Returns (success, message)."""
import httpx
@@ -627,7 +645,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
try:
resp = httpx.post(
f"https://api.github.com/repos/{target_repo}/forks",
headers=headers, timeout=30,
headers=headers,
timeout=30,
)
if resp.status_code in (200, 202):
fork = resp.json()
@@ -643,7 +662,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
try:
resp = httpx.get(
f"https://api.github.com/repos/{target_repo}",
headers=headers, timeout=15,
headers=headers,
timeout=15,
)
default_branch = resp.json().get("default_branch", "main")
except Exception:
@@ -653,7 +673,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
try:
resp = httpx.get(
f"https://api.github.com/repos/{fork_repo}/git/refs/heads/{default_branch}",
headers=headers, timeout=15,
headers=headers,
timeout=15,
)
base_sha = resp.json()["object"]["sha"]
except Exception as e:
@@ -664,7 +685,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
try:
httpx.post(
f"https://api.github.com/repos/{fork_repo}/git/refs",
headers=headers, timeout=15,
headers=headers,
timeout=15,
json={"ref": f"refs/heads/{branch_name}", "sha": base_sha},
)
except Exception as e:
@@ -678,10 +700,12 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
upload_path = f"skills/{skill_name}/{rel}"
try:
import base64
content_b64 = base64.b64encode(f.read_bytes()).decode()
httpx.put(
f"https://api.github.com/repos/{fork_repo}/contents/{upload_path}",
headers=headers, timeout=15,
headers=headers,
timeout=15,
json={
"message": f"Add {skill_name} skill: {rel}",
"content": content_b64,
@@ -695,11 +719,12 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
try:
resp = httpx.post(
f"https://api.github.com/repos/{target_repo}/pulls",
headers=headers, timeout=15,
headers=headers,
timeout=15,
json={
"title": f"Add skill: {skill_name}",
"body": f"Submitting the `{skill_name}` skill via Hermes Skills Hub.\n\n"
f"This skill was scanned by the Hermes Skills Guard before submission.",
f"This skill was scanned by the Hermes Skills Guard before submission.",
"head": f"{fork_repo.split('/')[0]}:{branch_name}",
"base": default_branch,
},
@@ -713,7 +738,7 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
return False, f"Network error creating PR: {e}"
def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> None:
def do_snapshot_export(output_path: str, console: Console | None = None) -> None:
"""Export current hub skill configuration to a portable JSON file."""
from tools.skills_hub import HubLockFile, TapsManager
@@ -726,16 +751,15 @@ def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> N
snapshot = {
"hermes_version": "0.1.0",
"exported_at": __import__("datetime").datetime.now(
__import__("datetime").timezone.utc
).isoformat(),
"exported_at": __import__("datetime").datetime.now(__import__("datetime").timezone.utc).isoformat(),
"skills": [
{
"name": entry["name"],
"source": entry.get("source", ""),
"identifier": entry.get("identifier", ""),
"category": str(Path(entry.get("install_path", "")).parent)
if "/" in entry.get("install_path", "") else "",
if "/" in entry.get("install_path", "")
else "",
}
for entry in installed
],
@@ -748,8 +772,7 @@ def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> N
c.print(f"[dim]{len(installed)} skill(s), {len(tap_list)} tap(s)[/]\n")
def do_snapshot_import(input_path: str, force: bool = False,
console: Optional[Console] = None) -> None:
def do_snapshot_import(input_path: str, force: bool = False, console: Console | None = None) -> None:
"""Re-install skills from a snapshot file."""
from tools.skills_hub import TapsManager
@@ -799,6 +822,7 @@ def do_snapshot_import(input_path: str, force: bool = False,
# CLI argparse entry point
# ---------------------------------------------------------------------------
def skills_command(args) -> None:
"""Router for `hermes skills <subcommand>` — called from hermes_cli/main.py."""
action = getattr(args, "skills_action", None)
@@ -839,7 +863,9 @@ def skills_command(args) -> None:
return
do_tap(tap_action, repo=repo)
else:
_console.print("Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
_console.print(
"Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n"
)
_console.print("Run 'hermes skills <command> --help' for details.\n")
@@ -847,7 +873,8 @@ def skills_command(args) -> None:
# Slash command entry point (/skills in chat)
# ---------------------------------------------------------------------------
def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
def handle_skills_slash(cmd: str, console: Console | None = None) -> None:
"""
Parse and dispatch `/skills <subcommand> [args]` from the chat interface.
@@ -1008,17 +1035,19 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
def _print_skills_help(console: Console) -> None:
"""Print help for the /skills slash command."""
console.print(Panel(
"[bold]Skills Hub Commands:[/]\n\n"
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
" [cyan]search[/] <query> Search registries for skills\n"
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"
" [cyan]list[/] [--source hub|builtin] List installed skills\n"
" [cyan]audit[/] [name] Re-scan hub skills for security\n"
" [cyan]uninstall[/] <name> Remove a hub-installed skill\n"
" [cyan]publish[/] <path> --repo <r> Publish a skill to GitHub via PR\n"
" [cyan]snapshot[/] export|import Export/import skill configurations\n"
" [cyan]tap[/] list|add|remove Manage skill sources\n",
title="/skills",
))
console.print(
Panel(
"[bold]Skills Hub Commands:[/]\n\n"
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
" [cyan]search[/] <query> Search registries for skills\n"
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"
" [cyan]list[/] [--source hub|builtin] List installed skills\n"
" [cyan]audit[/] [name] Re-scan hub skills for security\n"
" [cyan]uninstall[/] <name> Remove a hub-installed skill\n"
" [cyan]publish[/] <path> --repo <r> Publish a skill to GitHub via PR\n"
" [cyan]snapshot[/] export|import Export/import skill configurations\n"
" [cyan]tap[/] list|add|remove Manage skill sources\n",
title="/skills",
)
)

View File

@@ -5,21 +5,25 @@ Shows the status of all Hermes Agent components.
"""
import os
import sys
import subprocess
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
from datetime import UTC
from hermes_cli.colors import Colors, color
from hermes_cli.config import get_env_path, get_env_value
from hermes_constants import OPENROUTER_MODELS_URL
def check_mark(ok: bool) -> str:
if ok:
return color("", Colors.GREEN)
return color("", Colors.RED)
def redact_key(key: str) -> str:
"""Redact an API key for display."""
if not key:
@@ -33,7 +37,8 @@ def _format_iso_timestamp(value) -> str:
"""Format ISO timestamps for status output, converting to local timezone."""
if not value or not isinstance(value, str):
return "(unknown)"
from datetime import datetime, timezone
from datetime import datetime
text = value.strip()
if not text:
return "(unknown)"
@@ -42,7 +47,7 @@ def _format_iso_timestamp(value) -> str:
try:
parsed = datetime.fromisoformat(text)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
parsed = parsed.replace(tzinfo=UTC)
except Exception:
return value
return parsed.astimezone().strftime("%Y-%m-%d %H:%M:%S %Z")
@@ -50,14 +55,14 @@ def _format_iso_timestamp(value) -> str:
def show_status(args):
"""Show status of all Hermes Agent components."""
show_all = getattr(args, 'all', False)
deep = getattr(args, 'deep', False)
show_all = getattr(args, "all", False)
deep = getattr(args, "deep", False)
print()
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
print(color("│ ⚕ Hermes Agent Status │", Colors.CYAN))
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
# =========================================================================
# Environment
# =========================================================================
@@ -65,19 +70,19 @@ def show_status(args):
print(color("◆ Environment", Colors.CYAN, Colors.BOLD))
print(f" Project: {PROJECT_ROOT}")
print(f" Python: {sys.version.split()[0]}")
env_path = get_env_path()
print(f" .env file: {check_mark(env_path.exists())} {'exists' if env_path.exists() else 'not found'}")
# =========================================================================
# API Keys
# =========================================================================
print()
print(color("◆ API Keys", Colors.CYAN, Colors.BOLD))
keys = {
"OpenRouter": "OPENROUTER_API_KEY",
"Anthropic": "ANTHROPIC_API_KEY",
"Anthropic": "ANTHROPIC_API_KEY",
"OpenAI": "OPENAI_API_KEY",
"Z.AI/GLM": "GLM_API_KEY",
"Kimi": "KIMI_API_KEY",
@@ -91,7 +96,7 @@ def show_status(args):
"ElevenLabs": "ELEVENLABS_API_KEY",
"GitHub": "GITHUB_TOKEN",
}
for name, env_var in keys.items():
value = get_env_value(env_var) or ""
has_key = bool(value)
@@ -105,7 +110,8 @@ def show_status(args):
print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD))
try:
from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status
from hermes_cli.auth import get_codex_auth_status, get_nous_auth_status
nous_status = get_nous_auth_status()
codex_status = get_codex_auth_status()
except Exception:
@@ -148,10 +154,10 @@ def show_status(args):
print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD))
apikey_providers = {
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
"Kimi / Moonshot": ("KIMI_API_KEY",),
"MiniMax": ("MINIMAX_API_KEY",),
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
"Kimi / Moonshot": ("KIMI_API_KEY",),
"MiniMax": ("MINIMAX_API_KEY",),
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
}
for pname, env_vars in apikey_providers.items():
key_val = ""
@@ -168,19 +174,20 @@ def show_status(args):
# =========================================================================
print()
print(color("◆ Terminal Backend", Colors.CYAN, Colors.BOLD))
terminal_env = os.getenv("TERMINAL_ENV", "")
if not terminal_env:
# Fall back to config file value when env var isn't set
# (hermes status doesn't go through cli.py's config loading)
try:
from hermes_cli.config import load_config
_cfg = load_config()
terminal_env = _cfg.get("terminal", {}).get("backend", "local")
except Exception:
terminal_env = "local"
print(f" Backend: {terminal_env}")
if terminal_env == "ssh":
ssh_host = os.getenv("TERMINAL_SSH_HOST", "")
ssh_user = os.getenv("TERMINAL_SSH_USER", "")
@@ -192,16 +199,16 @@ def show_status(args):
elif terminal_env == "daytona":
daytona_image = os.getenv("TERMINAL_DAYTONA_IMAGE", "nikolaik/python-nodejs:python3.11-nodejs20")
print(f" Daytona Image: {daytona_image}")
sudo_password = os.getenv("SUDO_PASSWORD", "")
print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}")
# =========================================================================
# Messaging Platforms
# =========================================================================
print()
print(color("◆ Messaging Platforms", Colors.CYAN, Colors.BOLD))
platforms = {
"Telegram": ("TELEGRAM_BOT_TOKEN", "TELEGRAM_HOME_CHANNEL"),
"Discord": ("DISCORD_BOT_TOKEN", "DISCORD_HOME_CHANNEL"),
@@ -209,59 +216,52 @@ def show_status(args):
"Signal": ("SIGNAL_HTTP_URL", "SIGNAL_HOME_CHANNEL"),
"Slack": ("SLACK_BOT_TOKEN", None),
}
for name, (token_var, home_var) in platforms.items():
token = os.getenv(token_var, "")
has_token = bool(token)
home_channel = ""
if home_var:
home_channel = os.getenv(home_var, "")
status = "configured" if has_token else "not configured"
if home_channel:
status += f" (home: {home_channel})"
print(f" {name:<12} {check_mark(has_token)} {status}")
# =========================================================================
# Gateway Status
# =========================================================================
print()
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
if sys.platform.startswith('linux'):
result = subprocess.run(
["systemctl", "--user", "is-active", "hermes-gateway"],
capture_output=True,
text=True
)
if sys.platform.startswith("linux"):
result = subprocess.run(["systemctl", "--user", "is-active", "hermes-gateway"], capture_output=True, text=True)
is_active = result.stdout.strip() == "active"
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
print(f" Manager: systemd (user)")
elif sys.platform == 'darwin':
result = subprocess.run(
["launchctl", "list", "ai.hermes.gateway"],
capture_output=True,
text=True
)
print(" Manager: systemd (user)")
elif sys.platform == "darwin":
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
is_loaded = result.returncode == 0
print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}")
print(f" Manager: launchd")
print(" Manager: launchd")
else:
print(f" Status: {color('N/A', Colors.DIM)}")
print(f" Manager: (not supported on this platform)")
print(" Manager: (not supported on this platform)")
# =========================================================================
# Cron Jobs
# =========================================================================
print()
print(color("◆ Scheduled Jobs", Colors.CYAN, Colors.BOLD))
jobs_file = Path.home() / ".hermes" / "cron" / "jobs.json"
if jobs_file.exists():
import json
try:
with open(jobs_file) as f:
data = json.load(f)
@@ -269,56 +269,57 @@ def show_status(args):
enabled_jobs = [j for j in jobs if j.get("enabled", True)]
print(f" Jobs: {len(enabled_jobs)} active, {len(jobs)} total")
except Exception:
print(f" Jobs: (error reading jobs file)")
print(" Jobs: (error reading jobs file)")
else:
print(f" Jobs: 0")
print(" Jobs: 0")
# =========================================================================
# Sessions
# =========================================================================
print()
print(color("◆ Sessions", Colors.CYAN, Colors.BOLD))
sessions_file = Path.home() / ".hermes" / "sessions" / "sessions.json"
if sessions_file.exists():
import json
try:
with open(sessions_file) as f:
data = json.load(f)
print(f" Active: {len(data)} session(s)")
except Exception:
print(f" Active: (error reading sessions file)")
print(" Active: (error reading sessions file)")
else:
print(f" Active: 0")
print(" Active: 0")
# =========================================================================
# Deep checks
# =========================================================================
if deep:
print()
print(color("◆ Deep Checks", Colors.CYAN, Colors.BOLD))
# Check OpenRouter connectivity
openrouter_key = os.getenv("OPENROUTER_API_KEY", "")
if openrouter_key:
try:
import httpx
response = httpx.get(
OPENROUTER_MODELS_URL,
headers={"Authorization": f"Bearer {openrouter_key}"},
timeout=10
OPENROUTER_MODELS_URL, headers={"Authorization": f"Bearer {openrouter_key}"}, timeout=10
)
ok = response.status_code == 200
print(f" OpenRouter: {check_mark(ok)} {'reachable' if ok else f'error ({response.status_code})'}")
except Exception as e:
print(f" OpenRouter: {check_mark(False)} error: {e}")
# Check gateway port
try:
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
result = sock.connect_ex(('127.0.0.1', 18789))
result = sock.connect_ex(("127.0.0.1", 18789))
sock.close()
# Port in use = gateway likely running
port_in_use = result == 0
@@ -326,7 +327,7 @@ def show_status(args):
print(f" Port 18789: {'in use' if port_in_use else 'available'}")
except OSError:
pass
print()
print(color("" * 60, Colors.DIM))
print(color(" Run 'hermes doctor' for detailed diagnostics", Colors.DIM))

View File

@@ -11,33 +11,37 @@ the `platform_toolsets` key.
import sys
from pathlib import Path
from typing import Dict, List, Set
import os
from hermes_cli.config import (
load_config, save_config, get_env_value, save_env_value,
get_hermes_home,
)
from hermes_cli.colors import Colors, color
from hermes_cli.config import (
get_env_value,
load_config,
save_config,
save_env_value,
)
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
def _print_info(text: str):
print(color(f" {text}", Colors.DIM))
def _print_success(text: str):
print(color(f"{text}", Colors.GREEN))
def _print_warning(text: str):
print(color(f"{text}", Colors.YELLOW))
def _print_error(text: str):
print(color(f"{text}", Colors.RED))
def _prompt(question: str, default: str = None, password: bool = False) -> str:
if default:
display = f"{question} [{default}]: "
@@ -46,6 +50,7 @@ def _prompt(question: str, default: str = None, password: bool = False) -> str:
try:
if password:
import getpass
value = getpass.getpass(color(display, Colors.YELLOW))
else:
value = input(color(display, Colors.YELLOW))
@@ -54,6 +59,7 @@ def _prompt(question: str, default: str = None, password: bool = False) -> str:
print()
return default or ""
def _prompt_yes_no(question: str, default: bool = True) -> bool:
default_str = "Y/n" if default else "y/N"
while True:
@@ -64,9 +70,9 @@ def _prompt_yes_no(question: str, default: bool = True) -> bool:
return default
if not value:
return default
if value in ('y', 'yes'):
if value in ("y", "yes"):
return True
if value in ('n', 'no'):
if value in ("n", "no"):
return False
@@ -76,24 +82,24 @@ def _prompt_yes_no(question: str, default: bool = True) -> bool:
# Each entry: (toolset_name, label, description)
# These map to keys in toolsets.py TOOLSETS dict.
CONFIGURABLE_TOOLSETS = [
("web", "🔍 Web Search & Scraping", "web_search, web_extract"),
("browser", "🌐 Browser Automation", "navigate, click, type, scroll"),
("terminal", "💻 Terminal & Processes", "terminal, process"),
("file", "📁 File Operations", "read, write, patch, search"),
("code_execution", "⚡ Code Execution", "execute_code"),
("vision", "👁️ Vision / Image Analysis", "vision_analyze"),
("image_gen", "🎨 Image Generation", "image_generate"),
("moa", "🧠 Mixture of Agents", "mixture_of_agents"),
("tts", "🔊 Text-to-Speech", "text_to_speech"),
("skills", "📚 Skills", "list, view, manage"),
("todo", "📋 Task Planning", "todo"),
("memory", "💾 Memory", "persistent memory across sessions"),
("session_search", "🔎 Session Search", "search past conversations"),
("clarify", "❓ Clarifying Questions", "clarify"),
("delegation", "👥 Task Delegation", "delegate_task"),
("cronjob", "⏰ Cron Jobs", "schedule, list, remove"),
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
("homeassistant", "🏠 Home Assistant", "smart home device control"),
("web", "🔍 Web Search & Scraping", "web_search, web_extract"),
("browser", "🌐 Browser Automation", "navigate, click, type, scroll"),
("terminal", "💻 Terminal & Processes", "terminal, process"),
("file", "📁 File Operations", "read, write, patch, search"),
("code_execution", "⚡ Code Execution", "execute_code"),
("vision", "👁️ Vision / Image Analysis", "vision_analyze"),
("image_gen", "🎨 Image Generation", "image_generate"),
("moa", "🧠 Mixture of Agents", "mixture_of_agents"),
("tts", "🔊 Text-to-Speech", "text_to_speech"),
("skills", "📚 Skills", "list, view, manage"),
("todo", "📋 Task Planning", "todo"),
("memory", "💾 Memory", "persistent memory across sessions"),
("session_search", "🔎 Session Search", "search past conversations"),
("clarify", "❓ Clarifying Questions", "clarify"),
("delegation", "👥 Task Delegation", "delegate_task"),
("cronjob", "⏰ Cron Jobs", "schedule, list, remove"),
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
("homeassistant", "🏠 Home Assistant", "smart home device control"),
]
# Toolsets that are OFF by default for new installs.
@@ -103,11 +109,11 @@ _DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl"}
# Platform display config
PLATFORMS = {
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
}
@@ -131,7 +137,11 @@ TOOL_CATEGORIES = {
"name": "OpenAI TTS",
"tag": "Premium - high quality voices",
"env_vars": [
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
{
"key": "VOICE_TOOLS_OPENAI_KEY",
"prompt": "OpenAI API key",
"url": "https://platform.openai.com/api-keys",
},
],
"tts_provider": "openai",
},
@@ -139,7 +149,11 @@ TOOL_CATEGORIES = {
"name": "ElevenLabs",
"tag": "Premium - most natural voices",
"env_vars": [
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
{
"key": "ELEVENLABS_API_KEY",
"prompt": "ElevenLabs API key",
"url": "https://elevenlabs.io/app/settings/api-keys",
},
],
"tts_provider": "elevenlabs",
},
@@ -224,7 +238,11 @@ TOOL_CATEGORIES = {
"name": "Tinker / Atropos",
"tag": "RL training platform",
"env_vars": [
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
{
"key": "TINKER_API_KEY",
"prompt": "Tinker API key",
"url": "https://tinker-console.thinkingmachines.ai/keys",
},
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
],
"post_setup": "rl_training",
@@ -236,24 +254,26 @@ TOOL_CATEGORIES = {
# Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES.
# Used as a fallback for tools like vision/moa that just need an API key.
TOOLSET_ENV_REQUIREMENTS = {
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
}
# ─── Post-Setup Hooks ─────────────────────────────────────────────────────────
def _run_post_setup(post_setup_key: str):
"""Run post-setup hooks for tools that need extra installation steps."""
import shutil
if post_setup_key == "browserbase":
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
if not node_modules.exists() and shutil.which("npm"):
_print_info(" Installing Node.js dependencies for browser tools...")
import subprocess
result = subprocess.run(
["npm", "install", "--silent"],
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
["npm", "install", "--silent"], capture_output=True, text=True, cwd=str(PROJECT_ROOT)
)
if result.returncode == 0:
_print_success(" Node.js dependencies installed")
@@ -270,16 +290,17 @@ def _run_post_setup(post_setup_key: str):
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
_print_info(" Installing tinker-atropos submodule...")
import subprocess
uv_bin = shutil.which("uv")
if uv_bin:
result = subprocess.run(
[uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)],
capture_output=True, text=True
capture_output=True,
text=True,
)
else:
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
capture_output=True, text=True
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)], capture_output=True, text=True
)
if result.returncode == 0:
_print_success(" tinker-atropos installed")
@@ -294,7 +315,8 @@ def _run_post_setup(post_setup_key: str):
# ─── Platform / Toolset Helpers ───────────────────────────────────────────────
def _get_enabled_platforms() -> List[str]:
def _get_enabled_platforms() -> list[str]:
"""Return platform keys that are configured (have tokens or are CLI)."""
enabled = ["cli"]
if get_env_value("TELEGRAM_BOT_TOKEN"):
@@ -308,9 +330,9 @@ def _get_enabled_platforms() -> List[str]:
return enabled
def _get_platform_tools(config: dict, platform: str) -> Set[str]:
def _get_platform_tools(config: dict, platform: str) -> set[str]:
"""Resolve which individual toolset names are enabled for a platform."""
from toolsets import resolve_toolset, TOOLSETS
from toolsets import resolve_toolset
platform_toolsets = config.get("platform_toolsets", {})
toolset_names = platform_toolsets.get(platform)
@@ -335,7 +357,7 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
return enabled_toolsets
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[str]):
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: set[str]):
"""Save the selected toolset keys for a platform to config."""
config.setdefault("platform_toolsets", {})
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys)
@@ -364,6 +386,7 @@ def _toolset_has_keys(ts_key: str) -> bool:
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
rendering bugs in tmux, iTerm, and other non-standard terminals."""
@@ -371,6 +394,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
# Curses-based single-select — works in tmux, iTerm, and standard terminals
try:
import curses
result_holder = [default]
def _curses_menu(stdscr):
@@ -386,8 +410,9 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
try:
stdscr.addnstr(0, 0, question, max_x - 1,
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
stdscr.addnstr(
0, 0, question, max_x - 1, curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)
)
except curses.error:
pass
@@ -410,14 +435,14 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord('k')):
if key in (curses.KEY_UP, ord("k")):
cursor = (cursor - 1) % len(choices)
elif key in (curses.KEY_DOWN, ord('j')):
elif key in (curses.KEY_DOWN, ord("j")):
cursor = (cursor + 1) % len(choices)
elif key in (curses.KEY_ENTER, 10, 13):
result_holder[0] = cursor
return
elif key in (27, ord('q')):
elif key in (27, ord("q")):
return
curses.wrapper(_curses_menu)
@@ -431,7 +456,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
for i, c in enumerate(choices):
marker = "" if i == default else ""
style = Colors.GREEN if i == default else ""
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
print(color(f" {marker} {i + 1}. {c}", style) if style else f" {marker} {i + 1}. {c}")
while True:
try:
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
@@ -445,7 +470,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
return default
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str]:
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
labels = []
@@ -455,15 +480,13 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
suffix = " [no API key]"
labels.append(f"{ts_label} ({ts_desc}){suffix}")
pre_selected_indices = [
i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)
if ts_key in enabled
]
pre_selected_indices = [i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS) if ts_key in enabled]
# Curses-based multi-select — arrow keys + space to toggle + enter to confirm.
# simple_term_menu has rendering bugs in tmux, iTerm, and other terminals.
try:
import curses
selected = set(pre_selected_indices)
result_holder = [None]
@@ -483,7 +506,13 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
max_y, max_x = stdscr.getmaxyx()
header = f"Tools for {platform_label} — ↑↓ navigate, SPACE toggle, ENTER confirm"
try:
stdscr.addnstr(0, 0, header, max_x - 1, curses.A_BOLD | curses.color_pair(2) if curses.has_colors() else curses.A_BOLD)
stdscr.addnstr(
0,
0,
header,
max_x - 1,
curses.A_BOLD | curses.color_pair(2) if curses.has_colors() else curses.A_BOLD,
)
except curses.error:
pass
@@ -514,11 +543,11 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord('k')):
if key in (curses.KEY_UP, ord("k")):
cursor = (cursor - 1) % len(labels)
elif key in (curses.KEY_DOWN, ord('j')):
elif key in (curses.KEY_DOWN, ord("j")):
cursor = (cursor + 1) % len(labels)
elif key == ord(' '):
elif key == ord(" "):
if cursor in selected:
selected.discard(cursor)
else:
@@ -526,7 +555,7 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
elif key in (curses.KEY_ENTER, 10, 13):
result_holder[0] = {CONFIGURABLE_TOOLSETS[i][0] for i in selected}
return
elif key in (27, ord('q')): # ESC or q
elif key in (27, ord("q")): # ESC or q
result_holder[0] = enabled
return
@@ -565,9 +594,10 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
# ─── Provider-Aware Configuration ────────────────────────────────────────────
def _configure_toolset(ts_key: str, config: dict):
"""Configure a toolset - provider selection + API keys.
Uses TOOL_CATEGORIES for provider-aware config, falls back to simple
env var prompts for toolsets not in TOOL_CATEGORIES.
"""
@@ -591,7 +621,9 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
req = cat["requires_python"]
if sys.version_info < req:
print()
_print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})")
_print_error(
f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})"
)
_print_info(" Upgrade Python and reinstall to enable this tool.")
return
@@ -610,7 +642,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
# Multiple providers - let user choose
print()
# Use custom title if provided (e.g. "Select Search Provider")
title = cat.get("setup_title", f"Choose a provider")
title = cat.get("setup_title", "Choose a provider")
print(color(f" --- {icon} {name} - {title} ---", Colors.CYAN))
if cat.get("setup_note"):
_print_info(f" {cat['setup_note']}")
@@ -626,7 +658,11 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
configured = " [active]"
elif not env_vars:
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
configured = (
" [active]"
if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "")
else ""
)
else:
configured = " [configured]"
provider_choices.append(f"{p['name']}{tag}{configured}")
@@ -688,9 +724,9 @@ def _configure_provider(provider: dict, config: dict):
if value:
save_env_value(var["key"], value)
_print_success(f" Saved")
_print_success(" Saved")
else:
_print_warning(f" Skipped")
_print_warning(" Skipped")
all_configured = False
# Run post-setup hooks if needed
@@ -721,9 +757,9 @@ def _configure_simple_requirements(ts_key: str):
value = _prompt(f" {var}", password=True)
if value and value.strip():
save_env_value(var, value.strip())
_print_success(f" Saved")
_print_success(" Saved")
else:
_print_warning(f" Skipped")
_print_warning(" Skipped")
def _reconfigure_tool(config: dict):
@@ -827,9 +863,9 @@ def _reconfigure_provider(provider: dict, config: dict):
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
if value and value.strip():
save_env_value(var["key"], value.strip())
_print_success(f" Updated")
_print_success(" Updated")
else:
_print_info(f" Kept current")
_print_info(" Kept current")
def _reconfigure_simple_requirements(ts_key: str):
@@ -851,13 +887,14 @@ def _reconfigure_simple_requirements(ts_key: str):
value = _prompt(f" {var} (Enter to keep current)", password=True)
if value and value.strip():
save_env_value(var, value.strip())
_print_success(f" Updated")
_print_success(" Updated")
else:
_print_info(f" Kept current")
_print_info(" Kept current")
# ─── Main Entry Point ─────────────────────────────────────────────────────────
def tools_command(args=None, first_install: bool = False, config: dict = None):
"""Entry point for `hermes tools` and `hermes setup tools`.
@@ -907,7 +944,8 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
# TTS (Edge vs OpenAI vs ElevenLabs), etc. are shown even when
# a free provider exists.
to_configure = [
ts_key for ts_key in sorted(new_enabled)
ts_key
for ts_key in sorted(new_enabled)
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)
]
@@ -981,7 +1019,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
# Configure newly enabled toolsets that need API keys
for ts_key in sorted(added):
if (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key):
if not _toolset_has_keys(ts_key):
_configure_toolset(ts_key, config)

View File

@@ -7,23 +7,25 @@ Provides options for:
"""
import os
import sys
import shutil
import subprocess
from pathlib import Path
from typing import Optional
from hermes_cli.colors import Colors, color
def log_info(msg: str):
print(f"{color('', Colors.CYAN)} {msg}")
def log_success(msg: str):
print(f"{color('', Colors.GREEN)} {msg}")
def log_warn(msg: str):
print(f"{color('', Colors.YELLOW)} {msg}")
def log_error(msg: str):
print(f"{color('', Colors.RED)} {msg}")
@@ -42,7 +44,7 @@ def find_shell_configs() -> list:
"""Find shell configuration files that might have PATH entries."""
home = Path.home()
configs = []
candidates = [
home / ".bashrc",
home / ".bash_profile",
@@ -50,11 +52,11 @@ def find_shell_configs() -> list:
home / ".zshrc",
home / ".zprofile",
]
for config in candidates:
if config.exists():
configs.append(config)
return configs
@@ -62,45 +64,45 @@ def remove_path_from_shell_configs():
"""Remove Hermes PATH entries from shell configuration files."""
configs = find_shell_configs()
removed_from = []
for config_path in configs:
try:
content = config_path.read_text()
original_content = content
# Remove lines containing hermes-agent or hermes PATH entries
new_lines = []
skip_next = False
for line in content.split('\n'):
for line in content.split("\n"):
# Skip the "# Hermes Agent" comment and following line
if '# Hermes Agent' in line or '# hermes-agent' in line:
if "# Hermes Agent" in line or "# hermes-agent" in line:
skip_next = True
continue
if skip_next and ('hermes' in line.lower() and 'PATH' in line):
if skip_next and ("hermes" in line.lower() and "PATH" in line):
skip_next = False
continue
skip_next = False
# Remove any PATH line containing hermes
if 'hermes' in line.lower() and ('PATH=' in line or 'path=' in line.lower()):
if "hermes" in line.lower() and ("PATH=" in line or "path=" in line.lower()):
continue
new_lines.append(line)
new_content = '\n'.join(new_lines)
new_content = "\n".join(new_lines)
# Clean up multiple blank lines
while '\n\n\n' in new_content:
new_content = new_content.replace('\n\n\n', '\n\n')
while "\n\n\n" in new_content:
new_content = new_content.replace("\n\n\n", "\n\n")
if new_content != original_content:
config_path.write_text(new_content)
removed_from.append(config_path)
except Exception as e:
log_warn(f"Could not update {config_path}: {e}")
return removed_from
@@ -110,61 +112,49 @@ def remove_wrapper_script():
Path.home() / ".local" / "bin" / "hermes",
Path("/usr/local/bin/hermes"),
]
removed = []
for wrapper in wrapper_paths:
if wrapper.exists():
try:
# Check if it's our wrapper (contains hermes_cli reference)
content = wrapper.read_text()
if 'hermes_cli' in content or 'hermes-agent' in content:
if "hermes_cli" in content or "hermes-agent" in content:
wrapper.unlink()
removed.append(wrapper)
except Exception as e:
log_warn(f"Could not remove {wrapper}: {e}")
return removed
def uninstall_gateway_service():
"""Stop and uninstall the gateway service if running."""
import platform
if platform.system() != "Linux":
return False
service_file = Path.home() / ".config" / "systemd" / "user" / "hermes-gateway.service"
if not service_file.exists():
return False
try:
# Stop the service
subprocess.run(
["systemctl", "--user", "stop", "hermes-gateway"],
capture_output=True,
check=False
)
subprocess.run(["systemctl", "--user", "stop", "hermes-gateway"], capture_output=True, check=False)
# Disable the service
subprocess.run(
["systemctl", "--user", "disable", "hermes-gateway"],
capture_output=True,
check=False
)
subprocess.run(["systemctl", "--user", "disable", "hermes-gateway"], capture_output=True, check=False)
# Remove service file
service_file.unlink()
# Reload systemd
subprocess.run(
["systemctl", "--user", "daemon-reload"],
capture_output=True,
check=False
)
subprocess.run(["systemctl", "--user", "daemon-reload"], capture_output=True, check=False)
return True
except Exception as e:
log_warn(f"Could not fully remove gateway service: {e}")
return False
@@ -173,20 +163,20 @@ def uninstall_gateway_service():
def run_uninstall(args):
"""
Run the uninstall process.
Options:
- Full uninstall: removes code + ~/.hermes/ (configs, data, logs)
- Keep data: removes code but keeps ~/.hermes/ for future reinstall
"""
project_root = get_project_root()
hermes_home = get_hermes_home()
print()
print(color("┌─────────────────────────────────────────────────────────┐", Colors.MAGENTA, Colors.BOLD))
print(color("│ ⚕ Hermes Agent Uninstaller │", Colors.MAGENTA, Colors.BOLD))
print(color("└─────────────────────────────────────────────────────────┘", Colors.MAGENTA, Colors.BOLD))
print()
# Show what will be affected
print(color("Current Installation:", Colors.CYAN, Colors.BOLD))
print(f" Code: {project_root}")
@@ -194,7 +184,7 @@ def run_uninstall(args):
print(f" Secrets: {hermes_home / '.env'}")
print(f" Data: {hermes_home / 'cron/'}, {hermes_home / 'sessions/'}, {hermes_home / 'logs/'}")
print()
# Ask for confirmation
print(color("Uninstall Options:", Colors.YELLOW, Colors.BOLD))
print()
@@ -206,21 +196,21 @@ def run_uninstall(args):
print()
print(" 3) " + color("Cancel", Colors.CYAN) + " - Don't uninstall")
print()
try:
choice = input(color("Select option [1/2/3]: ", Colors.BOLD)).strip()
except (KeyboardInterrupt, EOFError):
print()
print("Cancelled.")
return
if choice == "3" or choice.lower() in ("c", "cancel", "q", "quit", "n", "no"):
print()
print("Uninstall cancelled.")
return
full_uninstall = (choice == "2")
full_uninstall = choice == "2"
# Final confirmation
print()
if full_uninstall:
@@ -228,7 +218,7 @@ def run_uninstall(args):
print(color(" Including: configs, API keys, sessions, scheduled jobs, logs", Colors.RED))
else:
print("This will remove the Hermes code but keep your configuration and data.")
print()
try:
confirm = input(f"Type '{color('yes', Colors.YELLOW)}' to confirm: ").strip().lower()
@@ -236,23 +226,23 @@ def run_uninstall(args):
print()
print("Cancelled.")
return
if confirm != "yes":
print()
print("Uninstall cancelled.")
return
print()
print(color("Uninstalling...", Colors.CYAN, Colors.BOLD))
print()
# 1. Stop and uninstall gateway service
log_info("Checking for gateway service...")
if uninstall_gateway_service():
log_success("Gateway service stopped and removed")
else:
log_info("No gateway service found")
# 2. Remove PATH entries from shell configs
log_info("Removing PATH entries from shell configs...")
removed_configs = remove_path_from_shell_configs()
@@ -261,7 +251,7 @@ def run_uninstall(args):
log_success(f"Updated {config}")
else:
log_info("No PATH entries found to remove")
# 3. Remove wrapper script
log_info("Removing hermes command...")
removed_wrappers = remove_wrapper_script()
@@ -270,10 +260,10 @@ def run_uninstall(args):
log_success(f"Removed {wrapper}")
else:
log_info("No wrapper script found")
# 4. Remove installation directory (code)
log_info(f"Removing installation directory...")
log_info("Removing installation directory...")
# Check if we're running from within the install dir
# We need to be careful here
try:
@@ -289,7 +279,7 @@ def run_uninstall(args):
except Exception as e:
log_warn(f"Could not fully remove {project_root}: {e}")
log_info("You may need to manually remove it")
# 5. Optionally remove ~/.hermes/ data directory
if full_uninstall:
log_info("Removing configuration and data...")
@@ -302,22 +292,27 @@ def run_uninstall(args):
log_info("You may need to manually remove it")
else:
log_info(f"Keeping configuration and data in {hermes_home}")
# Done
print()
print(color("┌─────────────────────────────────────────────────────────┐", Colors.GREEN, Colors.BOLD))
print(color("│ ✓ Uninstall Complete! │", Colors.GREEN, Colors.BOLD))
print(color("└─────────────────────────────────────────────────────────┘", Colors.GREEN, Colors.BOLD))
print()
if not full_uninstall:
print(color("Your configuration and data have been preserved:", Colors.CYAN))
print(f" {hermes_home}/")
print()
print("To reinstall later with your existing settings:")
print(color(" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash", Colors.DIM))
print(
color(
" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash",
Colors.DIM,
)
)
print()
print(color("Reload your shell to complete the process:", Colors.YELLOW))
print(" source ~/.bashrc # or ~/.zshrc")
print()

View File

@@ -19,8 +19,7 @@ import os
import sqlite3
import time
from pathlib import Path
from typing import Dict, Any, List, Optional
from typing import Any
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
@@ -156,8 +155,7 @@ class SessionDB:
# since the title column is guaranteed to exist at this point)
try:
cursor.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
"ON sessions(title) WHERE title IS NOT NULL"
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique ON sessions(title) WHERE title IS NOT NULL"
)
except sqlite3.OperationalError:
pass # Index already exists
@@ -185,7 +183,7 @@ class SessionDB:
session_id: str,
source: str,
model: str = None,
model_config: Dict[str, Any] = None,
model_config: dict[str, Any] = None,
system_prompt: str = None,
user_id: str = None,
parent_session_id: str = None,
@@ -225,9 +223,7 @@ class SessionDB:
)
self._conn.commit()
def update_token_counts(
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0
) -> None:
def update_token_counts(self, session_id: str, input_tokens: int = 0, output_tokens: int = 0) -> None:
"""Increment token counters on a session."""
self._conn.execute(
"""UPDATE sessions SET
@@ -238,11 +234,9 @@ class SessionDB:
)
self._conn.commit()
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
def get_session(self, session_id: str) -> dict[str, Any] | None:
"""Get a session by ID."""
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE id = ?", (session_id,)
)
cursor = self._conn.execute("SELECT * FROM sessions WHERE id = ?", (session_id,))
row = cursor.fetchone()
return dict(row) if row else None
@@ -250,7 +244,7 @@ class SessionDB:
MAX_TITLE_LENGTH = 100
@staticmethod
def sanitize_title(title: Optional[str]) -> Optional[str]:
def sanitize_title(title: str | None) -> str | None:
"""Validate and sanitize a session title.
- Strips leading/trailing whitespace
@@ -271,27 +265,26 @@ class SessionDB:
# Remove ASCII control characters (0x00-0x1F, 0x7F) but keep
# whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be
# normalized to spaces by the whitespace collapsing step below
cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title)
cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", title)
# Remove problematic Unicode control characters:
# - Zero-width chars (U+200B-U+200F, U+FEFF)
# - Directional overrides (U+202A-U+202E, U+2066-U+2069)
# - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB)
cleaned = re.sub(
r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]',
'', cleaned,
r"[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]",
"",
cleaned,
)
# Collapse internal whitespace runs and strip
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
cleaned = re.sub(r"\s+", " ", cleaned).strip()
if not cleaned:
return None
if len(cleaned) > SessionDB.MAX_TITLE_LENGTH:
raise ValueError(
f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})"
)
raise ValueError(f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})")
return cleaned
@@ -312,9 +305,7 @@ class SessionDB:
)
conflict = cursor.fetchone()
if conflict:
raise ValueError(
f"Title '{title}' is already in use by session {conflict['id']}"
)
raise ValueError(f"Title '{title}' is already in use by session {conflict['id']}")
cursor = self._conn.execute(
"UPDATE sessions SET title = ? WHERE id = ?",
(title, session_id),
@@ -322,23 +313,19 @@ class SessionDB:
self._conn.commit()
return cursor.rowcount > 0
def get_session_title(self, session_id: str) -> Optional[str]:
def get_session_title(self, session_id: str) -> str | None:
"""Get the title for a session, or None."""
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE id = ?", (session_id,)
)
cursor = self._conn.execute("SELECT title FROM sessions WHERE id = ?", (session_id,))
row = cursor.fetchone()
return row["title"] if row else None
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
def get_session_by_title(self, title: str) -> dict[str, Any] | None:
"""Look up a session by exact title. Returns session dict or None."""
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE title = ?", (title,)
)
cursor = self._conn.execute("SELECT * FROM sessions WHERE title = ?", (title,))
row = cursor.fetchone()
return dict(row) if row else None
def resolve_session_by_title(self, title: str) -> Optional[str]:
def resolve_session_by_title(self, title: str) -> str | None:
"""Resolve a title to a session ID, preferring the latest in a lineage.
If the exact title exists, returns that session's ID.
@@ -353,8 +340,7 @@ class SessionDB:
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
cursor = self._conn.execute(
"SELECT id, title, started_at FROM sessions "
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
"SELECT id, title, started_at FROM sessions WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
(f"{escaped} #%",),
)
numbered = cursor.fetchall()
@@ -373,8 +359,9 @@ class SessionDB:
the highest existing number and increments.
"""
import re
# Strip existing #N suffix to find the true base
match = re.match(r'^(.*?) #(\d+)$', base_title)
match = re.match(r"^(.*?) #(\d+)$", base_title)
if match:
base = match.group(1)
else:
@@ -395,7 +382,7 @@ class SessionDB:
# Find the highest number
max_num = 1 # The unnumbered original counts as #1
for t in existing:
m = re.match(r'^.* #(\d+)$', t)
m = re.match(r"^.* #(\d+)$", t)
if m:
max_num = max(max_num, int(m.group(1)))
@@ -406,7 +393,7 @@ class SessionDB:
source: str = None,
limit: int = 20,
offset: int = 0,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""List sessions with preview (first user message) and last active timestamp.
Returns dicts with keys: id, source, model, title, started_at, ended_at,
@@ -506,7 +493,7 @@ class SessionDB:
self._conn.commit()
return msg_id
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
def get_messages(self, session_id: str) -> list[dict[str, Any]]:
"""Load all messages for a session, ordered by timestamp."""
cursor = self._conn.execute(
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
@@ -524,7 +511,7 @@ class SessionDB:
result.append(msg)
return result
def get_messages_as_conversation(self, session_id: str) -> List[Dict[str, Any]]:
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
"""
Load messages in the OpenAI conversation format (role + content dicts).
Used by the gateway to restore conversation history.
@@ -556,11 +543,11 @@ class SessionDB:
def search_messages(
self,
query: str,
source_filter: List[str] = None,
role_filter: List[str] = None,
source_filter: list[str] = None,
role_filter: list[str] = None,
limit: int = 20,
offset: int = 0,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Full-text search across session messages using FTS5.
@@ -628,8 +615,7 @@ class SessionDB:
(match["session_id"], match["id"], match["id"]),
)
context_msgs = [
{"role": r["role"], "content": (r["content"] or "")[:200]}
for r in ctx_cursor.fetchall()
{"role": r["role"], "content": (r["content"] or "")[:200]} for r in ctx_cursor.fetchall()
]
match["context"] = context_msgs
except Exception:
@@ -645,7 +631,7 @@ class SessionDB:
source: str = None,
limit: int = 20,
offset: int = 0,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""List sessions, optionally filtered by source."""
if source:
cursor = self._conn.execute(
@@ -666,9 +652,7 @@ class SessionDB:
def session_count(self, source: str = None) -> int:
"""Count sessions, optionally filtered by source."""
if source:
cursor = self._conn.execute(
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
)
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions WHERE source = ?", (source,))
else:
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
return cursor.fetchone()[0]
@@ -676,9 +660,7 @@ class SessionDB:
def message_count(self, session_id: str = None) -> int:
"""Count messages, optionally for a specific session."""
if session_id:
cursor = self._conn.execute(
"SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)
)
cursor = self._conn.execute("SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,))
else:
cursor = self._conn.execute("SELECT COUNT(*) FROM messages")
return cursor.fetchone()[0]
@@ -687,7 +669,7 @@ class SessionDB:
# Export and cleanup
# =========================================================================
def export_session(self, session_id: str) -> Optional[Dict[str, Any]]:
def export_session(self, session_id: str) -> dict[str, Any] | None:
"""Export a single session with all its messages as a dict."""
session = self.get_session(session_id)
if not session:
@@ -695,7 +677,7 @@ class SessionDB:
messages = self.get_messages(session_id)
return {**session, "messages": messages}
def export_all(self, source: str = None) -> List[Dict[str, Any]]:
def export_all(self, source: str = None) -> list[dict[str, Any]]:
"""
Export all sessions (with messages) as a list of dicts.
Suitable for writing to a JSONL file for backup/analysis.
@@ -709,9 +691,7 @@ class SessionDB:
def clear_messages(self, session_id: str) -> None:
"""Delete all messages for a session and reset its counters."""
self._conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
self._conn.execute(
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
(session_id,),
@@ -720,9 +700,7 @@ class SessionDB:
def delete_session(self, session_id: str) -> bool:
"""Delete a session and all its messages. Returns True if found."""
cursor = self._conn.execute(
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
)
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,))
if cursor.fetchone()[0] == 0:
return False
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
@@ -736,6 +714,7 @@ class SessionDB:
Only prunes ended sessions (not active ones).
"""
import time as _time
cutoff = _time.time() - (older_than_days * 86400)
if source:

View File

@@ -20,11 +20,10 @@ Public API (signatures preserved from the original 2,400-line version):
check_tool_availability(quiet) -> tuple
"""
import json
import asyncio
import os
import json
import logging
from typing import Dict, Any, List, Optional, Tuple
from typing import Any
from tools.registry import registry
from toolsets import resolve_toolset, validate_toolset
@@ -36,6 +35,7 @@ logger = logging.getLogger(__name__)
# Async Bridging (single source of truth -- used by registry.dispatch too)
# =============================================================================
def _run_async(coro):
"""Run an async coroutine from a sync context.
@@ -56,6 +56,7 @@ def _run_async(coro):
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(asyncio.run, coro)
return future.result(timeout=300)
@@ -66,6 +67,7 @@ def _run_async(coro):
# Tool Discovery (importing each module triggers its registry.register calls)
# =============================================================================
def _discover_tools():
"""Import all tool modules to trigger their registry.register() calls.
@@ -97,6 +99,7 @@ def _discover_tools():
"tools.homeassistant_tool",
]
import importlib
for mod_name in _modules:
try:
importlib.import_module(mod_name)
@@ -109,6 +112,7 @@ _discover_tools()
# MCP tool discovery (external MCP servers from config)
try:
from tools.mcp_tool import discover_mcp_tools
discover_mcp_tools()
except Exception as e:
logger.debug("MCP tool discovery failed: %s", e)
@@ -118,13 +122,13 @@ except Exception as e:
# Backward-compat constants (built once after discovery)
# =============================================================================
TOOL_TO_TOOLSET_MAP: Dict[str, str] = registry.get_tool_to_toolset_map()
TOOL_TO_TOOLSET_MAP: dict[str, str] = registry.get_tool_to_toolset_map()
TOOLSET_REQUIREMENTS: Dict[str, dict] = registry.get_toolset_requirements()
TOOLSET_REQUIREMENTS: dict[str, dict] = registry.get_toolset_requirements()
# Resolved tool names from the last get_tool_definitions() call.
# Used by code_execution_tool to know which tools are available in this session.
_last_resolved_tool_names: List[str] = []
_last_resolved_tool_names: list[str] = []
# =============================================================================
@@ -139,18 +143,29 @@ _LEGACY_TOOLSET_MAP = {
"image_tools": ["image_generate"],
"skills_tools": ["skills_list", "skill_view", "skill_manage"],
"browser_tools": [
"browser_navigate", "browser_snapshot", "browser_click",
"browser_type", "browser_scroll", "browser_back",
"browser_press", "browser_close", "browser_get_images",
"browser_vision"
"browser_navigate",
"browser_snapshot",
"browser_click",
"browser_type",
"browser_scroll",
"browser_back",
"browser_press",
"browser_close",
"browser_get_images",
"browser_vision",
],
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
"rl_tools": [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_list_runs", "rl_test_inference"
"rl_list_environments",
"rl_select_environment",
"rl_get_current_config",
"rl_edit_config",
"rl_start_training",
"rl_check_status",
"rl_stop_training",
"rl_get_results",
"rl_list_runs",
"rl_test_inference",
],
"file_tools": ["read_file", "write_file", "patch", "search_files"],
"tts_tools": ["text_to_speech"],
@@ -161,11 +176,12 @@ _LEGACY_TOOLSET_MAP = {
# get_tool_definitions (the main schema provider)
# =============================================================================
def get_tool_definitions(
enabled_toolsets: List[str] = None,
disabled_toolsets: List[str] = None,
enabled_toolsets: list[str] = None,
disabled_toolsets: list[str] = None,
quiet_mode: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Get tool definitions for model API calls with toolset-based filtering.
@@ -200,6 +216,7 @@ def get_tool_definitions(
elif disabled_toolsets:
from toolsets import get_all_toolsets
for ts_name in get_all_toolsets():
tools_to_include.update(resolve_toolset(ts_name))
@@ -219,6 +236,7 @@ def get_tool_definitions(
print(f"⚠️ Unknown toolset: {toolset_name}")
else:
from toolsets import get_all_toolsets
for ts_name in get_all_toolsets():
tools_to_include.update(resolve_toolset(ts_name))
@@ -230,6 +248,7 @@ def get_tool_definitions(
# execute_code" even when the user disabled the web toolset (#560-discord).
if "execute_code" in tools_to_include:
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
dynamic_schema = build_execute_code_schema(sandbox_enabled)
for i, td in enumerate(filtered_tools):
@@ -263,9 +282,9 @@ _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
def handle_function_call(
function_name: str,
function_args: Dict[str, Any],
task_id: Optional[str] = None,
user_task: Optional[str] = None,
function_args: dict[str, Any],
task_id: str | None = None,
user_task: str | None = None,
) -> str:
"""
Main function call dispatcher that routes calls to the tool registry.
@@ -285,13 +304,15 @@ def handle_function_call(
if function_name == "execute_code":
return registry.dispatch(
function_name, function_args,
function_name,
function_args,
task_id=task_id,
enabled_tools=_last_resolved_tool_names,
)
return registry.dispatch(
function_name, function_args,
function_name,
function_args,
task_id=task_id,
user_task=user_task,
)
@@ -306,26 +327,27 @@ def handle_function_call(
# Backward-compat wrapper functions
# =============================================================================
def get_all_tool_names() -> List[str]:
def get_all_tool_names() -> list[str]:
"""Return all registered tool names."""
return registry.get_all_tool_names()
def get_toolset_for_tool(tool_name: str) -> Optional[str]:
def get_toolset_for_tool(tool_name: str) -> str | None:
"""Return the toolset a tool belongs to."""
return registry.get_toolset_for_tool(tool_name)
def get_available_toolsets() -> Dict[str, dict]:
def get_available_toolsets() -> dict[str, dict]:
"""Return toolset availability info for UI display."""
return registry.get_available_toolsets()
def check_toolset_requirements() -> Dict[str, bool]:
def check_toolset_requirements() -> dict[str, bool]:
"""Return {toolset: available_bool} for every registered toolset."""
return registry.check_toolset_requirements()
def check_tool_availability(quiet: bool = False) -> Tuple[List[str], List[dict]]:
def check_tool_availability(quiet: bool = False) -> tuple[list[str], list[dict]]:
"""Return (available_toolsets, unavailable_info)."""
return registry.check_tool_availability(quiet=quiet)

View File

@@ -40,7 +40,7 @@ dependencies = [
[project.optional-dependencies]
modal = ["swe-rex[modal]>=1.4.0"]
daytona = ["daytona>=0.148.0"]
dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0"]
dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0", "ruff", "pre-commit", "watchfiles"]
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
cron = ["croniter"]
slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
@@ -76,6 +76,46 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector
[tool.setuptools.packages.find]
include = ["tools", "hermes_cli", "gateway", "cron", "honcho_integration"]
[tool.ruff]
target-version = "py311"
line-length = 120
[tool.ruff.lint]
select = ["E", "F", "W", "I", "UP", "B", "SIM"]
ignore = [
"E402", # late imports — intentional throughout codebase
"E501", # line too long — handled by formatter where it can
"E731", # lambda assignments — used in registry pattern
"E741", # ambiguous variable name — existing patterns
"F811", # redefined unused — intentional overrides
"F841", # unused variable — cleanup separately
"B007", # unused loop variable — cleanup separately
"B904", # raise from — too noisy to gate on
"B905", # zip strict — cleanup separately
"B027", # empty method without abstract decorator
"SIM102", # collapsible if — readability preference
"SIM103", # needless bool — readability preference
"SIM105", # suppressible exception — existing pattern
"SIM108", # ternary — readability preference
"SIM110", # reimplemented builtin
"SIM112", # uncapitalized env var
"SIM115", # open file with context handler
"SIM117", # multiple with statements
"SIM118", # in-dict-keys — cleanup separately
"SIM212", # if-expr twisted arms
]
[tool.ruff.lint.per-file-ignores]
"batch_runner.py" = ["F821"]
"tools/patch_parser.py" = ["F821"]
"gateway/run.py" = ["F821"]
"gateway/channel_directory.py" = ["F401"]
"hermes_cli/doctor.py" = ["F401"]
"tools/image_generation_tool.py" = ["F401"]
[tool.ruff.lint.isort]
known-first-party = ["tools", "hermes_cli", "gateway", "agent", "cron"]
[tool.pytest.ini_options]
testpaths = ["tests"]
markers = [

File diff suppressed because it is too large Load Diff

View File

@@ -16,249 +16,222 @@ for the AI agent to access all capabilities.
"""
# Export all tools for easy importing
from .web_tools import (
web_search_tool,
web_extract_tool,
web_crawl_tool,
check_firecrawl_api_key
)
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
from .terminal_tool import (
terminal_tool,
check_terminal_requirements,
cleanup_vm,
cleanup_all_environments,
get_active_environments_info,
register_task_env_overrides,
clear_task_env_overrides,
TERMINAL_TOOL_DESCRIPTION
)
from .vision_tools import (
vision_analyze_tool,
check_vision_requirements
)
from .mixture_of_agents_tool import (
mixture_of_agents_tool,
check_moa_requirements
)
from .image_generation_tool import (
image_generate_tool,
check_image_generation_requirements
)
from .skills_tool import (
skills_list,
skill_view,
check_skills_requirements,
SKILLS_TOOL_DESCRIPTION
)
from .skill_manager_tool import (
skill_manage,
check_skill_manage_requirements,
SKILL_MANAGE_SCHEMA
)
# Browser automation tools (agent-browser + Browserbase)
from .browser_tool import (
browser_navigate,
browser_snapshot,
browser_click,
browser_type,
browser_scroll,
BROWSER_TOOL_SCHEMAS,
browser_back,
browser_press,
browser_click,
browser_close,
browser_get_images,
browser_navigate,
browser_press,
browser_scroll,
browser_snapshot,
browser_type,
browser_vision,
cleanup_browser,
cleanup_all_browsers,
get_active_browser_sessions,
check_browser_requirements,
BROWSER_TOOL_SCHEMAS
)
# Cronjob management tools (CLI-only, hermes-cli toolset)
from .cronjob_tools import (
schedule_cronjob,
list_cronjobs,
remove_cronjob,
check_cronjob_requirements,
get_cronjob_tool_definitions,
SCHEDULE_CRONJOB_SCHEMA,
LIST_CRONJOBS_SCHEMA,
REMOVE_CRONJOB_SCHEMA
)
# RL Training tools (Tinker-Atropos)
from .rl_training_tool import (
rl_list_environments,
rl_select_environment,
rl_get_current_config,
rl_edit_config,
rl_start_training,
rl_check_status,
rl_stop_training,
rl_get_results,
rl_list_runs,
rl_test_inference,
check_rl_api_keys,
get_missing_keys,
)
# File manipulation tools (read, write, patch, search)
from .file_tools import (
read_file_tool,
write_file_tool,
patch_tool,
search_tool,
get_file_tools,
clear_file_ops_cache,
)
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
from .tts_tool import (
text_to_speech_tool,
check_tts_requirements,
)
# Planning & task management tool
from .todo_tool import (
todo_tool,
check_todo_requirements,
TODO_SCHEMA,
TodoStore,
cleanup_all_browsers,
cleanup_browser,
get_active_browser_sessions,
)
# Clarifying questions tool (interactive Q&A with the user)
from .clarify_tool import (
clarify_tool,
check_clarify_requirements,
CLARIFY_SCHEMA,
check_clarify_requirements,
clarify_tool,
)
# Code execution sandbox (programmatic tool calling)
from .code_execution_tool import (
execute_code,
check_sandbox_requirements,
EXECUTE_CODE_SCHEMA,
check_sandbox_requirements,
execute_code,
)
# Cronjob management tools (CLI-only, hermes-cli toolset)
from .cronjob_tools import (
LIST_CRONJOBS_SCHEMA,
REMOVE_CRONJOB_SCHEMA,
SCHEDULE_CRONJOB_SCHEMA,
check_cronjob_requirements,
get_cronjob_tool_definitions,
list_cronjobs,
remove_cronjob,
schedule_cronjob,
)
# Subagent delegation (spawn child agents with isolated context)
from .delegate_tool import (
delegate_task,
check_delegate_requirements,
DELEGATE_TASK_SCHEMA,
check_delegate_requirements,
delegate_task,
)
# File manipulation tools (read, write, patch, search)
from .file_tools import (
clear_file_ops_cache,
get_file_tools,
patch_tool,
read_file_tool,
search_tool,
write_file_tool,
)
from .image_generation_tool import check_image_generation_requirements, image_generate_tool
from .mixture_of_agents_tool import check_moa_requirements, mixture_of_agents_tool
# RL Training tools (Tinker-Atropos)
from .rl_training_tool import (
check_rl_api_keys,
get_missing_keys,
rl_check_status,
rl_edit_config,
rl_get_current_config,
rl_get_results,
rl_list_environments,
rl_list_runs,
rl_select_environment,
rl_start_training,
rl_stop_training,
rl_test_inference,
)
from .skill_manager_tool import SKILL_MANAGE_SCHEMA, check_skill_manage_requirements, skill_manage
from .skills_tool import SKILLS_TOOL_DESCRIPTION, check_skills_requirements, skill_view, skills_list
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
from .terminal_tool import (
TERMINAL_TOOL_DESCRIPTION,
check_terminal_requirements,
cleanup_all_environments,
cleanup_vm,
clear_task_env_overrides,
get_active_environments_info,
register_task_env_overrides,
terminal_tool,
)
# Planning & task management tool
from .todo_tool import (
TODO_SCHEMA,
TodoStore,
check_todo_requirements,
todo_tool,
)
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
from .tts_tool import (
check_tts_requirements,
text_to_speech_tool,
)
from .vision_tools import check_vision_requirements, vision_analyze_tool
from .web_tools import check_firecrawl_api_key, web_crawl_tool, web_extract_tool, web_search_tool
# File tools have no external requirements - they use the terminal backend
def check_file_requirements():
"""File tools only require terminal backend to be available."""
from .terminal_tool import check_terminal_requirements
return check_terminal_requirements()
__all__ = [
# Web tools
'web_search_tool',
'web_extract_tool',
'web_crawl_tool',
'check_firecrawl_api_key',
"web_search_tool",
"web_extract_tool",
"web_crawl_tool",
"check_firecrawl_api_key",
# Terminal tools (mini-swe-agent backend)
'terminal_tool',
'check_terminal_requirements',
'cleanup_vm',
'cleanup_all_environments',
'get_active_environments_info',
'register_task_env_overrides',
'clear_task_env_overrides',
'TERMINAL_TOOL_DESCRIPTION',
"terminal_tool",
"check_terminal_requirements",
"cleanup_vm",
"cleanup_all_environments",
"get_active_environments_info",
"register_task_env_overrides",
"clear_task_env_overrides",
"TERMINAL_TOOL_DESCRIPTION",
# Vision tools
'vision_analyze_tool',
'check_vision_requirements',
"vision_analyze_tool",
"check_vision_requirements",
# MoA tools
'mixture_of_agents_tool',
'check_moa_requirements',
"mixture_of_agents_tool",
"check_moa_requirements",
# Image generation tools
'image_generate_tool',
'check_image_generation_requirements',
"image_generate_tool",
"check_image_generation_requirements",
# Skills tools
'skills_list',
'skill_view',
'check_skills_requirements',
'SKILLS_TOOL_DESCRIPTION',
"skills_list",
"skill_view",
"check_skills_requirements",
"SKILLS_TOOL_DESCRIPTION",
# Skill management
'skill_manage',
'check_skill_manage_requirements',
'SKILL_MANAGE_SCHEMA',
"skill_manage",
"check_skill_manage_requirements",
"SKILL_MANAGE_SCHEMA",
# Browser automation tools
'browser_navigate',
'browser_snapshot',
'browser_click',
'browser_type',
'browser_scroll',
'browser_back',
'browser_press',
'browser_close',
'browser_get_images',
'browser_vision',
'cleanup_browser',
'cleanup_all_browsers',
'get_active_browser_sessions',
'check_browser_requirements',
'BROWSER_TOOL_SCHEMAS',
"browser_navigate",
"browser_snapshot",
"browser_click",
"browser_type",
"browser_scroll",
"browser_back",
"browser_press",
"browser_close",
"browser_get_images",
"browser_vision",
"cleanup_browser",
"cleanup_all_browsers",
"get_active_browser_sessions",
"check_browser_requirements",
"BROWSER_TOOL_SCHEMAS",
# Cronjob management tools (CLI-only)
'schedule_cronjob',
'list_cronjobs',
'remove_cronjob',
'check_cronjob_requirements',
'get_cronjob_tool_definitions',
'SCHEDULE_CRONJOB_SCHEMA',
'LIST_CRONJOBS_SCHEMA',
'REMOVE_CRONJOB_SCHEMA',
"schedule_cronjob",
"list_cronjobs",
"remove_cronjob",
"check_cronjob_requirements",
"get_cronjob_tool_definitions",
"SCHEDULE_CRONJOB_SCHEMA",
"LIST_CRONJOBS_SCHEMA",
"REMOVE_CRONJOB_SCHEMA",
# RL Training tools
'rl_list_environments',
'rl_select_environment',
'rl_get_current_config',
'rl_edit_config',
'rl_start_training',
'rl_check_status',
'rl_stop_training',
'rl_get_results',
'rl_list_runs',
'rl_test_inference',
'check_rl_api_keys',
'get_missing_keys',
"rl_list_environments",
"rl_select_environment",
"rl_get_current_config",
"rl_edit_config",
"rl_start_training",
"rl_check_status",
"rl_stop_training",
"rl_get_results",
"rl_list_runs",
"rl_test_inference",
"check_rl_api_keys",
"get_missing_keys",
# File manipulation tools
'read_file_tool',
'write_file_tool',
'patch_tool',
'search_tool',
'get_file_tools',
'clear_file_ops_cache',
'check_file_requirements',
"read_file_tool",
"write_file_tool",
"patch_tool",
"search_tool",
"get_file_tools",
"clear_file_ops_cache",
"check_file_requirements",
# Text-to-speech tools
'text_to_speech_tool',
'check_tts_requirements',
"text_to_speech_tool",
"check_tts_requirements",
# Planning & task management tool
'todo_tool',
'check_todo_requirements',
'TODO_SCHEMA',
'TodoStore',
"todo_tool",
"check_todo_requirements",
"TODO_SCHEMA",
"TodoStore",
# Clarifying questions tool
'clarify_tool',
'check_clarify_requirements',
'CLARIFY_SCHEMA',
"clarify_tool",
"check_clarify_requirements",
"CLARIFY_SCHEMA",
# Code execution sandbox
'execute_code',
'check_sandbox_requirements',
'EXECUTE_CODE_SCHEMA',
"execute_code",
"check_sandbox_requirements",
"EXECUTE_CODE_SCHEMA",
# Subagent delegation
'delegate_task',
'check_delegate_requirements',
'DELEGATE_TASK_SCHEMA',
"delegate_task",
"check_delegate_requirements",
"DELEGATE_TASK_SCHEMA",
]

View File

@@ -12,7 +12,6 @@ import os
import re
import sys
import threading
from typing import Optional
logger = logging.getLogger(__name__)
@@ -21,32 +20,32 @@ logger = logging.getLogger(__name__)
# =========================================================================
DANGEROUS_PATTERNS = [
(r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"),
(r'\brm\s+-[^\s]*r', "recursive delete"),
(r'\brm\s+--recursive\b', "recursive delete (long flag)"),
(r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"),
(r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"),
(r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"),
(r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"),
(r'\bmkfs\b', "format filesystem"),
(r'\bdd\s+.*if=', "disk copy"),
(r'>\s*/dev/sd', "write to block device"),
(r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"),
(r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"),
(r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"),
(r'>\s*/etc/', "overwrite system config"),
(r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"),
(r'\bkill\s+-9\s+-1\b', "kill all processes"),
(r'\bpkill\s+-9\b', "force kill processes"),
(r':()\s*{\s*:\s*\|\s*:&\s*}\s*;:', "fork bomb"),
(r'\b(bash|sh|zsh)\s+-c\s+', "shell command via -c flag"),
(r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"),
(r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"),
(r'\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b', "execute remote script via process substitution"),
(r'\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)', "overwrite system file via tee"),
(r'\bxargs\s+.*\brm\b', "xargs with rm"),
(r'\bfind\b.*-exec\s+(/\S*/)?rm\b', "find -exec rm"),
(r'\bfind\b.*-delete\b', "find -delete"),
(r"\brm\s+(-[^\s]*\s+)*/", "delete in root path"),
(r"\brm\s+-[^\s]*r", "recursive delete"),
(r"\brm\s+--recursive\b", "recursive delete (long flag)"),
(r"\bchmod\s+(-[^\s]*\s+)*777\b", "world-writable permissions"),
(r"\bchmod\s+--recursive\b.*777", "recursive world-writable (long flag)"),
(r"\bchown\s+(-[^\s]*)?R\s+root", "recursive chown to root"),
(r"\bchown\s+--recursive\b.*root", "recursive chown to root (long flag)"),
(r"\bmkfs\b", "format filesystem"),
(r"\bdd\s+.*if=", "disk copy"),
(r">\s*/dev/sd", "write to block device"),
(r"\bDROP\s+(TABLE|DATABASE)\b", "SQL DROP"),
(r"\bDELETE\s+FROM\b(?!.*\bWHERE\b)", "SQL DELETE without WHERE"),
(r"\bTRUNCATE\s+(TABLE)?\s*\w", "SQL TRUNCATE"),
(r">\s*/etc/", "overwrite system config"),
(r"\bsystemctl\s+(stop|disable|mask)\b", "stop/disable system service"),
(r"\bkill\s+-9\s+-1\b", "kill all processes"),
(r"\bpkill\s+-9\b", "force kill processes"),
(r":()\s*{\s*:\s*\|\s*:&\s*}\s*;:", "fork bomb"),
(r"\b(bash|sh|zsh)\s+-c\s+", "shell command via -c flag"),
(r"\b(python[23]?|perl|ruby|node)\s+-[ec]\s+", "script execution via -e/-c flag"),
(r"\b(curl|wget)\b.*\|\s*(ba)?sh\b", "pipe remote content to shell"),
(r"\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b", "execute remote script via process substitution"),
(r"\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)", "overwrite system file via tee"),
(r"\bxargs\s+.*\brm\b", "xargs with rm"),
(r"\bfind\b.*-exec\s+(/\S*/)?rm\b", "find -exec rm"),
(r"\bfind\b.*-delete\b", "find -delete"),
]
@@ -54,6 +53,7 @@ DANGEROUS_PATTERNS = [
# Detection
# =========================================================================
def detect_dangerous_command(command: str) -> tuple:
"""Check if a command matches any dangerous patterns.
@@ -63,7 +63,7 @@ def detect_dangerous_command(command: str) -> tuple:
command_lower = command.lower()
for pattern, description in DANGEROUS_PATTERNS:
if re.search(pattern, command_lower, re.IGNORECASE | re.DOTALL):
pattern_key = pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20]
pattern_key = pattern.split(r"\b")[1] if r"\b" in pattern else pattern[:20]
return (True, pattern_key, description)
return (False, None, None)
@@ -84,7 +84,7 @@ def submit_pending(session_key: str, approval: dict):
_pending[session_key] = approval
def pop_pending(session_key: str) -> Optional[dict]:
def pop_pending(session_key: str) -> dict | None:
"""Retrieve and remove a pending approval for a session."""
with _lock:
return _pending.pop(session_key, None)
@@ -133,6 +133,7 @@ def clear_session(session_key: str):
# Config persistence for permanent allowlist
# =========================================================================
def load_permanent_allowlist() -> set:
"""Load permanently allowed command patterns from config.
@@ -141,6 +142,7 @@ def load_permanent_allowlist() -> set:
"""
try:
from hermes_cli.config import load_config
config = load_config()
patterns = set(config.get("command_allowlist", []) or [])
if patterns:
@@ -154,6 +156,7 @@ def save_permanent_allowlist(patterns: set):
"""Save permanently allowed command patterns to config."""
try:
from hermes_cli.config import load_config, save_config
config = load_config()
config["command_allowlist"] = list(patterns)
save_config(config)
@@ -165,9 +168,8 @@ def save_permanent_allowlist(patterns: set):
# Approval prompting + orchestration
# =========================================================================
def prompt_dangerous_approval(command: str, description: str,
timeout_seconds: int = 60,
approval_callback=None) -> str:
def prompt_dangerous_approval(command: str, description: str, timeout_seconds: int = 60, approval_callback=None) -> str:
"""Prompt the user to approve a dangerous command (CLI only).
Args:
@@ -188,7 +190,7 @@ def prompt_dangerous_approval(command: str, description: str,
print(f" ⚠️ DANGEROUS COMMAND: {description}")
print(f" {command[:80]}{'...' if len(command) > 80 else ''}")
print()
print(f" [o]nce | [s]ession | [a]lways | [d]eny")
print(" [o]nce | [s]ession | [a]lways | [d]eny")
print()
sys.stdout.flush()
@@ -209,13 +211,13 @@ def prompt_dangerous_approval(command: str, description: str,
return "deny"
choice = result["choice"]
if choice in ('o', 'once'):
if choice in ("o", "once"):
print(" ✓ Allowed once")
return "once"
elif choice in ('s', 'session'):
elif choice in ("s", "session"):
print(" ✓ Allowed for this session")
return "session"
elif choice in ('a', 'always'):
elif choice in ("a", "always"):
print(" ✓ Added to permanent allowlist")
return "always"
else:
@@ -232,8 +234,7 @@ def prompt_dangerous_approval(command: str, description: str,
sys.stdout.flush()
def check_dangerous_command(command: str, env_type: str,
approval_callback=None) -> dict:
def check_dangerous_command(command: str, env_type: str, approval_callback=None) -> dict:
"""Check if a command is dangerous and handle approval.
This is the main entry point called by terminal_tool before executing
@@ -265,11 +266,14 @@ def check_dangerous_command(command: str, env_type: str,
return {"approved": True, "message": None}
if is_gateway or os.getenv("HERMES_EXEC_ASK"):
submit_pending(session_key, {
"command": command,
"pattern_key": pattern_key,
"description": description,
})
submit_pending(
session_key,
{
"command": command,
"pattern_key": pattern_key,
"description": description,
},
)
return {
"approved": False,
"pattern_key": pattern_key,
@@ -279,8 +283,7 @@ def check_dangerous_command(command: str, env_type: str,
"message": f"⚠️ This command is potentially dangerous ({description}). Asking the user for approval...",
}
choice = prompt_dangerous_approval(command, description,
approval_callback=approval_callback)
choice = prompt_dangerous_approval(command, description, approval_callback=approval_callback)
if choice == "deny":
return {

File diff suppressed because it is too large Load Diff

View File

@@ -12,8 +12,7 @@ a thin dispatcher that delegates to a platform-provided callback.
"""
import json
from typing import Dict, Any, List, Optional, Callable
from collections.abc import Callable
# Maximum number of predefined choices the agent can offer.
# A 5th "Other (type your answer)" option is always appended by the UI.
@@ -22,8 +21,8 @@ MAX_CHOICES = 4
def clarify_tool(
question: str,
choices: Optional[List[str]] = None,
callback: Optional[Callable] = None,
choices: list[str] | None = None,
callback: Callable | None = None,
) -> str:
"""
Ask the user a question, optionally with multiple-choice options.
@@ -68,11 +67,14 @@ def clarify_tool(
ensure_ascii=False,
)
return json.dumps({
"question": question,
"choices_offered": choices,
"user_response": str(user_response).strip(),
}, ensure_ascii=False)
return json.dumps(
{
"question": question,
"choices_offered": choices,
"user_response": str(user_response).strip(),
},
ensure_ascii=False,
)
def check_clarify_requirements() -> bool:
@@ -133,8 +135,7 @@ registry.register(
toolset="clarify",
schema=CLARIFY_SCHEMA,
handler=lambda args, **kw: clarify_tool(
question=args.get("question", ""),
choices=args.get("choices"),
callback=kw.get("callback")),
question=args.get("question", ""), choices=args.get("choices"), callback=kw.get("callback")
),
check_fn=check_clarify_requirements,
)

View File

@@ -31,7 +31,7 @@ import time
import uuid
_IS_WINDOWS = platform.system() == "Windows"
from typing import Any, Dict, List, Optional
from typing import Any
# Availability gate: UDS requires a POSIX OS
logger = logging.getLogger(__name__)
@@ -40,21 +40,23 @@ SANDBOX_AVAILABLE = sys.platform != "win32"
# The 7 tools allowed inside the sandbox. The intersection of this list
# and the session's enabled tools determines which stubs are generated.
SANDBOX_ALLOWED_TOOLS = frozenset([
"web_search",
"web_extract",
"read_file",
"write_file",
"search_files",
"patch",
"terminal",
])
SANDBOX_ALLOWED_TOOLS = frozenset(
[
"web_search",
"web_extract",
"read_file",
"write_file",
"search_files",
"patch",
"terminal",
]
)
# Resource limit defaults (overridable via config.yaml → code_execution.*)
DEFAULT_TIMEOUT = 300 # 5 minutes
DEFAULT_TIMEOUT = 300 # 5 minutes
DEFAULT_MAX_TOOL_CALLS = 50
MAX_STDOUT_BYTES = 50_000 # 50 KB
MAX_STDERR_BYTES = 10_000 # 10 KB
MAX_STDOUT_BYTES = 50_000 # 50 KB
MAX_STDERR_BYTES = 10_000 # 10 KB
def check_sandbox_requirements() -> bool:
@@ -114,7 +116,7 @@ _TOOL_STUBS = {
}
def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
def generate_hermes_tools_module(enabled_tools: list[str]) -> str:
"""
Build the source code for the hermes_tools.py stub module.
@@ -128,11 +130,7 @@ def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
if tool_name not in _TOOL_STUBS:
continue
func_name, sig, doc, args_expr = _TOOL_STUBS[tool_name]
stub_functions.append(
f"def {func_name}({sig}):\n"
f" {doc}\n"
f" return _call({func_name!r}, {args_expr})\n"
)
stub_functions.append(f"def {func_name}({sig}):\n {doc}\n return _call({func_name!r}, {args_expr})\n")
export_names.append(func_name)
header = '''\
@@ -223,7 +221,7 @@ def _rpc_server_loop(
server_sock: socket.socket,
task_id: str,
tool_call_log: list,
tool_call_counter: list, # mutable [int] so the thread can increment
tool_call_counter: list, # mutable [int] so the thread can increment
max_tool_calls: int,
allowed_tools: frozenset,
):
@@ -243,7 +241,7 @@ def _rpc_server_loop(
while True:
try:
chunk = conn.recv(65536)
except socket.timeout:
except TimeoutError:
break
if not chunk:
break
@@ -270,23 +268,22 @@ def _rpc_server_loop(
# Enforce the allow-list
if tool_name not in allowed_tools:
available = ", ".join(sorted(allowed_tools))
resp = json.dumps({
"error": (
f"Tool '{tool_name}' is not available in execute_code. "
f"Available: {available}"
)
})
resp = json.dumps(
{"error": (f"Tool '{tool_name}' is not available in execute_code. Available: {available}")}
)
conn.sendall((resp + "\n").encode())
continue
# Enforce tool call limit
if tool_call_counter[0] >= max_tool_calls:
resp = json.dumps({
"error": (
f"Tool call limit reached ({max_tool_calls}). "
"No more tool calls allowed in this execution."
)
})
resp = json.dumps(
{
"error": (
f"Tool call limit reached ({max_tool_calls}). "
"No more tool calls allowed in this execution."
)
}
)
conn.sendall((resp + "\n").encode())
continue
@@ -303,9 +300,7 @@ def _rpc_server_loop(
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
try:
result = handle_function_call(
tool_name, tool_args, task_id=task_id
)
result = handle_function_call(tool_name, tool_args, task_id=task_id)
finally:
sys.stdout.close()
sys.stderr.close()
@@ -318,15 +313,17 @@ def _rpc_server_loop(
# Log for observability
args_preview = str(tool_args)[:80]
tool_call_log.append({
"tool": tool_name,
"args_preview": args_preview,
"duration": round(call_duration, 2),
})
tool_call_log.append(
{
"tool": tool_name,
"args_preview": args_preview,
"duration": round(call_duration, 2),
}
)
conn.sendall((result + "\n").encode())
except socket.timeout:
except TimeoutError:
pass
except OSError:
pass
@@ -342,10 +339,11 @@ def _rpc_server_loop(
# Main entry point
# ---------------------------------------------------------------------------
def execute_code(
code: str,
task_id: Optional[str] = None,
enabled_tools: Optional[List[str]] = None,
task_id: str | None = None,
enabled_tools: list[str] | None = None,
) -> str:
"""
Run a Python script in a sandboxed child process with RPC access
@@ -361,9 +359,7 @@ def execute_code(
JSON string with execution results.
"""
if not SANDBOX_AVAILABLE:
return json.dumps({
"error": "execute_code is not available on Windows. Use normal tool calls instead."
})
return json.dumps({"error": "execute_code is not available on Windows. Use normal tool calls instead."})
if not code or not code.strip():
return json.dumps({"error": "No code provided."})
@@ -397,9 +393,7 @@ def execute_code(
try:
# Write the auto-generated hermes_tools module
tools_src = generate_hermes_tools_module(
list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS)
)
tools_src = generate_hermes_tools_module(list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS))
with open(os.path.join(tmpdir, "hermes_tools.py"), "w") as f:
f.write(tools_src)
@@ -415,8 +409,12 @@ def execute_code(
rpc_thread = threading.Thread(
target=_rpc_server_loop,
args=(
server_sock, task_id, tool_call_log,
tool_call_counter, max_tool_calls, sandbox_tools,
server_sock,
task_id,
tool_call_log,
tool_call_counter,
max_tool_calls,
sandbox_tools,
),
daemon=True,
)
@@ -426,11 +424,24 @@ def execute_code(
# Build a minimal environment for the child. We intentionally exclude
# API keys and tokens to prevent credential exfiltration from LLM-
# generated scripts. The child accesses tools via RPC, not direct API.
_SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM",
"TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME",
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA")
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL",
"PASSWD", "AUTH")
_SAFE_ENV_PREFIXES = (
"PATH",
"HOME",
"USER",
"LANG",
"LC_",
"TERM",
"TMPDIR",
"TMP",
"TEMP",
"SHELL",
"LOGNAME",
"XDG_",
"PYTHONPATH",
"VIRTUAL_ENV",
"CONDA",
)
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL", "PASSWD", "AUTH")
child_env = {}
for k, v in os.environ.items():
if any(s in k.upper() for s in _SECRET_SUBSTRINGS):
@@ -515,7 +526,7 @@ def execute_code(
rpc_thread.join(timeout=3)
# Build response
result: Dict[str, Any] = {
result: dict[str, Any] = {
"status": status,
"output": stdout_text,
"tool_calls_made": tool_call_counter[0],
@@ -538,17 +549,21 @@ def execute_code(
except Exception as exc:
duration = round(time.monotonic() - exec_start, 2)
logging.exception("execute_code failed")
return json.dumps({
"status": "error",
"error": str(exc),
"tool_calls_made": tool_call_counter[0],
"duration_seconds": duration,
}, ensure_ascii=False)
return json.dumps(
{
"status": "error",
"error": str(exc),
"tool_calls_made": tool_call_counter[0],
"duration_seconds": duration,
},
ensure_ascii=False,
)
finally:
# Cleanup temp dir and socket
try:
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
except Exception as e:
logger.debug("Could not clean temp dir: %s", e)
@@ -592,6 +607,7 @@ def _load_config() -> dict:
"""Load code_execution config from CLI_CONFIG if available."""
try:
from cli import CLI_CONFIG
return CLI_CONFIG.get("code_execution", {})
except Exception:
return {}
@@ -604,27 +620,37 @@ def _load_config() -> dict:
# Per-tool documentation lines for the execute_code description.
# Ordered to match the canonical display order.
_TOOL_DOC_LINES = [
("web_search",
" web_search(query: str, limit: int = 5) -> dict\n"
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"),
("web_extract",
" web_extract(urls: list[str]) -> dict\n"
" Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"),
("read_file",
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"),
("write_file",
" write_file(path: str, content: str) -> dict\n"
" Always overwrites the entire file."),
("search_files",
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"),
("patch",
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
" Replaces old_string with new_string in the file."),
("terminal",
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"),
(
"web_search",
" web_search(query: str, limit: int = 5) -> dict\n"
' Returns {"data": {"web": [{"url", "title", "description"}, ...]}}',
),
(
"web_extract",
" web_extract(urls: list[str]) -> dict\n"
' Returns {"results": [{"url", "title", "content", "error"}, ...]} where content is markdown',
),
(
"read_file",
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
' Lines are 1-indexed. Returns {"content": "...", "total_lines": N}',
),
("write_file", " write_file(path: str, content: str) -> dict\n Always overwrites the entire file."),
(
"search_files",
' search_files(pattern: str, target="content", path=".", file_glob=None, limit=50) -> dict\n'
' target: "content" (search inside files) or "files" (find files by name). Returns {"matches": [...]}',
),
(
"patch",
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
" Replaces old_string with new_string in the file.",
),
(
"terminal",
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
' Foreground only (no background/pty). Returns {"output": "...", "exit_code": N}',
),
]
@@ -639,9 +665,7 @@ def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict:
enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS
# Build tool documentation lines for only the enabled tools
tool_lines = "\n".join(
doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools
)
tool_lines = "\n".join(doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools)
# Build example import list from enabled tools
import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools]
@@ -702,8 +726,7 @@ registry.register(
toolset="code_execution",
schema=EXECUTE_CODE_SCHEMA,
handler=lambda args, **kw: execute_code(
code=args.get("code", ""),
task_id=kw.get("task_id"),
enabled_tools=kw.get("enabled_tools")),
code=args.get("code", ""), task_id=kw.get("task_id"), enabled_tools=kw.get("enabled_tools")
),
check_fn=check_sandbox_requirements,
)

View File

@@ -11,37 +11,44 @@ The prompt must contain ALL necessary information.
import json
import os
import re
from typing import Optional
# Import from cron module (will be available when properly installed)
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from cron.jobs import create_job, get_job, list_jobs, remove_job
# ---------------------------------------------------------------------------
# Cron prompt scanning — critical-severity patterns only, since cron prompts
# run in fresh sessions with full tool access.
# ---------------------------------------------------------------------------
_CRON_THREAT_PATTERNS = [
(r'ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions', "prompt_injection"),
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
(r'system\s+prompt\s+override', "sys_prompt_override"),
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
(r'authorized_keys', "ssh_backdoor"),
(r'/etc/sudoers|visudo', "sudoers_mod"),
(r'rm\s+-rf\s+/', "destructive_root_rm"),
(r"ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions", "prompt_injection"),
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
(r"system\s+prompt\s+override", "sys_prompt_override"),
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
(r"authorized_keys", "ssh_backdoor"),
(r"/etc/sudoers|visudo", "sudoers_mod"),
(r"rm\s+-rf\s+/", "destructive_root_rm"),
]
_CRON_INVISIBLE_CHARS = {
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
"\u200b",
"\u200c",
"\u200d",
"\u2060",
"\ufeff",
"\u202a",
"\u202b",
"\u202c",
"\u202d",
"\u202e",
}
@@ -60,17 +67,18 @@ def _scan_cron_prompt(prompt: str) -> str:
# Tool: schedule_cronjob
# =============================================================================
def schedule_cronjob(
prompt: str,
schedule: str,
name: Optional[str] = None,
repeat: Optional[int] = None,
deliver: Optional[str] = None,
task_id: str = None
name: str | None = None,
repeat: int | None = None,
deliver: str | None = None,
task_id: str = None,
) -> str:
"""
Schedule an automated task to run the agent on a schedule.
IMPORTANT: When the cronjob runs, it starts a COMPLETELY FRESH session.
The agent will have NO memory of this conversation or any prior context.
Therefore, the prompt MUST contain ALL necessary information:
@@ -78,12 +86,12 @@ def schedule_cronjob(
- Specific file paths, URLs, or identifiers
- Clear success criteria
- Any relevant background information
BAD prompt: "Check on that server issue"
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
is running with 'systemctl status nginx', and verify the site
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
is running with 'systemctl status nginx', and verify the site
https://example.com returns HTTP 200. Report any issues found."
Args:
prompt: Complete, self-contained instructions for the future agent.
Must include ALL context needed - the agent won't remember anything.
@@ -105,7 +113,7 @@ def schedule_cronjob(
- "signal": Send to Signal home channel
- "telegram:123456": Send to specific chat ID
- "signal:+15551234567": Send to specific Signal number
Returns:
JSON with job_id, next_run time, and confirmation
"""
@@ -124,17 +132,10 @@ def schedule_cronjob(
"chat_id": origin_chat_id,
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
}
try:
job = create_job(
prompt=prompt,
schedule=schedule,
name=name,
repeat=repeat,
deliver=deliver,
origin=origin
)
job = create_job(prompt=prompt, schedule=schedule, name=name, repeat=repeat, deliver=deliver, origin=origin)
# Format repeat info for display
times = job["repeat"].get("times")
if times is None:
@@ -143,23 +144,23 @@ def schedule_cronjob(
repeat_display = "once"
else:
repeat_display = f"{times} times"
return json.dumps({
"success": True,
"job_id": job["id"],
"name": job["name"],
"schedule": job["schedule_display"],
"repeat": repeat_display,
"deliver": job.get("deliver", "local"),
"next_run_at": job["next_run_at"],
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}."
}, indent=2)
return json.dumps(
{
"success": True,
"job_id": job["id"],
"name": job["name"],
"schedule": job["schedule_display"],
"repeat": repeat_display,
"deliver": job.get("deliver", "local"),
"next_run_at": job["next_run_at"],
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}.",
},
indent=2,
)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, indent=2)
return json.dumps({"success": False, "error": str(e)}, indent=2)
SCHEDULE_CRONJOB_SCHEMA = {
@@ -177,7 +178,7 @@ The future agent will NOT remember anything from the current conversation.
SCHEDULE FORMATS:
- One-shot: "30m", "2h", "1d" (runs once after delay)
- Interval: "every 30m", "every 2h" (recurring)
- Interval: "every 30m", "every 2h" (recurring)
- Cron: "0 9 * * *" (cron expression for precise scheduling)
- Timestamp: "2026-02-03T14:00:00" (specific date/time)
@@ -202,27 +203,24 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
"properties": {
"prompt": {
"type": "string",
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation."
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation.",
},
"schedule": {
"type": "string",
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp"
},
"name": {
"type": "string",
"description": "Optional human-friendly name for the job"
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp",
},
"name": {"type": "string", "description": "Optional human-friendly name for the job"},
"repeat": {
"type": "integer",
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs."
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs.",
},
"deliver": {
"type": "string",
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'"
}
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'",
},
},
"required": ["prompt", "schedule"]
}
"required": ["prompt", "schedule"],
},
}
@@ -230,10 +228,11 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
# Tool: list_cronjobs
# =============================================================================
def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
"""
List all scheduled cronjobs.
Returns information about each job including:
- Job ID (needed for removal)
- Name
@@ -241,16 +240,16 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
- Repeat status (completed/total or 'forever')
- Next scheduled run time
- Last run time and status (if any)
Args:
include_disabled: Whether to include disabled/completed jobs
Returns:
JSON array of all scheduled jobs
"""
try:
jobs = list_jobs(include_disabled=include_disabled)
formatted_jobs = []
for job in jobs:
# Format repeat status
@@ -260,31 +259,26 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
repeat_status = "forever"
else:
repeat_status = f"{completed}/{times}"
formatted_jobs.append({
"job_id": job["id"],
"name": job["name"],
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
"schedule": job["schedule_display"],
"repeat": repeat_status,
"deliver": job.get("deliver", "local"),
"next_run_at": job.get("next_run_at"),
"last_run_at": job.get("last_run_at"),
"last_status": job.get("last_status"),
"enabled": job.get("enabled", True)
})
return json.dumps({
"success": True,
"count": len(formatted_jobs),
"jobs": formatted_jobs
}, indent=2)
formatted_jobs.append(
{
"job_id": job["id"],
"name": job["name"],
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
"schedule": job["schedule_display"],
"repeat": repeat_status,
"deliver": job.get("deliver", "local"),
"next_run_at": job.get("next_run_at"),
"last_run_at": job.get("last_run_at"),
"last_status": job.get("last_status"),
"enabled": job.get("enabled", True),
}
)
return json.dumps({"success": True, "count": len(formatted_jobs), "jobs": formatted_jobs}, indent=2)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, indent=2)
return json.dumps({"success": False, "error": str(e)}, indent=2)
LIST_CRONJOBS_SCHEMA = {
@@ -302,11 +296,11 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
"properties": {
"include_disabled": {
"type": "boolean",
"description": "Include disabled/completed jobs in the list (default: false)"
"description": "Include disabled/completed jobs in the list (default: false)",
}
},
"required": []
}
"required": [],
},
}
@@ -314,48 +308,45 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
# Tool: remove_cronjob
# =============================================================================
def remove_cronjob(job_id: str, task_id: str = None) -> str:
"""
Remove a scheduled cronjob by its ID.
Use list_cronjobs first to find the job_id of the job you want to remove.
Args:
job_id: The ID of the job to remove (from list_cronjobs output)
Returns:
JSON confirmation of removal
"""
try:
job = get_job(job_id)
if not job:
return json.dumps({
"success": False,
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs."
}, indent=2)
return json.dumps(
{
"success": False,
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs.",
},
indent=2,
)
removed = remove_job(job_id)
if removed:
return json.dumps({
"success": True,
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
"removed_job": {
"id": job_id,
"name": job["name"],
"schedule": job["schedule_display"]
}
}, indent=2)
return json.dumps(
{
"success": True,
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
"removed_job": {"id": job_id, "name": job["name"], "schedule": job["schedule_display"]},
},
indent=2,
)
else:
return json.dumps({
"success": False,
"error": f"Failed to remove job '{job_id}'"
}, indent=2)
return json.dumps({"success": False, "error": f"Failed to remove job '{job_id}'"}, indent=2)
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
}, indent=2)
return json.dumps({"success": False, "error": str(e)}, indent=2)
REMOVE_CRONJOB_SCHEMA = {
@@ -368,13 +359,10 @@ use this to cancel a job before it completes.""",
"parameters": {
"type": "object",
"properties": {
"job_id": {
"type": "string",
"description": "The ID of the cronjob to remove (from list_cronjobs output)"
}
"job_id": {"type": "string", "description": "The ID of the cronjob to remove (from list_cronjobs output)"}
},
"required": ["job_id"]
}
"required": ["job_id"],
},
}
@@ -382,44 +370,34 @@ use this to cancel a job before it completes.""",
# Requirements check
# =============================================================================
def check_cronjob_requirements() -> bool:
"""
Check if cronjob tools can be used.
Available in interactive CLI mode and gateway/messaging platforms.
Cronjobs are server-side scheduled tasks so they work from any interface.
"""
return bool(
os.getenv("HERMES_INTERACTIVE")
or os.getenv("HERMES_GATEWAY_SESSION")
or os.getenv("HERMES_EXEC_ASK")
)
return bool(os.getenv("HERMES_INTERACTIVE") or os.getenv("HERMES_GATEWAY_SESSION") or os.getenv("HERMES_EXEC_ASK"))
# =============================================================================
# Exports
# =============================================================================
def get_cronjob_tool_definitions():
"""Return tool definitions for cronjob management."""
return [
SCHEDULE_CRONJOB_SCHEMA,
LIST_CRONJOBS_SCHEMA,
REMOVE_CRONJOB_SCHEMA
]
return [SCHEDULE_CRONJOB_SCHEMA, LIST_CRONJOBS_SCHEMA, REMOVE_CRONJOB_SCHEMA]
# For direct testing
if __name__ == "__main__":
# Test the tools
print("Testing schedule_cronjob:")
result = schedule_cronjob(
prompt="Test prompt for cron job",
schedule="5m",
name="Test Job"
)
result = schedule_cronjob(prompt="Test prompt for cron job", schedule="5m", name="Test Job")
print(result)
print("\nTesting list_cronjobs:")
result = list_cronjobs()
print(result)
@@ -438,7 +416,8 @@ registry.register(
name=args.get("name"),
repeat=args.get("repeat"),
deliver=args.get("deliver"),
task_id=kw.get("task_id")),
task_id=kw.get("task_id"),
),
check_fn=check_cronjob_requirements,
)
registry.register(
@@ -446,16 +425,14 @@ registry.register(
toolset="cronjob",
schema=LIST_CRONJOBS_SCHEMA,
handler=lambda args, **kw: list_cronjobs(
include_disabled=args.get("include_disabled", False),
task_id=kw.get("task_id")),
include_disabled=args.get("include_disabled", False), task_id=kw.get("task_id")
),
check_fn=check_cronjob_requirements,
)
registry.register(
name="remove_cronjob",
toolset="cronjob",
schema=REMOVE_CRONJOB_SCHEMA,
handler=lambda args, **kw: remove_cronjob(
job_id=args.get("job_id", ""),
task_id=kw.get("task_id")),
handler=lambda args, **kw: remove_cronjob(job_id=args.get("job_id", ""), task_id=kw.get("task_id")),
check_fn=check_cronjob_requirements,
)

View File

@@ -27,7 +27,7 @@ import logging
import os
import uuid
from pathlib import Path
from typing import Any, Dict
from typing import Any
logger = logging.getLogger(__name__)
@@ -44,27 +44,28 @@ class DebugSession:
self.enabled = os.getenv(env_var, "false").lower() == "true"
self.session_id = str(uuid.uuid4()) if self.enabled else ""
self.log_dir = Path("./logs")
self._calls: list[Dict[str, Any]] = []
self._calls: list[dict[str, Any]] = []
self._start_time = datetime.datetime.now().isoformat() if self.enabled else ""
if self.enabled:
self.log_dir.mkdir(exist_ok=True)
logger.debug("%s debug mode enabled - Session ID: %s",
tool_name, self.session_id)
logger.debug("%s debug mode enabled - Session ID: %s", tool_name, self.session_id)
@property
def active(self) -> bool:
return self.enabled
def log_call(self, call_name: str, call_data: Dict[str, Any]) -> None:
def log_call(self, call_name: str, call_data: dict[str, Any]) -> None:
"""Append a tool-call entry to the in-memory log."""
if not self.enabled:
return
self._calls.append({
"timestamp": datetime.datetime.now().isoformat(),
"tool_name": call_name,
**call_data,
})
self._calls.append(
{
"timestamp": datetime.datetime.now().isoformat(),
"tool_name": call_name,
**call_data,
}
)
def save(self) -> None:
"""Flush the in-memory log to a JSON file in the logs directory."""
@@ -87,7 +88,7 @@ class DebugSession:
except Exception as e:
logger.error("Error saving %s debug log: %s", self.tool_name, e)
def get_session_info(self) -> Dict[str, Any]:
def get_session_info(self) -> dict[str, Any]:
"""Return a summary dict suitable for returning from get_debug_session_info()."""
if not self.enabled:
return {

View File

@@ -20,21 +20,22 @@ import contextlib
import io
import json
import logging
import os
import sys
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
from typing import Any
# Tools that children must never have access to
DELEGATE_BLOCKED_TOOLS = frozenset([
"delegate_task", # no recursive delegation
"clarify", # no user interaction
"memory", # no writes to shared MEMORY.md
"send_message", # no cross-platform side effects
"execute_code", # children should reason step-by-step, not write scripts
])
DELEGATE_BLOCKED_TOOLS = frozenset(
[
"delegate_task", # no recursive delegation
"clarify", # no user interaction
"memory", # no writes to shared MEMORY.md
"send_message", # no cross-platform side effects
"execute_code", # children should reason step-by-step, not write scripts
]
)
MAX_CONCURRENT_CHILDREN = 3
MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2)
@@ -47,7 +48,7 @@ def check_delegate_requirements() -> bool:
return True
def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
def _build_child_system_prompt(goal: str, context: str | None = None) -> str:
"""Build a focused system prompt for a child agent."""
parts = [
"You are a focused subagent working on a specific delegated task.",
@@ -69,15 +70,18 @@ def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
return "\n".join(parts)
def _strip_blocked_tools(toolsets: List[str]) -> List[str]:
def _strip_blocked_tools(toolsets: list[str]) -> list[str]:
"""Remove toolsets that contain only blocked tools."""
blocked_toolset_names = {
"delegation", "clarify", "memory", "code_execution",
"delegation",
"clarify",
"memory",
"code_execution",
}
return [t for t in toolsets if t not in blocked_toolset_names]
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Optional[callable]:
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Callable | None:
"""Build a callback that relays child agent tool calls to the parent display.
Two display paths:
@@ -87,8 +91,8 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
Returns None if no display mechanism is available, in which case the
child agent runs with no progress callback (identical to current behavior).
"""
spinner = getattr(parent_agent, '_delegate_spinner', None)
parent_cb = getattr(parent_agent, 'tool_progress_callback', None)
spinner = getattr(parent_agent, "_delegate_spinner", None)
parent_cb = getattr(parent_agent, "tool_progress_callback", None)
if not spinner and not parent_cb:
return None # No display → no callback → zero behavior change
@@ -98,7 +102,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
# Gateway: batch tool names, flush periodically
_BATCH_SIZE = 5
_batch: List[str] = []
_batch: list[str] = []
def _callback(tool_name: str, preview: str = None):
# Special "_thinking" event: model produced text content (reasoning)
@@ -106,7 +110,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
if spinner:
short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "")
try:
spinner.print_above(f" {prefix}├─ 💭 \"{short}\"")
spinner.print_above(f' {prefix}├─ 💭 "{short}"')
except Exception:
pass
# Don't relay thinking to gateway (too noisy for chat)
@@ -116,17 +120,25 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
if spinner:
short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "")
tool_emojis = {
"terminal": "💻", "web_search": "🔍", "web_extract": "📄",
"read_file": "📖", "write_file": "✍️", "patch": "🔧",
"search_files": "🔎", "list_directory": "📂",
"browser_navigate": "🌐", "browser_click": "👆",
"text_to_speech": "🔊", "image_generate": "🎨",
"vision_analyze": "👁️", "process": "⚙️",
"terminal": "💻",
"web_search": "🔍",
"web_extract": "📄",
"read_file": "📖",
"write_file": "✍️",
"patch": "🔧",
"search_files": "🔎",
"list_directory": "📂",
"browser_navigate": "🌐",
"browser_click": "👆",
"text_to_speech": "🔊",
"image_generate": "🎨",
"vision_analyze": "👁️",
"process": "⚙️",
}
emoji = tool_emojis.get(tool_name, "")
line = f" {prefix}├─ {emoji} {tool_name}"
if short:
line += f" \"{short}\""
line += f' "{short}"'
try:
spinner.print_above(line)
except Exception:
@@ -159,13 +171,13 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
def _run_single_child(
task_index: int,
goal: str,
context: Optional[str],
toolsets: Optional[List[str]],
model: Optional[str],
context: str | None,
toolsets: list[str] | None,
model: str | None,
max_iterations: int,
parent_agent,
task_count: int = 1,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Spawn and run a single child agent. Called from within a thread.
Returns a structured result dict.
@@ -216,7 +228,7 @@ def _run_single_child(
skip_context_files=True,
skip_memory=True,
clarify_callback=None,
session_db=getattr(parent_agent, '_session_db', None),
session_db=getattr(parent_agent, "_session_db", None),
providers_allowed=parent_agent.providers_allowed,
providers_ignored=parent_agent.providers_ignored,
providers_order=parent_agent.providers_order,
@@ -226,10 +238,10 @@ def _run_single_child(
)
# Set delegation depth so children can't spawn grandchildren
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
child._delegate_depth = getattr(parent_agent, "_delegate_depth", 0) + 1
# Register child for interrupt propagation
if hasattr(parent_agent, '_active_children'):
if hasattr(parent_agent, "_active_children"):
parent_agent._active_children.append(child)
# Run with stdout/stderr suppressed to prevent interleaved output
@@ -238,7 +250,7 @@ def _run_single_child(
result = child.run_conversation(user_message=goal)
# Flush any remaining batched progress to gateway
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
if child_progress_cb and hasattr(child_progress_cb, "_flush"):
try:
child_progress_cb._flush()
except Exception:
@@ -258,7 +270,7 @@ def _run_single_child(
else:
status = "failed"
entry: Dict[str, Any] = {
entry: dict[str, Any] = {
"task_index": task_index,
"status": status,
"summary": summary,
@@ -284,7 +296,7 @@ def _run_single_child(
finally:
# Unregister child from interrupt propagation
if hasattr(parent_agent, '_active_children'):
if hasattr(parent_agent, "_active_children"):
try:
parent_agent._active_children.remove(child)
except (ValueError, UnboundLocalError):
@@ -292,11 +304,11 @@ def _run_single_child(
def delegate_task(
goal: Optional[str] = None,
context: Optional[str] = None,
toolsets: Optional[List[str]] = None,
tasks: Optional[List[Dict[str, Any]]] = None,
max_iterations: Optional[int] = None,
goal: str | None = None,
context: str | None = None,
toolsets: list[str] | None = None,
tasks: list[dict[str, Any]] | None = None,
max_iterations: int | None = None,
parent_agent=None,
) -> str:
"""
@@ -312,14 +324,11 @@ def delegate_task(
return json.dumps({"error": "delegate_task requires a parent agent context."})
# Depth limit
depth = getattr(parent_agent, '_delegate_depth', 0)
depth = getattr(parent_agent, "_delegate_depth", 0)
if depth >= MAX_DEPTH:
return json.dumps({
"error": (
f"Delegation depth limit reached ({MAX_DEPTH}). "
"Subagents cannot spawn further subagents."
)
})
return json.dumps(
{"error": (f"Delegation depth limit reached ({MAX_DEPTH}). Subagents cannot spawn further subagents.")}
)
# Load config
cfg = _load_config()
@@ -366,7 +375,7 @@ def delegate_task(
else:
# Batch -- run in parallel with per-task progress lines
completed_count = 0
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
spinner_ref = getattr(parent_agent, "_delegate_spinner", None)
# Save stdout/stderr before the executor — redirect_stdout in child
# threads races on sys.stdout and can leave it as devnull permanently.
@@ -412,7 +421,7 @@ def delegate_task(
status = entry.get("status", "?")
icon = "" if status == "completed" else ""
remaining = n_tasks - completed_count
completion_line = f"{icon} [{idx+1}/{n_tasks}] {label} ({dur}s)"
completion_line = f"{icon} [{idx + 1}/{n_tasks}] {label} ({dur}s)"
if spinner_ref:
try:
spinner_ref.print_above(completion_line)
@@ -437,16 +446,20 @@ def delegate_task(
total_duration = round(time.monotonic() - overall_start, 2)
return json.dumps({
"results": results,
"total_duration_seconds": total_duration,
}, ensure_ascii=False)
return json.dumps(
{
"results": results,
"total_duration_seconds": total_duration,
},
ensure_ascii=False,
)
def _load_config() -> dict:
"""Load delegation config from CLI_CONFIG if available."""
try:
from cli import CLI_CONFIG
return CLI_CONFIG.get("delegation", {})
except Exception:
return {}
@@ -537,10 +550,7 @@ DELEGATE_TASK_SCHEMA = {
},
"max_iterations": {
"type": "integer",
"description": (
"Max tool-calling turns per subagent (default: 50). "
"Only set lower for simple tasks."
),
"description": ("Max tool-calling turns per subagent (default: 50). Only set lower for simple tasks."),
},
},
"required": [],
@@ -561,6 +571,7 @@ registry.register(
toolsets=args.get("toolsets"),
tasks=args.get("tasks"),
max_iterations=args.get("max_iterations"),
parent_agent=kw.get("parent_agent")),
parent_agent=kw.get("parent_agent"),
),
check_fn=check_delegate_requirements,
)

View File

@@ -1,8 +1,8 @@
"""Base class for all Hermes execution environment backends."""
from abc import ABC, abstractmethod
import os
import subprocess
from abc import ABC, abstractmethod
from pathlib import Path
@@ -34,9 +34,9 @@ class BaseEnvironment(ABC):
self.env = env or {}
@abstractmethod
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
"""Execute a command, return {"output": str, "returncode": int}."""
...
@@ -62,10 +62,10 @@ class BaseEnvironment(ABC):
def _prepare_command(self, command: str) -> str:
"""Transform sudo commands if SUDO_PASSWORD is available."""
from tools.terminal_tool import _transform_sudo_command
return _transform_sudo_command(command)
def _build_run_kwargs(self, timeout: int | None,
stdin_data: str | None = None) -> dict:
def _build_run_kwargs(self, timeout: int | None, stdin_data: str | None = None) -> dict:
"""Build common subprocess.run kwargs for non-interactive execution."""
kw = {
"text": True,

View File

@@ -11,7 +11,6 @@ import shlex
import threading
import uuid
import warnings
from typing import Optional
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
@@ -32,8 +31,8 @@ class DaytonaEnvironment(BaseEnvironment):
cwd: str = "/home/daytona",
timeout: int = 60,
cpu: int = 1,
memory: int = 5120, # MB (hermes convention)
disk: int = 10240, # MB (Daytona platform max is 10GB)
memory: int = 5120, # MB (hermes convention)
disk: int = 10240, # MB (Daytona platform max is 10GB)
persistent_filesystem: bool = True,
task_id: str = "default",
):
@@ -41,8 +40,8 @@ class DaytonaEnvironment(BaseEnvironment):
super().__init__(cwd=cwd, timeout=timeout)
from daytona import (
Daytona,
CreateSandboxFromImageParams,
Daytona,
DaytonaError,
Resources,
SandboxState,
@@ -73,13 +72,11 @@ class DaytonaEnvironment(BaseEnvironment):
try:
self._sandbox = self._daytona.find_one(labels=labels)
self._sandbox.start()
logger.info("Daytona: resumed sandbox %s for task %s",
self._sandbox.id, task_id)
logger.info("Daytona: resumed sandbox %s for task %s", self._sandbox.id, task_id)
except DaytonaError:
self._sandbox = None
except Exception as e:
logger.warning("Daytona: failed to resume sandbox for task %s: %s",
task_id, e)
logger.warning("Daytona: failed to resume sandbox for task %s: %s", task_id, e)
self._sandbox = None
# Create a fresh sandbox if we don't have one
@@ -92,8 +89,7 @@ class DaytonaEnvironment(BaseEnvironment):
resources=resources,
)
)
logger.info("Daytona: created sandbox %s for task %s",
self._sandbox.id, task_id)
logger.info("Daytona: created sandbox %s for task %s", self._sandbox.id, task_id)
# Resolve cwd: detect actual home dir inside the sandbox
if self._requested_cwd in ("~", "/home/daytona"):
@@ -112,7 +108,7 @@ class DaytonaEnvironment(BaseEnvironment):
self._sandbox.start()
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
def _exec_in_thread(self, exec_command: str, cwd: str | None, timeout: int) -> dict:
"""Run exec in a background thread with interrupt polling.
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
@@ -130,7 +126,8 @@ class DaytonaEnvironment(BaseEnvironment):
def _run():
try:
response = self._sandbox.process.exec(
timed_command, cwd=cwd,
timed_command,
cwd=cwd,
)
result_holder["value"] = {
"output": response.result or "",
@@ -169,9 +166,9 @@ class DaytonaEnvironment(BaseEnvironment):
return {"error": result_holder["error"]}
return result_holder["value"]
def execute(self, command: str, cwd: str = "", *,
timeout: Optional[int] = None,
stdin_data: Optional[str] = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
with self._lock:
self._ensure_sandbox_ready()
@@ -189,6 +186,7 @@ class DaytonaEnvironment(BaseEnvironment):
if "error" in result:
from daytona import DaytonaError
err = result["error"]
if isinstance(err, DaytonaError):
with self._lock:
@@ -210,8 +208,7 @@ class DaytonaEnvironment(BaseEnvironment):
try:
if self._persistent:
self._sandbox.stop()
logger.info("Daytona: stopped sandbox %s (filesystem preserved)",
self._sandbox.id)
logger.info("Daytona: stopped sandbox %s (filesystem preserved)", self._sandbox.id)
else:
self._daytona.delete(self._sandbox)
logger.info("Daytona: deleted sandbox %s", self._sandbox.id)

View File

@@ -11,7 +11,6 @@ import subprocess
import sys
import threading
import time
from typing import Optional
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
@@ -19,7 +18,6 @@ from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
# Security flags applied to every container.
# The container itself is the security boundary (isolated from host).
# We drop all capabilities then add back the minimum needed:
@@ -28,19 +26,28 @@ logger = logging.getLogger(__name__)
# Block privilege escalation and limit PIDs.
# /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds).
_SECURITY_ARGS = [
"--cap-drop", "ALL",
"--cap-add", "DAC_OVERRIDE",
"--cap-add", "CHOWN",
"--cap-add", "FOWNER",
"--security-opt", "no-new-privileges",
"--pids-limit", "256",
"--tmpfs", "/tmp:rw,nosuid,size=512m",
"--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m",
"--tmpfs", "/run:rw,noexec,nosuid,size=64m",
"--cap-drop",
"ALL",
"--cap-add",
"DAC_OVERRIDE",
"--cap-add",
"CHOWN",
"--cap-add",
"FOWNER",
"--security-opt",
"no-new-privileges",
"--pids-limit",
"256",
"--tmpfs",
"/tmp:rw,nosuid,size=512m",
"--tmpfs",
"/var/tmp:rw,noexec,nosuid,size=256m",
"--tmpfs",
"/run:rw,noexec,nosuid,size=64m",
]
_storage_opt_ok: Optional[bool] = None # cached result across instances
_storage_opt_ok: bool | None = None # cached result across instances
class DockerEnvironment(BaseEnvironment):
@@ -74,7 +81,7 @@ class DockerEnvironment(BaseEnvironment):
self._base_image = image
self._persistent = persistent_filesystem
self._task_id = task_id
self._container_id: Optional[str] = None
self._container_id: str | None = None
logger.info(f"DockerEnvironment volumes: {volumes}")
# Ensure volumes is a list (config.yaml could be malformed)
if volumes is not None and not isinstance(volumes, list):
@@ -105,8 +112,8 @@ class DockerEnvironment(BaseEnvironment):
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
from tools.environments.base import get_sandbox_dir
self._workspace_dir: Optional[str] = None
self._home_dir: Optional[str] = None
self._workspace_dir: str | None = None
self._home_dir: str | None = None
if self._persistent:
sandbox = get_sandbox_dir() / "docker" / task_id
self._workspace_dir = str(sandbox / "workspace")
@@ -114,14 +121,19 @@ class DockerEnvironment(BaseEnvironment):
os.makedirs(self._workspace_dir, exist_ok=True)
os.makedirs(self._home_dir, exist_ok=True)
writable_args = [
"-v", f"{self._workspace_dir}:/workspace",
"-v", f"{self._home_dir}:/root",
"-v",
f"{self._workspace_dir}:/workspace",
"-v",
f"{self._home_dir}:/root",
]
else:
writable_args = [
"--tmpfs", "/workspace:rw,exec,size=10g",
"--tmpfs", "/home:rw,exec,size=1g",
"--tmpfs", "/root:rw,exec,size=1g",
"--tmpfs",
"/workspace:rw,exec,size=10g",
"--tmpfs",
"/home:rw,exec,size=1g",
"--tmpfs",
"/root:rw,exec,size=1g",
]
# All containers get security hardening (capabilities dropped, no privilege
@@ -129,7 +141,7 @@ class DockerEnvironment(BaseEnvironment):
# can install packages as needed.
# User-configured volume mounts (from config.yaml docker_volumes)
volume_args = []
for vol in (volumes or []):
for vol in volumes or []:
if not isinstance(vol, str):
logger.warning(f"Docker volume entry is not a string: {vol!r}")
continue
@@ -146,7 +158,9 @@ class DockerEnvironment(BaseEnvironment):
logger.info(f"Docker run_args: {all_run_args}")
self._inner = _Docker(
image=image, cwd=cwd, timeout=timeout,
image=image,
cwd=cwd,
timeout=timeout,
run_args=all_run_args,
)
self._container_id = self._inner.container_id
@@ -154,7 +168,7 @@ class DockerEnvironment(BaseEnvironment):
@staticmethod
def _storage_opt_supported() -> bool:
"""Check if Docker's storage driver supports --storage-opt size=.
Only overlay2 on XFS with pquota supports per-container disk quotas.
Ubuntu (and most distros) default to ext4, where this flag errors out.
"""
@@ -164,7 +178,9 @@ class DockerEnvironment(BaseEnvironment):
try:
result = subprocess.run(
["docker", "info", "--format", "{{.Driver}}"],
capture_output=True, text=True, timeout=10,
capture_output=True,
text=True,
timeout=10,
)
driver = result.stdout.strip().lower()
if driver != "overlay2":
@@ -174,14 +190,15 @@ class DockerEnvironment(BaseEnvironment):
# Probe by attempting a dry-ish run — the fastest reliable check.
probe = subprocess.run(
["docker", "create", "--storage-opt", "size=1m", "hello-world"],
capture_output=True, text=True, timeout=15,
capture_output=True,
text=True,
timeout=15,
)
if probe.returncode == 0:
# Clean up the created container
container_id = probe.stdout.strip()
if container_id:
subprocess.run(["docker", "rm", container_id],
capture_output=True, timeout=5)
subprocess.run(["docker", "rm", container_id], capture_output=True, timeout=5)
_storage_opt_ok = True
else:
_storage_opt_ok = False
@@ -190,9 +207,9 @@ class DockerEnvironment(BaseEnvironment):
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
return _storage_opt_ok
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
exec_command = self._prepare_command(command)
work_dir = cwd or self.cwd
effective_timeout = timeout or self.timeout
@@ -218,7 +235,8 @@ class DockerEnvironment(BaseEnvironment):
_output_chunks = []
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
text=True,
)
@@ -269,6 +287,7 @@ class DockerEnvironment(BaseEnvironment):
if not self._persistent:
import shutil
for d in (self._workspace_dir, self._home_dir):
if d:
shutil.rmtree(d, ignore_errors=True)

View File

@@ -154,9 +154,9 @@ class LocalEnvironment(BaseEnvironment):
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
from tools.terminal_tool import _interrupt_event
work_dir = cwd or self.cwd or os.getcwd()
@@ -172,11 +172,7 @@ class LocalEnvironment(BaseEnvironment):
# Wrap with output fences so we can later extract the real
# command output and discard shell init/exit noise.
fenced_cmd = (
f"printf '{_OUTPUT_FENCE}';"
f" {exec_command};"
f" __hermes_rc=$?;"
f" printf '{_OUTPUT_FENCE}';"
f" exit $__hermes_rc"
f"printf '{_OUTPUT_FENCE}'; {exec_command}; __hermes_rc=$?; printf '{_OUTPUT_FENCE}'; exit $__hermes_rc"
)
# Ensure PATH always includes standard dirs — systemd services
# and some terminal multiplexers inherit a minimal PATH.
@@ -200,12 +196,14 @@ class LocalEnvironment(BaseEnvironment):
)
if stdin_data is not None:
def _write_stdin():
try:
proc.stdin.write(stdin_data)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
threading.Thread(target=_write_stdin, daemon=True).start()
_output_chunks: list[str] = []

View File

@@ -8,10 +8,9 @@ project files, and config changes survive across sessions.
import json
import logging
import threading
import time
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
@@ -21,7 +20,7 @@ logger = logging.getLogger(__name__)
_SNAPSHOT_STORE = Path.home() / ".hermes" / "modal_snapshots.json"
def _load_snapshots() -> Dict[str, str]:
def _load_snapshots() -> dict[str, str]:
"""Load snapshot ID mapping from disk."""
if _SNAPSHOT_STORE.exists():
try:
@@ -31,7 +30,7 @@ def _load_snapshots() -> Dict[str, str]:
return {}
def _save_snapshots(data: Dict[str, str]) -> None:
def _save_snapshots(data: dict[str, str]) -> None:
"""Persist snapshot ID mapping to disk."""
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
@@ -52,7 +51,7 @@ class ModalEnvironment(BaseEnvironment):
image: str,
cwd: str = "~",
timeout: int = 60,
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
modal_sandbox_kwargs: dict[str, Any] | None = None,
persistent_filesystem: bool = True,
task_id: str = "default",
):
@@ -61,6 +60,7 @@ class ModalEnvironment(BaseEnvironment):
if not ModalEnvironment._patches_applied:
try:
from environments.patches import apply_patches
apply_patches()
except ImportError:
pass
@@ -79,6 +79,7 @@ class ModalEnvironment(BaseEnvironment):
if snapshot_id:
try:
import modal
restored_image = modal.Image.from_id(snapshot_id)
logger.info("Modal: restoring from snapshot %s", snapshot_id[:20])
except Exception as e:
@@ -88,6 +89,7 @@ class ModalEnvironment(BaseEnvironment):
effective_image = restored_image if restored_image else image
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
self._inner = SwerexModalEnvironment(
image=effective_image,
cwd=cwd,
@@ -97,9 +99,9 @@ class ModalEnvironment(BaseEnvironment):
modal_sandbox_kwargs=sandbox_kwargs,
)
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
if stdin_data is not None:
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
while marker in stdin_data:
@@ -139,29 +141,29 @@ class ModalEnvironment(BaseEnvironment):
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
if self._persistent:
try:
sandbox = getattr(self._inner, 'deployment', None)
sandbox = getattr(sandbox, '_sandbox', None) if sandbox else None
sandbox = getattr(self._inner, "deployment", None)
sandbox = getattr(sandbox, "_sandbox", None) if sandbox else None
if sandbox:
import asyncio
async def _snapshot():
img = await sandbox.snapshot_filesystem.aio()
return img.object_id
try:
snapshot_id = asyncio.run(_snapshot())
except RuntimeError:
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
snapshot_id = pool.submit(
asyncio.run, _snapshot()
).result(timeout=60)
snapshot_id = pool.submit(asyncio.run, _snapshot()).result(timeout=60)
snapshots = _load_snapshots()
snapshots[self._task_id] = snapshot_id
_save_snapshots(snapshots)
logger.info("Modal: saved filesystem snapshot %s for task %s",
snapshot_id[:20], self._task_id)
logger.info("Modal: saved filesystem snapshot %s for task %s", snapshot_id[:20], self._task_id)
except Exception as e:
logger.warning("Modal: filesystem snapshot failed: %s", e)
if hasattr(self._inner, 'stop'):
if hasattr(self._inner, "stop"):
self._inner.stop()

View File

@@ -10,11 +10,9 @@ import logging
import os
import shutil
import subprocess
import tempfile
import threading
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
@@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
_SNAPSHOT_STORE = Path.home() / ".hermes" / "singularity_snapshots.json"
def _load_snapshots() -> Dict[str, str]:
def _load_snapshots() -> dict[str, str]:
if _SNAPSHOT_STORE.exists():
try:
return json.loads(_SNAPSHOT_STORE.read_text())
@@ -33,7 +31,7 @@ def _load_snapshots() -> Dict[str, str]:
return {}
def _save_snapshots(data: Dict[str, str]) -> None:
def _save_snapshots(data: dict[str, str]) -> None:
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
@@ -42,6 +40,7 @@ def _save_snapshots(data: Dict[str, str]) -> None:
# Singularity helpers (scratch dir, SIF cache, SIF building)
# -------------------------------------------------------------------------
def _get_scratch_dir() -> Path:
"""Get the best directory for Singularity sandboxes.
@@ -58,6 +57,7 @@ def _get_scratch_dir() -> Path:
return scratch_path
from tools.environments.base import get_sandbox_dir
sandbox = get_sandbox_dir() / "singularity"
scratch = Path("/scratch")
@@ -93,12 +93,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
Returns the path unchanged if it's already a .sif file.
For docker:// URLs, checks the cache and builds if needed.
"""
if image.endswith('.sif') and Path(image).exists():
if image.endswith(".sif") and Path(image).exists():
return image
if not image.startswith('docker://'):
if not image.startswith("docker://"):
return image
image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-')
image_name = image.replace("docker://", "").replace("/", "-").replace(":", "-")
cache_dir = _get_apptainer_cache_dir()
sif_path = cache_dir / f"{image_name}.sif"
@@ -123,7 +123,10 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
try:
result = subprocess.run(
[executable, "build", str(sif_path), image],
capture_output=True, text=True, timeout=600, env=env,
capture_output=True,
text=True,
timeout=600,
env=env,
)
if result.returncode != 0:
logger.warning("SIF build failed, falling back to docker:// URL")
@@ -145,6 +148,7 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
# SingularityEnvironment
# -------------------------------------------------------------------------
class SingularityEnvironment(BaseEnvironment):
"""Hardened Singularity/Apptainer container with resource limits and persistence.
@@ -174,7 +178,7 @@ class SingularityEnvironment(BaseEnvironment):
self._instance_started = False
self._persistent = persistent_filesystem
self._task_id = task_id
self._overlay_dir: Optional[Path] = None
self._overlay_dir: Path | None = None
# Resource limits
self._cpu = cpu
@@ -215,14 +219,13 @@ class SingularityEnvironment(BaseEnvironment):
if result.returncode != 0:
raise RuntimeError(f"Failed to start instance: {result.stderr}")
self._instance_started = True
logger.info("Singularity instance %s started (persistent=%s)",
self.instance_id, self._persistent)
logger.info("Singularity instance %s started (persistent=%s)", self.instance_id, self._persistent)
except subprocess.TimeoutExpired:
raise RuntimeError("Instance start timed out")
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
if not self._instance_started:
return {"output": "Instance not started", "returncode": -1}
@@ -235,16 +238,16 @@ class SingularityEnvironment(BaseEnvironment):
exec_command = f"cd {work_dir} && {exec_command}"
work_dir = "/tmp"
cmd = [self.executable, "exec", "--pwd", work_dir,
f"instance://{self.instance_id}",
"bash", "-c", exec_command]
cmd = [self.executable, "exec", "--pwd", work_dir, f"instance://{self.instance_id}", "bash", "-c", exec_command]
try:
import time as _time
_output_chunks = []
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
text=True,
)
@@ -295,7 +298,9 @@ class SingularityEnvironment(BaseEnvironment):
try:
subprocess.run(
[self.executable, "instance", "stop", self.instance_id],
capture_output=True, text=True, timeout=30,
capture_output=True,
text=True,
timeout=30,
)
logger.info("Singularity instance %s stopped", self.instance_id)
except Exception as e:

View File

@@ -24,8 +24,7 @@ class SSHEnvironment(BaseEnvironment):
and a remote kill is attempted over the ControlMaster socket.
"""
def __init__(self, host: str, user: str, cwd: str = "~",
timeout: int = 60, port: int = 22, key_path: str = ""):
def __init__(self, host: str, user: str, cwd: str = "~", timeout: int = 60, port: int = 22, key_path: str = ""):
super().__init__(cwd=cwd, timeout=timeout)
self.host = host
self.user = user
@@ -65,12 +64,12 @@ class SSHEnvironment(BaseEnvironment):
except subprocess.TimeoutExpired:
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
def execute(
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
) -> dict:
work_dir = cwd or self.cwd
exec_command = self._prepare_command(command)
wrapped = f'cd {work_dir} && {exec_command}'
wrapped = f"cd {work_dir} && {exec_command}"
effective_timeout = timeout or self.timeout
cmd = self._build_ssh_command()
@@ -136,8 +135,7 @@ class SSHEnvironment(BaseEnvironment):
def cleanup(self):
if self.control_socket.exists():
try:
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
"-O", "exit", f"{self.user}@{self.host}"]
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", "-O", "exit", f"{self.user}@{self.host}"]
subprocess.run(cmd, capture_output=True, timeout=5)
except (OSError, subprocess.SubprocessError):
pass

File diff suppressed because it is too large Load Diff

View File

@@ -3,11 +3,10 @@
import json
import logging
import os
import threading
from typing import Optional
from tools.file_operations import ShellFileOperations
from agent.redact import redact_sensitive_text
from tools.file_operations import ShellFileOperations
logger = logging.getLogger(__name__)
@@ -25,14 +24,19 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
Thread-safe: uses the same per-task creation locks as terminal_tool to
prevent duplicate sandbox creation from concurrent tool calls.
"""
from tools.terminal_tool import (
_active_environments, _env_lock, _create_environment,
_get_env_config, _last_activity, _start_cleanup_thread,
_check_disk_usage_warning,
_creation_locks, _creation_locks_lock,
)
import time
from tools.terminal_tool import (
_active_environments,
_create_environment,
_creation_locks,
_creation_locks_lock,
_env_lock,
_get_env_config,
_last_activity,
_start_cleanup_thread,
)
# Fast path: check cache -- but also verify the underlying environment
# is still alive (it may have been killed by the cleanup thread).
with _file_ops_lock:
@@ -143,17 +147,23 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
result = file_ops.write_file(path, content)
return json.dumps(result.to_dict(), ensure_ascii=False)
except Exception as e:
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
return json.dumps({"error": str(e)}, ensure_ascii=False)
def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
new_string: str = None, replace_all: bool = False, patch: str = None,
task_id: str = "default") -> str:
def patch_tool(
mode: str = "replace",
path: str = None,
old_string: str = None,
new_string: str = None,
replace_all: bool = False,
patch: str = None,
task_id: str = "default",
) -> str:
"""Patch a file using replace mode or V4A patch format."""
try:
file_ops = _get_file_ops(task_id)
if mode == "replace":
if not path:
return json.dumps({"error": "path required"})
@@ -166,7 +176,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
result = file_ops.patch_v4a(patch)
else:
return json.dumps({"error": f"Unknown mode: {mode}"})
result_dict = result.to_dict()
result_json = json.dumps(result_dict, ensure_ascii=False)
# Hint when old_string not found — saves iterations where the agent
@@ -178,20 +188,33 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
return json.dumps({"error": str(e)}, ensure_ascii=False)
def search_tool(pattern: str, target: str = "content", path: str = ".",
file_glob: str = None, limit: int = 50, offset: int = 0,
output_mode: str = "content", context: int = 0,
task_id: str = "default") -> str:
def search_tool(
pattern: str,
target: str = "content",
path: str = ".",
file_glob: str = None,
limit: int = 50,
offset: int = 0,
output_mode: str = "content",
context: int = 0,
task_id: str = "default",
) -> str:
"""Search for content or files."""
try:
file_ops = _get_file_ops(task_id)
result = file_ops.search(
pattern=pattern, path=path, target=target, file_glob=file_glob,
limit=limit, offset=offset, output_mode=output_mode, context=context
pattern=pattern,
path=path,
target=target,
file_glob=file_glob,
limit=limit,
offset=offset,
output_mode=output_mode,
context=context,
)
if hasattr(result, 'matches'):
if hasattr(result, "matches"):
for m in result.matches:
if hasattr(m, 'content') and m.content:
if hasattr(m, "content") and m.content:
m.content = redact_sensitive_text(m.content)
result_dict = result.to_dict()
result_json = json.dumps(result_dict, ensure_ascii=False)
@@ -209,7 +232,7 @@ FILE_TOOLS = [
{"name": "read_file", "function": read_file_tool},
{"name": "write_file", "function": write_file_tool},
{"name": "patch", "function": patch_tool},
{"name": "search_files", "function": search_tool}
{"name": "search_files", "function": search_tool},
]
@@ -227,8 +250,10 @@ from tools.registry import registry
def _check_file_reqs():
"""Lazy wrapper to avoid circular import with tools/__init__.py."""
from tools import check_file_requirements
return check_file_requirements()
READ_FILE_SCHEMA = {
"name": "read_file",
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
@@ -236,11 +261,21 @@ READ_FILE_SCHEMA = {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path to the file to read (absolute, relative, or ~/path)"},
"offset": {"type": "integer", "description": "Line number to start reading from (1-indexed, default: 1)", "default": 1, "minimum": 1},
"limit": {"type": "integer", "description": "Maximum number of lines to read (default: 500, max: 2000)", "default": 500, "maximum": 2000}
"offset": {
"type": "integer",
"description": "Line number to start reading from (1-indexed, default: 1)",
"default": 1,
"minimum": 1,
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read (default: 500, max: 2000)",
"default": 500,
"maximum": 2000,
},
},
"required": ["path"]
}
"required": ["path"],
},
}
WRITE_FILE_SCHEMA = {
@@ -249,11 +284,14 @@ WRITE_FILE_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"},
"content": {"type": "string", "description": "Complete content to write to the file"}
"path": {
"type": "string",
"description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)",
},
"content": {"type": "string", "description": "Complete content to write to the file"},
},
"required": ["path", "content"]
}
"required": ["path", "content"],
},
}
PATCH_SCHEMA = {
@@ -262,15 +300,33 @@ PATCH_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"mode": {"type": "string", "enum": ["replace", "patch"], "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", "default": "replace"},
"mode": {
"type": "string",
"enum": ["replace", "patch"],
"description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches",
"default": "replace",
},
"path": {"type": "string", "description": "File path to edit (required for 'replace' mode)"},
"old_string": {"type": "string", "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."},
"new_string": {"type": "string", "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."},
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match (default: false)", "default": False},
"patch": {"type": "string", "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"}
"old_string": {
"type": "string",
"description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness.",
},
"new_string": {
"type": "string",
"description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text.",
},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences instead of requiring a unique match (default: false)",
"default": False,
},
"patch": {
"type": "string",
"description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch",
},
},
"required": ["mode"]
}
"required": ["mode"],
},
}
SEARCH_FILES_SCHEMA = {
@@ -279,23 +335,57 @@ SEARCH_FILES_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search"},
"target": {"type": "string", "enum": ["content", "files"], "description": "'content' searches inside file contents, 'files' searches for files by name", "default": "content"},
"path": {"type": "string", "description": "Directory or file to search in (default: current working directory)", "default": "."},
"file_glob": {"type": "string", "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)"},
"limit": {"type": "integer", "description": "Maximum number of results to return (default: 50)", "default": 50},
"offset": {"type": "integer", "description": "Skip first N results for pagination (default: 0)", "default": 0},
"output_mode": {"type": "string", "enum": ["content", "files_only", "count"], "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", "default": "content"},
"context": {"type": "integer", "description": "Number of context lines before and after each match (grep mode only)", "default": 0}
"pattern": {
"type": "string",
"description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search",
},
"target": {
"type": "string",
"enum": ["content", "files"],
"description": "'content' searches inside file contents, 'files' searches for files by name",
"default": "content",
},
"path": {
"type": "string",
"description": "Directory or file to search in (default: current working directory)",
"default": ".",
},
"file_glob": {
"type": "string",
"description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)",
},
"limit": {
"type": "integer",
"description": "Maximum number of results to return (default: 50)",
"default": 50,
},
"offset": {
"type": "integer",
"description": "Skip first N results for pagination (default: 0)",
"default": 0,
},
"output_mode": {
"type": "string",
"enum": ["content", "files_only", "count"],
"description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file",
"default": "content",
},
"context": {
"type": "integer",
"description": "Number of context lines before and after each match (grep mode only)",
"default": 0,
},
},
"required": ["pattern"]
}
"required": ["pattern"],
},
}
def _handle_read_file(args, **kw):
tid = kw.get("task_id") or "default"
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid)
return read_file_tool(
path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid
)
def _handle_write_file(args, **kw):
@@ -306,9 +396,14 @@ def _handle_write_file(args, **kw):
def _handle_patch(args, **kw):
tid = kw.get("task_id") or "default"
return patch_tool(
mode=args.get("mode", "replace"), path=args.get("path"),
old_string=args.get("old_string"), new_string=args.get("new_string"),
replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid)
mode=args.get("mode", "replace"),
path=args.get("path"),
old_string=args.get("old_string"),
new_string=args.get("new_string"),
replace_all=args.get("replace_all", False),
patch=args.get("patch"),
task_id=tid,
)
def _handle_search_files(args, **kw):
@@ -317,12 +412,29 @@ def _handle_search_files(args, **kw):
raw_target = args.get("target", "content")
target = target_map.get(raw_target, raw_target)
return search_tool(
pattern=args.get("pattern", ""), target=target, path=args.get("path", "."),
file_glob=args.get("file_glob"), limit=args.get("limit", 50), offset=args.get("offset", 0),
output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid)
pattern=args.get("pattern", ""),
target=target,
path=args.get("path", "."),
file_glob=args.get("file_glob"),
limit=args.get("limit", 50),
offset=args.get("offset", 0),
output_mode=args.get("output_mode", "content"),
context=args.get("context", 0),
task_id=tid,
)
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs)
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs)
registry.register(
name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs
)
registry.register(
name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs
)
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs)
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs)
registry.register(
name="search_files",
toolset="file",
schema=SEARCH_FILES_SCHEMA,
handler=_handle_search_files,
check_fn=_check_file_reqs,
)

View File

@@ -19,7 +19,7 @@ The 9-strategy chain (inspired by OpenCode):
Usage:
from tools.fuzzy_match import fuzzy_find_and_replace
new_content, match_count, error = fuzzy_find_and_replace(
content="def foo():\\n pass",
old_string="def foo():",
@@ -29,21 +29,22 @@ Usage:
"""
import re
from typing import Tuple, Optional, List, Callable
from collections.abc import Callable
from difflib import SequenceMatcher
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
def fuzzy_find_and_replace(
content: str, old_string: str, new_string: str, replace_all: bool = False
) -> tuple[str, int, str | None]:
"""
Find and replace text using a chain of increasingly fuzzy matching strategies.
Args:
content: The file content to search in
old_string: The text to find
new_string: The replacement text
replace_all: If True, replace all occurrences; if False, require uniqueness
Returns:
Tuple of (new_content, match_count, error_message)
- If successful: (modified_content, number_of_replacements, None)
@@ -51,12 +52,12 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
"""
if not old_string:
return content, 0, "old_string cannot be empty"
if old_string == new_string:
return content, 0, "old_string and new_string are identical"
# Try each matching strategy in order
strategies: List[Tuple[str, Callable]] = [
strategies: list[tuple[str, Callable]] = [
("exact", _strategy_exact),
("line_trimmed", _strategy_line_trimmed),
("whitespace_normalized", _strategy_whitespace_normalized),
@@ -66,46 +67,50 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
("block_anchor", _strategy_block_anchor),
("context_aware", _strategy_context_aware),
]
for strategy_name, strategy_fn in strategies:
matches = strategy_fn(content, old_string)
if matches:
# Found matches with this strategy
if len(matches) > 1 and not replace_all:
return content, 0, (
f"Found {len(matches)} matches for old_string. "
f"Provide more context to make it unique, or use replace_all=True."
return (
content,
0,
(
f"Found {len(matches)} matches for old_string. "
f"Provide more context to make it unique, or use replace_all=True."
),
)
# Perform replacement
new_content = _apply_replacements(content, matches, new_string)
return new_content, len(matches), None
# No strategy found a match
return content, 0, "Could not find a match for old_string in the file"
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
def _apply_replacements(content: str, matches: list[tuple[int, int]], new_string: str) -> str:
"""
Apply replacements at the given positions.
Args:
content: Original content
matches: List of (start, end) positions to replace
new_string: Replacement text
Returns:
Content with replacements applied
"""
# Sort matches by position (descending) to replace from end to start
# This preserves positions of earlier matches
sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True)
result = content
for start, end in sorted_matches:
result = result[:start] + new_string + result[end:]
return result
@@ -113,7 +118,8 @@ def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string
# Matching Strategies
# =============================================================================
def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_exact(content: str, pattern: str) -> list[tuple[int, int]]:
"""Strategy 1: Exact string match."""
matches = []
start = 0
@@ -126,206 +132,201 @@ def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
return matches
def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_line_trimmed(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 2: Match with line-by-line whitespace trimming.
Strips leading/trailing whitespace from each line before matching.
"""
# Normalize pattern and content by trimming each line
pattern_lines = [line.strip() for line in pattern.split('\n')]
pattern_normalized = '\n'.join(pattern_lines)
content_lines = content.split('\n')
pattern_lines = [line.strip() for line in pattern.split("\n")]
pattern_normalized = "\n".join(pattern_lines)
content_lines = content.split("\n")
content_normalized_lines = [line.strip() for line in content_lines]
# Build mapping from normalized positions back to original positions
return _find_normalized_matches(
content, content_lines, content_normalized_lines,
pattern, pattern_normalized
)
return _find_normalized_matches(content, content_lines, content_normalized_lines, pattern, pattern_normalized)
def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_whitespace_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 3: Collapse multiple whitespace to single space.
"""
def normalize(s):
# Collapse multiple spaces/tabs to single space, preserve newlines
return re.sub(r'[ \t]+', ' ', s)
return re.sub(r"[ \t]+", " ", s)
pattern_normalized = normalize(pattern)
content_normalized = normalize(content)
# Find in normalized, map back to original
matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized)
if not matches_in_normalized:
return []
# Map positions back to original content
return _map_normalized_positions(content, content_normalized, matches_in_normalized)
def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_indentation_flexible(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 4: Ignore indentation differences entirely.
Strips all leading whitespace from lines before matching.
"""
def strip_indent(s):
return '\n'.join(line.lstrip() for line in s.split('\n'))
return "\n".join(line.lstrip() for line in s.split("\n"))
pattern_stripped = strip_indent(pattern)
content_lines = content.split('\n')
content_lines = content.split("\n")
content_stripped_lines = [line.lstrip() for line in content_lines]
pattern_lines = [line.lstrip() for line in pattern.split('\n')]
return _find_normalized_matches(
content, content_lines, content_stripped_lines,
pattern, '\n'.join(pattern_lines)
)
pattern_lines = [line.lstrip() for line in pattern.split("\n")]
return _find_normalized_matches(content, content_lines, content_stripped_lines, pattern, "\n".join(pattern_lines))
def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_escape_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 5: Convert escape sequences to actual characters.
Handles \\n -> newline, \\t -> tab, etc.
"""
def unescape(s):
# Convert common escape sequences
return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r')
return s.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
pattern_unescaped = unescape(pattern)
if pattern_unescaped == pattern:
# No escapes to convert, skip this strategy
return []
return _strategy_exact(content, pattern_unescaped)
def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_trimmed_boundary(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 6: Trim whitespace from first and last lines only.
Useful when the pattern boundaries have whitespace differences.
"""
pattern_lines = pattern.split('\n')
pattern_lines = pattern.split("\n")
if not pattern_lines:
return []
# Trim only first and last lines
pattern_lines[0] = pattern_lines[0].strip()
if len(pattern_lines) > 1:
pattern_lines[-1] = pattern_lines[-1].strip()
modified_pattern = '\n'.join(pattern_lines)
content_lines = content.split('\n')
modified_pattern = "\n".join(pattern_lines)
content_lines = content.split("\n")
# Search through content for matching block
matches = []
pattern_line_count = len(pattern_lines)
for i in range(len(content_lines) - pattern_line_count + 1):
block_lines = content_lines[i:i + pattern_line_count]
block_lines = content_lines[i : i + pattern_line_count]
# Trim first and last of this block
check_lines = block_lines.copy()
check_lines[0] = check_lines[0].strip()
if len(check_lines) > 1:
check_lines[-1] = check_lines[-1].strip()
if '\n'.join(check_lines) == modified_pattern:
if "\n".join(check_lines) == modified_pattern:
# Found match - calculate original positions
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
if end_pos >= len(content):
end_pos = len(content)
matches.append((start_pos, end_pos))
return matches
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_block_anchor(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 7: Match by anchoring on first and last lines.
If first and last lines match exactly, accept middle with 70% similarity.
"""
pattern_lines = pattern.split('\n')
pattern_lines = pattern.split("\n")
if len(pattern_lines) < 2:
return [] # Need at least 2 lines for anchoring
first_line = pattern_lines[0].strip()
last_line = pattern_lines[-1].strip()
content_lines = content.split('\n')
content_lines = content.split("\n")
matches = []
pattern_line_count = len(pattern_lines)
for i in range(len(content_lines) - pattern_line_count + 1):
# Check if first and last lines match
if (content_lines[i].strip() == first_line and
content_lines[i + pattern_line_count - 1].strip() == last_line):
if content_lines[i].strip() == first_line and content_lines[i + pattern_line_count - 1].strip() == last_line:
# Check middle similarity
if pattern_line_count <= 2:
# Only first and last, they match
similarity = 1.0
else:
content_middle = '\n'.join(content_lines[i+1:i+pattern_line_count-1])
pattern_middle = '\n'.join(pattern_lines[1:-1])
content_middle = "\n".join(content_lines[i + 1 : i + pattern_line_count - 1])
pattern_middle = "\n".join(pattern_lines[1:-1])
similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio()
if similarity >= 0.70:
# Calculate positions
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
if end_pos >= len(content):
end_pos = len(content)
matches.append((start_pos, end_pos))
return matches
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
def _strategy_context_aware(content: str, pattern: str) -> list[tuple[int, int]]:
"""
Strategy 8: Line-by-line similarity with 50% threshold.
Finds blocks where at least 50% of lines have high similarity.
"""
pattern_lines = pattern.split('\n')
content_lines = content.split('\n')
pattern_lines = pattern.split("\n")
content_lines = content.split("\n")
if not pattern_lines:
return []
matches = []
pattern_line_count = len(pattern_lines)
for i in range(len(content_lines) - pattern_line_count + 1):
block_lines = content_lines[i:i + pattern_line_count]
block_lines = content_lines[i : i + pattern_line_count]
# Calculate line-by-line similarity
high_similarity_count = 0
for p_line, c_line in zip(pattern_lines, block_lines):
sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio()
if sim >= 0.80:
high_similarity_count += 1
# Need at least 50% of lines to have high similarity
if high_similarity_count >= len(pattern_lines) * 0.5:
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
if end_pos >= len(content):
end_pos = len(content)
matches.append((start_pos, end_pos))
return matches
@@ -333,74 +334,76 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]
# Helper Functions
# =============================================================================
def _find_normalized_matches(content: str, content_lines: List[str],
content_normalized_lines: List[str],
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
def _find_normalized_matches(
content: str, content_lines: list[str], content_normalized_lines: list[str], pattern: str, pattern_normalized: str
) -> list[tuple[int, int]]:
"""
Find matches in normalized content and map back to original positions.
Args:
content: Original content string
content_lines: Original content split by lines
content_normalized_lines: Normalized content lines
pattern: Original pattern
pattern_normalized: Normalized pattern
Returns:
List of (start, end) positions in the original content
"""
pattern_norm_lines = pattern_normalized.split('\n')
pattern_norm_lines = pattern_normalized.split("\n")
num_pattern_lines = len(pattern_norm_lines)
matches = []
for i in range(len(content_normalized_lines) - num_pattern_lines + 1):
# Check if this block matches
block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines])
block = "\n".join(content_normalized_lines[i : i + num_pattern_lines])
if block == pattern_normalized:
# Found a match - calculate original positions
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1
end_pos = sum(len(line) + 1 for line in content_lines[: i + num_pattern_lines]) - 1
# Handle case where end is past content
if end_pos >= len(content):
end_pos = len(content)
matches.append((start_pos, end_pos))
return matches
def _map_normalized_positions(original: str, normalized: str,
normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
def _map_normalized_positions(
original: str, normalized: str, normalized_matches: list[tuple[int, int]]
) -> list[tuple[int, int]]:
"""
Map positions from normalized string back to original.
This is a best-effort mapping that works for whitespace normalization.
"""
if not normalized_matches:
return []
# Build character mapping from normalized to original
orig_to_norm = [] # orig_to_norm[i] = position in normalized
orig_idx = 0
norm_idx = 0
while orig_idx < len(original) and norm_idx < len(normalized):
if original[orig_idx] == normalized[norm_idx]:
orig_to_norm.append(norm_idx)
orig_idx += 1
norm_idx += 1
elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ':
elif original[orig_idx] in " \t" and normalized[norm_idx] == " ":
# Original has space/tab, normalized collapsed to space
orig_to_norm.append(norm_idx)
orig_idx += 1
# Don't advance norm_idx yet - wait until all whitespace consumed
if orig_idx < len(original) and original[orig_idx] not in ' \t':
if orig_idx < len(original) and original[orig_idx] not in " \t":
norm_idx += 1
elif original[orig_idx] in ' \t':
elif original[orig_idx] in " \t":
# Extra whitespace in original
orig_to_norm.append(norm_idx)
orig_idx += 1
@@ -408,21 +411,21 @@ def _map_normalized_positions(original: str, normalized: str,
# Mismatch - shouldn't happen with our normalization
orig_to_norm.append(norm_idx)
orig_idx += 1
# Fill remaining
while orig_idx < len(original):
orig_to_norm.append(len(normalized))
orig_idx += 1
# Reverse mapping: for each normalized position, find original range
norm_to_orig_start = {}
norm_to_orig_end = {}
for orig_pos, norm_pos in enumerate(orig_to_norm):
if norm_pos not in norm_to_orig_start:
norm_to_orig_start[norm_pos] = orig_pos
norm_to_orig_end[norm_pos] = orig_pos
# Map matches
original_matches = []
for norm_start, norm_end in normalized_matches:
@@ -432,17 +435,17 @@ def _map_normalized_positions(original: str, normalized: str,
else:
# Find nearest
orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start)
# Find original end
if norm_end - 1 in norm_to_orig_end:
orig_end = norm_to_orig_end[norm_end - 1] + 1
else:
orig_end = orig_start + (norm_end - norm_start)
# Expand to include trailing whitespace that was normalized
while orig_end < len(original) and original[orig_end] in ' \t':
while orig_end < len(original) and original[orig_end] in " \t":
orig_end += 1
original_matches.append((orig_start, min(orig_end, len(original))))
return original_matches

View File

@@ -15,7 +15,7 @@ import json
import logging
import os
import re
from typing import Any, Dict, Optional
from typing import Any
logger = logging.getLogger(__name__)
@@ -35,23 +35,26 @@ def _get_config():
_HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
)
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
# Service domains blocked for security -- these allow arbitrary code/command
# execution on the HA host or enable SSRF attacks on the local network.
# HA provides zero service-level access control; all safety must be in our layer.
_BLOCKED_DOMAINS = frozenset({
"shell_command", # arbitrary shell commands as root in HA container
"command_line", # sensors/switches that execute shell commands
"python_script", # sandboxed but can escalate via hass.services.call()
"pyscript", # scripting integration with broader access
"hassio", # addon control, host shutdown/reboot, stdin to containers
"rest_command", # HTTP requests from HA server (SSRF vector)
})
_BLOCKED_DOMAINS = frozenset(
{
"shell_command", # arbitrary shell commands as root in HA container
"command_line", # sensors/switches that execute shell commands
"python_script", # sandboxed but can escalate via hass.services.call()
"pyscript", # scripting integration with broader access
"hassio", # addon control, host shutdown/reboot, stdin to containers
"rest_command", # HTTP requests from HA server (SSRF vector)
}
)
def _get_headers(token: str = "") -> Dict[str, str]:
def _get_headers(token: str = "") -> dict[str, str]:
"""Return authorization headers for HA REST API."""
if not token:
_, token = _get_config()
@@ -65,11 +68,12 @@ def _get_headers(token: str = "") -> Dict[str, str]:
# Async helpers (called from sync handlers via run_until_complete)
# ---------------------------------------------------------------------------
def _filter_and_summarize(
states: list,
domain: Optional[str] = None,
area: Optional[str] = None,
) -> Dict[str, Any]:
domain: str | None = None,
area: str | None = None,
) -> dict[str, Any]:
"""Filter raw HA states by domain/area and return a compact summary."""
if domain:
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
@@ -77,26 +81,29 @@ def _filter_and_summarize(
if area:
area_lower = area.lower()
states = [
s for s in states
s
for s in states
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
]
entities = []
for s in states:
entities.append({
"entity_id": s["entity_id"],
"state": s["state"],
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
})
entities.append(
{
"entity_id": s["entity_id"],
"state": s["state"],
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
}
)
return {"count": len(entities), "entities": entities}
async def _async_list_entities(
domain: Optional[str] = None,
area: Optional[str] = None,
) -> Dict[str, Any]:
domain: str | None = None,
area: str | None = None,
) -> dict[str, Any]:
"""Fetch entity states from HA and optionally filter by domain/area."""
import aiohttp
@@ -110,7 +117,7 @@ async def _async_list_entities(
return _filter_and_summarize(states, domain, area)
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
async def _async_get_state(entity_id: str) -> dict[str, Any]:
"""Fetch detailed state of a single entity."""
import aiohttp
@@ -131,11 +138,11 @@ async def _async_get_state(entity_id: str) -> Dict[str, Any]:
def _build_service_payload(
entity_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
entity_id: str | None = None,
data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Build the JSON payload for a HA service call."""
payload: Dict[str, Any] = {}
payload: dict[str, Any] = {}
if data:
payload.update(data)
# entity_id parameter takes precedence over data["entity_id"]
@@ -148,15 +155,17 @@ def _parse_service_response(
domain: str,
service: str,
result: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Parse HA service call response into a structured result."""
affected = []
if isinstance(result, list):
for s in result:
affected.append({
"entity_id": s.get("entity_id", ""),
"state": s.get("state", ""),
})
affected.append(
{
"entity_id": s.get("entity_id", ""),
"state": s.get("state", ""),
}
)
return {
"success": True,
@@ -168,9 +177,9 @@ def _parse_service_response(
async def _async_call_service(
domain: str,
service: str,
entity_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
entity_id: str | None = None,
data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Call a Home Assistant service."""
import aiohttp
@@ -178,15 +187,17 @@ async def _async_call_service(
url = f"{hass_url}/api/services/{domain}/{service}"
payload = _build_service_payload(entity_id, data)
async with aiohttp.ClientSession() as session:
async with session.post(
async with (
aiohttp.ClientSession() as session,
session.post(
url,
headers=_get_headers(hass_token),
json=payload,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
resp.raise_for_status()
result = await resp.json()
) as resp,
):
resp.raise_for_status()
result = await resp.json()
return _parse_service_response(domain, service, result)
@@ -195,6 +206,7 @@ async def _async_call_service(
# Sync wrappers (handler signature: (args, **kw) -> str)
# ---------------------------------------------------------------------------
def _run_async(coro):
"""Run an async coroutine from a sync handler."""
try:
@@ -205,6 +217,7 @@ def _run_async(coro):
if loop and loop.is_running():
# Already inside an event loop -- create a new thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(asyncio.run, coro)
return future.result(timeout=30)
@@ -247,10 +260,12 @@ def _handle_call_service(args: dict, **kw) -> str:
return json.dumps({"error": "Missing required parameters: domain and service"})
if domain in _BLOCKED_DOMAINS:
return json.dumps({
"error": f"Service domain '{domain}' is blocked for security. "
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
})
return json.dumps(
{
"error": f"Service domain '{domain}' is blocked for security. "
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
}
)
entity_id = args.get("entity_id")
if entity_id and not _ENTITY_ID_RE.match(entity_id):
@@ -269,7 +284,8 @@ def _handle_call_service(args: dict, **kw) -> str:
# List services
# ---------------------------------------------------------------------------
async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
async def _async_list_services(domain: str | None = None) -> dict[str, Any]:
"""Fetch available services from HA and optionally filter by domain."""
import aiohttp
@@ -290,13 +306,10 @@ async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
d = svc_domain.get("domain", "")
domain_services = {}
for svc_name, svc_info in svc_domain.get("services", {}).items():
svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
svc_entry: dict[str, Any] = {"description": svc_info.get("description", "")}
fields = svc_info.get("fields", {})
if fields:
svc_entry["fields"] = {
k: v.get("description", "") for k, v in fields.items()
if isinstance(v, dict)
}
svc_entry["fields"] = {k: v.get("description", "") for k, v in fields.items() if isinstance(v, dict)}
domain_services[svc_name] = svc_entry
result.append({"domain": d, "services": domain_services})
@@ -318,6 +331,7 @@ def _handle_list_services(args: dict, **kw) -> str:
# Availability check
# ---------------------------------------------------------------------------
def _check_ha_available() -> bool:
"""Tool is only available when HASS_TOKEN is set."""
return bool(os.getenv("HASS_TOKEN"))
@@ -369,8 +383,7 @@ HA_GET_STATE_SCHEMA = {
"entity_id": {
"type": "string",
"description": (
"The entity ID to query (e.g. 'light.living_room', "
"'climate.thermostat', 'sensor.temperature')."
"The entity ID to query (e.g. 'light.living_room', 'climate.thermostat', 'sensor.temperature')."
),
},
},
@@ -392,8 +405,7 @@ HA_LIST_SERVICES_SCHEMA = {
"domain": {
"type": "string",
"description": (
"Filter by domain (e.g. 'light', 'climate', 'switch'). "
"Omit to list services for all domains."
"Filter by domain (e.g. 'light', 'climate', 'switch'). Omit to list services for all domains."
),
},
},
@@ -428,8 +440,7 @@ HA_CALL_SERVICE_SCHEMA = {
"entity_id": {
"type": "string",
"description": (
"Target entity ID (e.g. 'light.living_room'). "
"Some services (like scene.turn_on) may not need this."
"Target entity ID (e.g. 'light.living_room'). Some services (like scene.turn_on) may not need this."
),
},
"data": {

View File

@@ -65,6 +65,7 @@ HONCHO_TOOL_SCHEMA = {
# ── Tool handler ──
def _handle_query_user_context(args: dict, **kw) -> str:
"""Execute the Honcho context query."""
query = args.get("query", "")
@@ -84,6 +85,7 @@ def _handle_query_user_context(args: dict, **kw) -> str:
# ── Availability check ──
def _check_honcho_available() -> bool:
"""Tool is only available when Honcho is active."""
return _session_manager is not None and _session_key is not None

View File

@@ -2,7 +2,7 @@
"""
Image Generation Tools Module
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
Available tools:
@@ -19,7 +19,7 @@ Features:
Usage:
from image_generation_tool import image_generate_tool
import asyncio
# Generate and automatically upscale an image
result = await image_generate_tool(
prompt="A serene mountain landscape with cherry blossoms",
@@ -28,12 +28,14 @@ Usage:
)
"""
import datetime
import json
import logging
import os
import datetime
from typing import Dict, Any, Optional, Union
from typing import Any
import fal_client
from tools.debug_helpers import DebugSession
logger = logging.getLogger(__name__)
@@ -51,11 +53,7 @@ ENABLE_SAFETY_CHECKER = False
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
# Aspect ratio mapping - simplified choices for model to select
ASPECT_RATIO_MAP = {
"landscape": "landscape_16_9",
"square": "square_hd",
"portrait": "portrait_16_9"
}
ASPECT_RATIO_MAP = {"landscape": "landscape_16_9", "square": "square_hd", "portrait": "portrait_16_9"}
VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys())
# Configuration for automatic upscaling
@@ -70,9 +68,7 @@ UPSCALER_GUIDANCE_SCALE = 4
UPSCALER_NUM_INFERENCE_STEPS = 18
# Valid parameter values for validation based on FLUX 2 Pro documentation
VALID_IMAGE_SIZES = [
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
]
VALID_IMAGE_SIZES = ["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"]
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
@@ -80,16 +76,16 @@ _debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
def _validate_parameters(
image_size: Union[str, Dict[str, int]],
image_size: str | dict[str, int],
num_inference_steps: int,
guidance_scale: float,
num_images: int,
output_format: str,
acceleration: str = "none"
) -> Dict[str, Any]:
acceleration: str = "none",
) -> dict[str, Any]:
"""
Validate and normalize image generation parameters for FLUX 2 Pro model.
Args:
image_size: Either a preset string or custom size dict
num_inference_steps: Number of inference steps
@@ -97,15 +93,15 @@ def _validate_parameters(
num_images: Number of images to generate
output_format: Output format for images
acceleration: Acceleration mode for generation speed
Returns:
Dict[str, Any]: Validated and normalized parameters
Raises:
ValueError: If any parameter is invalid
"""
validated = {}
# Validate image_size
if isinstance(image_size, str):
if image_size not in VALID_IMAGE_SIZES:
@@ -123,52 +119,52 @@ def _validate_parameters(
validated["image_size"] = image_size
else:
raise ValueError("image_size must be either a preset string or a dict with width/height")
# Validate num_inference_steps
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
raise ValueError("num_inference_steps must be an integer between 1 and 100")
validated["num_inference_steps"] = num_inference_steps
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
validated["guidance_scale"] = float(guidance_scale)
# Validate num_images
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
raise ValueError("num_images must be an integer between 1 and 4")
validated["num_images"] = num_images
# Validate output_format
if output_format not in VALID_OUTPUT_FORMATS:
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
validated["output_format"] = output_format
# Validate acceleration
if acceleration not in VALID_ACCELERATION_MODES:
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
validated["acceleration"] = acceleration
return validated
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
def _upscale_image(image_url: str, original_prompt: str) -> dict[str, Any]:
"""
Upscale an image using FAL.ai's Clarity Upscaler.
Uses the synchronous fal_client API to avoid event loop lifecycle issues
when called from threaded contexts (e.g. gateway thread pool).
Args:
image_url (str): URL of the image to upscale
original_prompt (str): Original prompt used to generate the image
Returns:
Dict[str, Any]: Upscaled image data or None if upscaling fails
"""
try:
logger.info("Upscaling image with Clarity Upscaler...")
# Prepare arguments for upscaler
upscaler_arguments = {
"image_url": image_url,
@@ -179,35 +175,36 @@ def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
"resemblance": UPSCALER_RESEMBLANCE,
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
"enable_safety_checker": UPSCALER_SAFETY_CHECKER,
}
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
# The async API (submit_async) caches a global httpx.AsyncClient via
# @cached_property, which breaks when asyncio.run() destroys the loop
# between calls (gateway thread-pool pattern).
handler = fal_client.submit(
UPSCALER_MODEL,
arguments=upscaler_arguments
)
handler = fal_client.submit(UPSCALER_MODEL, arguments=upscaler_arguments)
# Get the upscaled result (sync — blocks until done)
result = handler.get()
if result and "image" in result:
upscaled_image = result["image"]
logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown'))
logger.info(
"Image upscaled successfully to %sx%s",
upscaled_image.get("width", "unknown"),
upscaled_image.get("height", "unknown"),
)
return {
"url": upscaled_image["url"],
"width": upscaled_image.get("width", 0),
"height": upscaled_image.get("height", 0),
"upscaled": True,
"upscale_factor": UPSCALER_FACTOR
"upscale_factor": UPSCALER_FACTOR,
}
else:
logger.error("Upscaler returned invalid response")
return None
except Exception as e:
logger.error("Error upscaling image: %s", e)
return None
@@ -220,16 +217,16 @@ def image_generate_tool(
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
num_images: int = DEFAULT_NUM_IMAGES,
output_format: str = DEFAULT_OUTPUT_FORMAT,
seed: Optional[int] = None
seed: int | None = None,
) -> str:
"""
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
when asyncio.run() destroys and recreates event loops between calls, which
happens in the gateway's thread-pool pattern.
Args:
prompt (str): The text prompt describing the desired image
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
@@ -238,7 +235,7 @@ def image_generate_tool(
num_images (int): Number of images to generate (1-4, default: 1)
output_format (str): Image format "jpeg" or "png" (default: "png")
seed (Optional[int]): Random seed for reproducible results (optional)
Returns:
str: JSON string containing minimal generation results:
{
@@ -252,7 +249,7 @@ def image_generate_tool(
logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO)
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
debug_call_data = {
"parameters": {
"prompt": prompt,
@@ -262,32 +259,32 @@ def image_generate_tool(
"guidance_scale": guidance_scale,
"num_images": num_images,
"output_format": output_format,
"seed": seed
"seed": seed,
},
"error": None,
"success": False,
"images_generated": 0,
"generation_time": 0
"generation_time": 0,
}
start_time = datetime.datetime.now()
try:
logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80])
# Validate prompt
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
raise ValueError("Prompt is required and must be a non-empty string")
# Check API key availability
if not os.getenv("FAL_KEY"):
raise ValueError("FAL_KEY environment variable not set")
# Validate other parameters
validated_params = _validate_parameters(
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
)
# Prepare arguments for FAL.ai FLUX 2 Pro API
arguments = {
"prompt": prompt.strip(),
@@ -298,51 +295,44 @@ def image_generate_tool(
"output_format": validated_params["output_format"],
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
"safety_tolerance": SAFETY_TOLERANCE,
"sync_mode": True # Use sync mode for immediate results
"sync_mode": True, # Use sync mode for immediate results
}
# Add seed if provided
if seed is not None and isinstance(seed, int):
arguments["seed"] = seed
logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...")
logger.info(" Model: %s", DEFAULT_MODEL)
logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size)
logger.info(" Steps: %s", validated_params['num_inference_steps'])
logger.info(" Guidance: %s", validated_params['guidance_scale'])
logger.info(" Steps: %s", validated_params["num_inference_steps"])
logger.info(" Guidance: %s", validated_params["guidance_scale"])
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
handler = fal_client.submit(
DEFAULT_MODEL,
arguments=arguments
)
handler = fal_client.submit(DEFAULT_MODEL, arguments=arguments)
# Get the result (sync — blocks until done)
result = handler.get()
generation_time = (datetime.datetime.now() - start_time).total_seconds()
# Process the response
if not result or "images" not in result:
raise ValueError("Invalid response from FAL.ai API - no images returned")
images = result.get("images", [])
if not images:
raise ValueError("No images were generated")
# Format image data and upscale images
formatted_images = []
for img in images:
if isinstance(img, dict) and "url" in img:
original_image = {
"url": img["url"],
"width": img.get("width", 0),
"height": img.get("height", 0)
}
original_image = {"url": img["url"], "width": img.get("width", 0), "height": img.get("height", 0)}
# Attempt to upscale the image
upscaled_image = _upscale_image(img["url"], prompt.strip())
if upscaled_image:
# Use upscaled image if successful
formatted_images.append(upscaled_image)
@@ -351,52 +341,48 @@ def image_generate_tool(
logger.warning("Using original image as fallback")
original_image["upscaled"] = False
formatted_images.append(original_image)
if not formatted_images:
raise ValueError("No valid image URLs returned from API")
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count)
logger.info(
"Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count
)
# Prepare successful response - minimal format
response_data = {
"success": True,
"image": formatted_images[0]["url"] if formatted_images else None
}
response_data = {"success": True, "image": formatted_images[0]["url"] if formatted_images else None}
debug_call_data["success"] = True
debug_call_data["images_generated"] = len(formatted_images)
debug_call_data["generation_time"] = generation_time
# Log debug information
_debug.log_call("image_generate_tool", debug_call_data)
_debug.save()
return json.dumps(response_data, indent=2, ensure_ascii=False)
except Exception as e:
generation_time = (datetime.datetime.now() - start_time).total_seconds()
error_msg = f"Error generating image: {str(e)}"
logger.error("%s", error_msg)
# Prepare error response - minimal format
response_data = {
"success": False,
"image": None
}
response_data = {"success": False, "image": None}
debug_call_data["error"] = error_msg
debug_call_data["generation_time"] = generation_time
_debug.log_call("image_generate_tool", debug_call_data)
_debug.save()
return json.dumps(response_data, indent=2, ensure_ascii=False)
def check_fal_api_key() -> bool:
"""
Check if the FAL.ai API key is available in environment variables.
Returns:
bool: True if API key is set, False otherwise
"""
@@ -406,7 +392,7 @@ def check_fal_api_key() -> bool:
def check_image_generation_requirements() -> bool:
"""
Check if all requirements for image generation tools are met.
Returns:
bool: True if requirements are met, False otherwise
"""
@@ -414,19 +400,20 @@ def check_image_generation_requirements() -> bool:
# Check API key
if not check_fal_api_key():
return False
# Check if fal_client is available
import fal_client
return True
except ImportError:
return False
def get_debug_session_info() -> Dict[str, Any]:
def get_debug_session_info() -> dict[str, Any]:
"""
Get information about the current debug session.
Returns:
Dict[str, Any]: Dictionary containing debug session information
"""
@@ -439,10 +426,10 @@ if __name__ == "__main__":
"""
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
print("=" * 60)
# Check if API key is available
api_available = check_fal_api_key()
if not api_available:
print("❌ FAL_KEY environment variable not set")
print("Please set your API key: export FAL_KEY='your-key-here'")
@@ -450,27 +437,28 @@ if __name__ == "__main__":
exit(1)
else:
print("✅ FAL.ai API key found")
# Check if fal_client is available
try:
import fal_client
print("✅ fal_client library available")
except ImportError:
print("❌ fal_client library not found")
print("Please install: pip install fal-client")
exit(1)
print("🛠️ Image generation tools ready for use!")
print(f"🤖 Using model: {DEFAULT_MODEL}")
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
# Show debug mode status
if _debug.active:
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json")
else:
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
print("\nBasic usage:")
print(" from image_generation_tool import image_generate_tool")
print(" import asyncio")
@@ -484,23 +472,23 @@ if __name__ == "__main__":
print(" )")
print(" print(result)")
print(" asyncio.run(main())")
print("\nSupported image sizes:")
for size in VALID_IMAGE_SIZES:
print(f" - {size}")
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
print("\nAcceleration modes:")
for mode in VALID_ACCELERATION_MODES:
print(f" - {mode}")
print("\nExample prompts:")
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
print(" - 'Modern architecture building with glass facade, sunset lighting'")
print(" - 'Abstract art with vibrant colors and geometric patterns'")
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
print(" - 'Futuristic cityscape with flying cars and neon lights'")
print("\nDebug mode:")
print(" # Enable debug logging")
print(" export IMAGE_TOOLS_DEBUG=true")
@@ -521,17 +509,17 @@ IMAGE_GENERATE_SCHEMA = {
"properties": {
"prompt": {
"type": "string",
"description": "The text prompt describing the desired image. Be detailed and descriptive."
"description": "The text prompt describing the desired image. Be detailed and descriptive.",
},
"aspect_ratio": {
"type": "string",
"enum": ["landscape", "square", "portrait"],
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
"default": "landscape"
}
"default": "landscape",
},
},
"required": ["prompt"]
}
"required": ["prompt"],
},
}

View File

@@ -77,7 +77,7 @@ import os
import re
import threading
import time
from typing import Any, Dict, List, Optional
from typing import Any
logger = logging.getLogger(__name__)
@@ -91,9 +91,11 @@ _MCP_SAMPLING_TYPES = False
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
_MCP_AVAILABLE = True
try:
from mcp.client.streamable_http import streamablehttp_client
_MCP_HTTP_AVAILABLE = True
except ImportError:
_MCP_HTTP_AVAILABLE = False
@@ -108,6 +110,7 @@ try:
TextContent,
ToolUseContent,
)
_MCP_SAMPLING_TYPES = True
except ImportError:
logger.debug("MCP sampling types not available -- sampling disabled")
@@ -118,27 +121,36 @@ except ImportError:
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
_MAX_RECONNECT_RETRIES = 5
_MAX_BACKOFF_SECONDS = 60
# Environment variables that are safe to pass to stdio subprocesses
_SAFE_ENV_KEYS = frozenset({
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
})
_SAFE_ENV_KEYS = frozenset(
{
"PATH",
"HOME",
"USER",
"LANG",
"LC_ALL",
"TERM",
"SHELL",
"TMPDIR",
}
)
# Regex for credential patterns to strip from error messages
_CREDENTIAL_PATTERN = re.compile(
r"(?:"
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
r"|Bearer\s+\S+" # Bearer token
r"|token=[^\s&,;\"']{1,255}" # token=...
r"|key=[^\s&,;\"']{1,255}" # key=...
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
r"|password=[^\s&,;\"']{1,255}" # password=...
r"|secret=[^\s&,;\"']{1,255}" # secret=...
r")",
re.IGNORECASE,
)
@@ -148,7 +160,8 @@ _CREDENTIAL_PATTERN = re.compile(
# Security helpers
# ---------------------------------------------------------------------------
def _build_safe_env(user_env: Optional[dict]) -> dict:
def _build_safe_env(user_env: dict | None) -> dict:
"""Build a filtered environment dict for stdio subprocesses.
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
@@ -180,6 +193,7 @@ def _sanitize_error(text: str) -> str:
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
def _safe_numeric(value, default, coerce=int, minimum=1):
"""Coerce a config value to a numeric type, returning *default* on failure.
@@ -216,18 +230,22 @@ class SamplingHandler:
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
self.max_tool_rounds = _safe_numeric(
config.get("max_tool_rounds", 5), 5, int, minimum=0,
config.get("max_tool_rounds", 5),
5,
int,
minimum=0,
)
self.model_override = config.get("model")
self.allowed_models = config.get("allowed_models", [])
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
self.audit_level = _log_levels.get(
str(config.get("log_level", "info")).lower(), logging.INFO,
str(config.get("log_level", "info")).lower(),
logging.INFO,
)
# Per-instance state
self._rate_timestamps: List[float] = []
self._rate_timestamps: list[float] = []
self._tool_loop_count = 0
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
@@ -245,7 +263,7 @@ class SamplingHandler:
# -- Model resolution ----------------------------------------------------
def _resolve_model(self, preferences) -> Optional[str]:
def _resolve_model(self, preferences) -> str | None:
"""Config override > server hint > None (use default)."""
if self.model_override:
return self.model_override
@@ -265,7 +283,7 @@ class SamplingHandler:
items = block.content if isinstance(block.content, list) else [block.content]
return "\n".join(item.text for item in items if hasattr(item, "text"))
def _convert_messages(self, params) -> List[dict]:
def _convert_messages(self, params) -> list[dict]:
"""Convert MCP SamplingMessages to OpenAI format.
Uses ``msg.content_as_list`` (SDK helper) so single-block and
@@ -273,37 +291,47 @@ class SamplingHandler:
with ``isinstance`` on real SDK types when available, falling back
to duck-typing via ``hasattr`` for compatibility.
"""
messages: List[dict] = []
messages: list[dict] = []
for msg in params.messages:
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
msg.content if isinstance(msg.content, list) else [msg.content]
blocks = (
msg.content_as_list
if hasattr(msg, "content_as_list")
else (msg.content if isinstance(msg.content, list) else [msg.content])
)
# Separate blocks by kind
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
tool_uses = [
b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")
]
content_blocks = [
b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))
]
# Emit tool result messages (role: tool)
for tr in tool_results:
messages.append({
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
})
messages.append(
{
"role": "tool",
"tool_call_id": tr.toolUseId,
"content": self._extract_tool_result_text(tr),
}
)
# Emit assistant tool_calls message
if tool_uses:
tc_list = []
for tu in tool_uses:
tc_list.append({
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
})
tc_list.append(
{
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
"type": "function",
"function": {
"name": tu.name,
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
},
}
)
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
# Include any accompanying text
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
@@ -320,10 +348,12 @@ class SamplingHandler:
if hasattr(block, "text"):
parts.append({"type": "text", "text": block.text})
elif hasattr(block, "data") and hasattr(block, "mimeType"):
parts.append({
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
})
parts.append(
{
"type": "image_url",
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
}
)
else:
logger.warning(
"Unsupported sampling content block type: %s (skipped)",
@@ -352,16 +382,13 @@ class SamplingHandler:
# Tool loop governance
if self.max_tool_rounds == 0:
self._tool_loop_count = 0
return self._error(
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
)
return self._error(f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)")
self._tool_loop_count += 1
if self._tool_loop_count > self.max_tool_rounds:
self._tool_loop_count = 0
return self._error(
f"Tool loop limit exceeded for server '{self.server_name}' "
f"(max {self.max_tool_rounds} rounds)"
f"Tool loop limit exceeded for server '{self.server_name}' (max {self.max_tool_rounds} rounds)"
)
content_blocks = []
@@ -372,25 +399,28 @@ class SamplingHandler:
parsed = json.loads(args)
except (json.JSONDecodeError, ValueError):
logger.warning(
"MCP server '%s': malformed tool_calls arguments "
"from LLM (wrapping as raw): %.100s",
self.server_name, args,
"MCP server '%s': malformed tool_calls arguments from LLM (wrapping as raw): %.100s",
self.server_name,
args,
)
parsed = {"_raw": args}
else:
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
content_blocks.append(ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
))
content_blocks.append(
ToolUseContent(
type="tool_use",
id=tc.id,
name=tc.function.name,
input=parsed,
)
)
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
self.server_name, response.model,
self.server_name,
response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
len(content_blocks),
)
@@ -410,7 +440,8 @@ class SamplingHandler:
logger.log(
self.audit_level,
"MCP server '%s' sampling response: model=%s, tokens=%s",
self.server_name, response.model,
self.server_name,
response.model,
getattr(getattr(response, "usage", None), "total_tokens", "?"),
)
@@ -445,12 +476,12 @@ class SamplingHandler:
if not self._check_rate_limit():
logger.warning(
"MCP server '%s' sampling rate limit exceeded (%d/min)",
self.server_name, self.max_rpm,
self.server_name,
self.max_rpm,
)
self.metrics["errors"] += 1
return self._error(
f"Sampling rate limit exceeded for server '{self.server_name}' "
f"({self.max_rpm} requests/minute)"
f"Sampling rate limit exceeded for server '{self.server_name}' ({self.max_rpm} requests/minute)"
)
# Resolve model
@@ -458,6 +489,7 @@ class SamplingHandler:
# Get auxiliary LLM client
from agent.auxiliary_client import get_text_auxiliary_client
client, default_model = get_text_auxiliary_client()
if client is None:
self.metrics["errors"] += 1
@@ -469,7 +501,8 @@ class SamplingHandler:
if self.allowed_models and resolved_model not in self.allowed_models:
logger.warning(
"MCP server '%s' requested model '%s' not in allowed_models",
self.server_name, resolved_model,
self.server_name,
resolved_model,
)
self.metrics["errors"] += 1
return self._error(
@@ -515,7 +548,10 @@ class SamplingHandler:
logger.log(
self.audit_level,
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
self.server_name, resolved_model, max_tokens, len(messages),
self.server_name,
resolved_model,
max_tokens,
len(messages),
)
# Offload sync LLM call to thread (non-blocking)
@@ -524,19 +560,15 @@ class SamplingHandler:
try:
response = await asyncio.wait_for(
asyncio.to_thread(_sync_call), timeout=self.timeout,
asyncio.to_thread(_sync_call),
timeout=self.timeout,
)
except asyncio.TimeoutError:
except TimeoutError:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call timed out after {self.timeout}s "
f"for server '{self.server_name}'"
)
return self._error(f"Sampling LLM call timed out after {self.timeout}s for server '{self.server_name}'")
except Exception as exc:
self.metrics["errors"] += 1
return self._error(
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
)
return self._error(f"Sampling LLM call failed: {_sanitize_error(str(exc))}")
# Track metrics
choice = response.choices[0]
@@ -546,11 +578,7 @@ class SamplingHandler:
self.metrics["tokens_used"] += total_tokens
# Dispatch based on response type
if (
choice.finish_reason == "tool_calls"
and hasattr(choice.message, "tool_calls")
and choice.message.tool_calls
):
if choice.finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
return self._build_tool_use_result(choice, response)
return self._build_text_result(choice, response)
@@ -560,6 +588,7 @@ class SamplingHandler:
# Server task -- each MCP server lives in one long-lived asyncio Task
# ---------------------------------------------------------------------------
class MCPServerTask:
"""Manages a single MCP server connection in a dedicated asyncio Task.
@@ -571,22 +600,29 @@ class MCPServerTask:
"""
__slots__ = (
"name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
"name",
"session",
"tool_timeout",
"_task",
"_ready",
"_shutdown_event",
"_tools",
"_error",
"_config",
"_sampling",
)
def __init__(self, name: str):
self.name = name
self.session: Optional[Any] = None
self.session: Any | None = None
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
self._task: Optional[asyncio.Task] = None
self._task: asyncio.Task | None = None
self._ready = asyncio.Event()
self._shutdown_event = asyncio.Event()
self._tools: list = []
self._error: Optional[Exception] = None
self._error: Exception | None = None
self._config: dict = {}
self._sampling: Optional[SamplingHandler] = None
self._sampling: SamplingHandler | None = None
def _is_http(self) -> bool:
"""Check if this server uses HTTP transport."""
@@ -599,9 +635,7 @@ class MCPServerTask:
user_env = config.get("env")
if not command:
raise ValueError(
f"MCP server '{self.name}' has no 'command' in config"
)
raise ValueError(f"MCP server '{self.name}' has no 'command' in config")
safe_env = _build_safe_env(user_env)
server_params = StdioServerParameters(
@@ -650,11 +684,7 @@ class MCPServerTask:
if self.session is None:
return
tools_result = await self.session.list_tools()
self._tools = (
tools_result.tools
if hasattr(tools_result, "tools")
else []
)
self._tools = tools_result.tools if hasattr(tools_result, "tools") else []
async def run(self, config: dict):
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
@@ -704,24 +734,28 @@ class MCPServerTask:
if self._shutdown_event.is_set():
logger.debug(
"MCP server '%s' disconnected during shutdown: %s",
self.name, exc,
self.name,
exc,
)
return
retries += 1
if retries > _MAX_RECONNECT_RETRIES:
logger.warning(
"MCP server '%s' failed after %d reconnection attempts, "
"giving up: %s",
self.name, _MAX_RECONNECT_RETRIES, exc,
"MCP server '%s' failed after %d reconnection attempts, giving up: %s",
self.name,
_MAX_RECONNECT_RETRIES,
exc,
)
return
logger.warning(
"MCP server '%s' connection lost (attempt %d/%d), "
"reconnecting in %.0fs: %s",
self.name, retries, _MAX_RECONNECT_RETRIES,
backoff, exc,
"MCP server '%s' connection lost (attempt %d/%d), reconnecting in %.0fs: %s",
self.name,
retries,
_MAX_RECONNECT_RETRIES,
backoff,
exc,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
@@ -745,7 +779,7 @@ class MCPServerTask:
if self._task and not self._task.done():
try:
await asyncio.wait_for(self._task, timeout=10)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"MCP server '%s' shutdown timed out, cancelling task",
self.name,
@@ -762,11 +796,11 @@ class MCPServerTask:
# Module-level state
# ---------------------------------------------------------------------------
_servers: Dict[str, MCPServerTask] = {}
_servers: dict[str, MCPServerTask] = {}
# Dedicated event loop running in a background daemon thread.
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None
_mcp_loop: asyncio.AbstractEventLoop | None = None
_mcp_thread: threading.Thread | None = None
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
_lock = threading.Lock()
@@ -801,7 +835,8 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
# Config loading
# ---------------------------------------------------------------------------
def _load_mcp_config() -> Dict[str, dict]:
def _load_mcp_config() -> dict[str, dict]:
"""Read ``mcp_servers`` from the Hermes config file.
Returns a dict of ``{server_name: server_config}`` or empty dict.
@@ -811,6 +846,7 @@ def _load_mcp_config() -> Dict[str, dict]:
"""
try:
from hermes_cli.config import load_config
config = load_config()
servers = config.get("mcp_servers")
if not servers or not isinstance(servers, dict):
@@ -825,6 +861,7 @@ def _load_mcp_config() -> Dict[str, dict]:
# Server connection helper
# ---------------------------------------------------------------------------
async def _connect_server(name: str, config: dict) -> MCPServerTask:
"""Create an MCPServerTask, start it, and return when ready.
@@ -845,6 +882,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
# Handler / check-fn factories
# ---------------------------------------------------------------------------
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
"""Return a sync handler that calls an MCP tool via the background loop.
@@ -856,27 +894,21 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.call_tool(tool_name, arguments=args)
# MCP CallToolResult has .content (list of content blocks) and .isError
if result.isError:
error_text = ""
for block in (result.content or []):
for block in result.content or []:
if hasattr(block, "text"):
error_text += block.text
return json.dumps({
"error": _sanitize_error(
error_text or "MCP tool returned an error"
)
})
return json.dumps({"error": _sanitize_error(error_text or "MCP tool returned an error")})
# Collect text from content blocks
parts: List[str] = []
for block in (result.content or []):
parts: list[str] = []
for block in result.content or []:
if hasattr(block, "text"):
parts.append(block.text)
return json.dumps({"result": "\n".join(parts) if parts else ""})
@@ -886,13 +918,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
except Exception as exc:
logger.error(
"MCP tool %s/%s call failed: %s",
server_name, tool_name, exc,
server_name,
tool_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@@ -904,14 +934,12 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.list_resources()
resources = []
for r in (result.resources if hasattr(result, "resources") else []):
for r in result.resources if hasattr(result, "resources") else []:
entry = {}
if hasattr(r, "uri"):
entry["uri"] = str(r.uri)
@@ -928,13 +956,11 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/list_resources failed: %s", server_name, exc,
"MCP %s/list_resources failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@@ -946,9 +972,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
uri = args.get("uri")
if not uri:
@@ -957,7 +981,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
async def _call():
result = await server.session.read_resource(uri)
# read_resource returns ReadResourceResult with .contents list
parts: List[str] = []
parts: list[str] = []
contents = result.contents if hasattr(result, "contents") else []
for block in contents:
if hasattr(block, "text"):
@@ -970,13 +994,11 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/read_resource failed: %s", server_name, exc,
"MCP %s/read_resource failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@@ -988,14 +1010,12 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
async def _call():
result = await server.session.list_prompts()
prompts = []
for p in (result.prompts if hasattr(result, "prompts") else []):
for p in result.prompts if hasattr(result, "prompts") else []:
entry = {}
if hasattr(p, "name"):
entry["name"] = p.name
@@ -1017,13 +1037,11 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/list_prompts failed: %s", server_name, exc,
"MCP %s/list_prompts failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@@ -1035,9 +1053,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
with _lock:
server = _servers.get(server_name)
if not server or not server.session:
return json.dumps({
"error": f"MCP server '{server_name}' is not connected"
})
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
name = args.get("name")
if not name:
@@ -1048,7 +1064,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
result = await server.session.get_prompt(name, arguments=arguments)
# GetPromptResult has .messages list
messages = []
for msg in (result.messages if hasattr(result, "messages") else []):
for msg in result.messages if hasattr(result, "messages") else []:
entry = {}
if hasattr(msg, "role"):
entry["role"] = msg.role
@@ -1070,13 +1086,11 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except Exception as exc:
logger.error(
"MCP %s/get_prompt failed: %s", server_name, exc,
"MCP %s/get_prompt failed: %s",
server_name,
exc,
)
return json.dumps({
"error": _sanitize_error(
f"MCP call failed: {type(exc).__name__}: {exc}"
)
})
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
return _handler
@@ -1096,6 +1110,7 @@ def _make_check_fn(server_name: str):
# Discovery & registration
# ---------------------------------------------------------------------------
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
"""Convert an MCP tool listing to the Hermes registry schema format.
@@ -1114,14 +1129,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
return {
"name": prefixed_name,
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
"parameters": mcp_tool.inputSchema
if mcp_tool.inputSchema
else {
"type": "object",
"properties": {},
},
}
def _build_utility_schemas(server_name: str) -> List[dict]:
def _build_utility_schemas(server_name: str) -> list[dict]:
"""Build schemas for the MCP utility tools (resources & prompts).
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
@@ -1192,9 +1209,9 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
]
def _existing_tool_names() -> List[str]:
def _existing_tool_names() -> list[str]:
"""Return tool names for all currently connected servers."""
names: List[str] = []
names: list[str] = []
for sname, server in _servers.items():
for mcp_tool in server._tools:
schema = _convert_mcp_schema(sname, mcp_tool)
@@ -1205,7 +1222,7 @@ def _existing_tool_names() -> List[str]:
return names
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
async def _discover_and_register_server(name: str, config: dict) -> list[str]:
"""Connect to a single MCP server, discover tools, and register them.
Also registers utility tools for MCP Resources and Prompts support
@@ -1224,7 +1241,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
with _lock:
_servers[name] = server
registered_names: List[str] = []
registered_names: list[str] = []
toolset_name = f"mcp-{name}"
for mcp_tool in server._tools:
@@ -1277,7 +1294,9 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
transport_type = "HTTP" if "url" in config else "stdio"
logger.info(
"MCP server '%s' (%s): registered %d tool(s): %s",
name, transport_type, len(registered_names),
name,
transport_type,
len(registered_names),
", ".join(registered_names),
)
return registered_names
@@ -1287,7 +1306,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
# Public API
# ---------------------------------------------------------------------------
def discover_mcp_tools() -> List[str]:
def discover_mcp_tools() -> list[str]:
"""Entry point: load config, connect to MCP servers, register tools.
Called from ``model_tools._discover_tools()``. Safe to call even when
@@ -1318,12 +1338,12 @@ def discover_mcp_tools() -> List[str]:
# Start the background event loop for MCP connections
_ensure_mcp_loop()
all_tools: List[str] = []
all_tools: list[str] = []
failed_count = 0
async def _discover_one(name: str, cfg: dict) -> List[str]:
async def _discover_one(name: str, cfg: dict) -> list[str]:
"""Connect to a single server and return its registered tool names."""
transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}')
transport_desc = cfg.get("url", f"{cfg.get('command', '?')} {' '.join(cfg.get('args', [])[:2])}")
try:
registered = await _discover_and_register_server(name, cfg)
transport_type = "HTTP" if "url" in cfg else "stdio"
@@ -1331,7 +1351,8 @@ def discover_mcp_tools() -> List[str]:
except Exception as exc:
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, exc,
name,
exc,
)
return []
@@ -1358,6 +1379,7 @@ def discover_mcp_tools() -> List[str]:
if all_tools:
# Dynamically inject into all hermes-* platform toolsets
from toolsets import TOOLSETS
for ts_name, ts in TOOLSETS.items():
if ts_name.startswith("hermes-"):
for tool_name in all_tools:
@@ -1377,13 +1399,13 @@ def discover_mcp_tools() -> List[str]:
return _existing_tool_names()
def get_mcp_status() -> List[dict]:
def get_mcp_status() -> list[dict]:
"""Return status of all configured MCP servers for banner display.
Returns a list of dicts with keys: name, transport, tools, connected.
Includes both successfully connected servers and configured-but-failed ones.
"""
result: List[dict] = []
result: list[dict] = []
# Get configured servers from config
configured = _load_mcp_config()
@@ -1407,12 +1429,14 @@ def get_mcp_status() -> List[dict]:
entry["sampling"] = dict(server._sampling.metrics)
result.append(entry)
else:
result.append({
"name": name,
"transport": transport,
"tools": 0,
"connected": False,
})
result.append(
{
"name": name,
"transport": transport,
"tools": 0,
"connected": False,
}
)
return result
@@ -1440,7 +1464,9 @@ def shutdown_mcp_servers():
for server, result in zip(servers_snapshot, results):
if isinstance(result, Exception):
logger.debug(
"Error closing MCP server '%s': %s", server.name, result,
"Error closing MCP server '%s': %s",
server.name,
result,
)
with _lock:
_servers.clear()

View File

@@ -29,7 +29,7 @@ import os
import re
import tempfile
from pathlib import Path
from typing import Dict, Any, List, Optional
from typing import Any
logger = logging.getLogger(__name__)
@@ -46,30 +46,38 @@ ENTRY_DELIMITER = "\n§\n"
_MEMORY_THREAT_PATTERNS = [
# Prompt injection
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
(r'you\s+are\s+now\s+', "role_hijack"),
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
(r'system\s+prompt\s+override', "sys_prompt_override"),
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
(r"you\s+are\s+now\s+", "role_hijack"),
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
(r"system\s+prompt\s+override", "sys_prompt_override"),
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
# Exfiltration via curl/wget with secrets
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', "read_secrets"),
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", "read_secrets"),
# Persistence via shell rc
(r'authorized_keys', "ssh_backdoor"),
(r'\$HOME/\.ssh|\~/\.ssh', "ssh_access"),
(r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', "hermes_env"),
(r"authorized_keys", "ssh_backdoor"),
(r"\$HOME/\.ssh|\~/\.ssh", "ssh_access"),
(r"\$HOME/\.hermes/\.env|\~/\.hermes/\.env", "hermes_env"),
]
# Subset of invisible chars for injection detection
_INVISIBLE_CHARS = {
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
"\u200b",
"\u200c",
"\u200d",
"\u2060",
"\ufeff",
"\u202a",
"\u202b",
"\u202c",
"\u202d",
"\u202e",
}
def _scan_memory_content(content: str) -> Optional[str]:
def _scan_memory_content(content: str) -> str | None:
"""Scan memory content for injection/exfil patterns. Returns error string if blocked."""
# Check invisible unicode
for char in _INVISIBLE_CHARS:
@@ -96,12 +104,12 @@ class MemoryStore:
"""
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375):
self.memory_entries: List[str] = []
self.user_entries: List[str] = []
self.memory_entries: list[str] = []
self.user_entries: list[str] = []
self.memory_char_limit = memory_char_limit
self.user_char_limit = user_char_limit
# Frozen snapshot for system prompt -- set once at load_from_disk()
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""}
self._system_prompt_snapshot: dict[str, str] = {"memory": "", "user": ""}
def load_from_disk(self):
"""Load entries from MEMORY.md and USER.md, capture system prompt snapshot."""
@@ -129,12 +137,12 @@ class MemoryStore:
elif target == "user":
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
def _entries_for(self, target: str) -> List[str]:
def _entries_for(self, target: str) -> list[str]:
if target == "user":
return self.user_entries
return self.memory_entries
def _set_entries(self, target: str, entries: List[str]):
def _set_entries(self, target: str, entries: list[str]):
if target == "user":
self.user_entries = entries
else:
@@ -151,7 +159,7 @@ class MemoryStore:
return self.user_char_limit
return self.memory_char_limit
def add(self, target: str, content: str) -> Dict[str, Any]:
def add(self, target: str, content: str) -> dict[str, Any]:
"""Append a new entry. Returns error if it would exceed the char limit."""
content = content.strip()
if not content:
@@ -192,7 +200,7 @@ class MemoryStore:
return self._success_response(target, "Entry added.")
def replace(self, target: str, old_text: str, new_content: str) -> Dict[str, Any]:
def replace(self, target: str, old_text: str, new_content: str) -> dict[str, Any]:
"""Find entry containing old_text substring, replace it with new_content."""
old_text = old_text.strip()
new_content = new_content.strip()
@@ -247,7 +255,7 @@ class MemoryStore:
return self._success_response(target, "Entry replaced.")
def remove(self, target: str, old_text: str) -> Dict[str, Any]:
def remove(self, target: str, old_text: str) -> dict[str, Any]:
"""Remove the entry containing old_text substring."""
old_text = old_text.strip()
if not old_text:
@@ -278,7 +286,7 @@ class MemoryStore:
return self._success_response(target, "Entry removed.")
def format_for_system_prompt(self, target: str) -> Optional[str]:
def format_for_system_prompt(self, target: str) -> str | None:
"""
Return the frozen snapshot for system prompt injection.
@@ -293,7 +301,7 @@ class MemoryStore:
# -- Internal helpers --
def _success_response(self, target: str, message: str = None) -> Dict[str, Any]:
def _success_response(self, target: str, message: str = None) -> dict[str, Any]:
entries = self._entries_for(target)
current = self._char_count(target)
limit = self._char_limit(target)
@@ -310,7 +318,7 @@ class MemoryStore:
resp["message"] = message
return resp
def _render_block(self, target: str, entries: List[str]) -> str:
def _render_block(self, target: str, entries: list[str]) -> str:
"""Render a system prompt block with header and usage indicator."""
if not entries:
return ""
@@ -329,7 +337,7 @@ class MemoryStore:
return f"{separator}\n{header}\n{separator}\n{content}"
@staticmethod
def _read_file(path: Path) -> List[str]:
def _read_file(path: Path) -> list[str]:
"""Read a memory file and split into entries.
No file locking needed: _write_file uses atomic rename, so readers
@@ -339,7 +347,7 @@ class MemoryStore:
return []
try:
raw = path.read_text(encoding="utf-8")
except (OSError, IOError):
except OSError:
return []
if not raw.strip():
@@ -351,7 +359,7 @@ class MemoryStore:
return [e for e in entries if e]
@staticmethod
def _write_file(path: Path, entries: List[str]):
def _write_file(path: Path, entries: list[str]):
"""Write entries to a memory file using atomic temp-file + rename.
Previous implementation used open("w") + flock, but "w" truncates the
@@ -362,9 +370,7 @@ class MemoryStore:
content = ENTRY_DELIMITER.join(entries) if entries else ""
try:
# Write to temp file in same directory (same filesystem for atomic rename)
fd, tmp_path = tempfile.mkstemp(
dir=str(path.parent), suffix=".tmp", prefix=".mem_"
)
fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".mem_")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(content)
@@ -378,7 +384,7 @@ class MemoryStore:
except OSError:
pass
raise
except (OSError, IOError) as e:
except OSError as e:
raise RuntimeError(f"Failed to write memory file {path}: {e}")
@@ -387,7 +393,7 @@ def memory_tool(
target: str = "memory",
content: str = None,
old_text: str = None,
store: Optional[MemoryStore] = None,
store: MemoryStore | None = None,
) -> str:
"""
Single entry point for the memory tool. Dispatches to MemoryStore methods.
@@ -395,10 +401,15 @@ def memory_tool(
Returns JSON string with results.
"""
if store is None:
return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "Memory is not available. It may be disabled in config or this environment."},
ensure_ascii=False,
)
if target not in ("memory", "user"):
return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False
)
if action == "add":
if not content:
@@ -407,18 +418,26 @@ def memory_tool(
elif action == "replace":
if not old_text:
return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False
)
if not content:
return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False
)
result = store.replace(target, old_text, content)
elif action == "remove":
if not old_text:
return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False
)
result = store.remove(target, old_text)
else:
return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False
)
return json.dumps(result, ensure_ascii=False)
@@ -457,23 +476,16 @@ MEMORY_SCHEMA = {
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["add", "replace", "remove"],
"description": "The action to perform."
},
"action": {"type": "string", "enum": ["add", "replace", "remove"], "description": "The action to perform."},
"target": {
"type": "string",
"enum": ["memory", "user"],
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile."
},
"content": {
"type": "string",
"description": "The entry content. Required for 'add' and 'replace'."
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile.",
},
"content": {"type": "string", "description": "The entry content. Required for 'add' and 'replace'."},
"old_text": {
"type": "string",
"description": "Short unique substring identifying the entry to replace or remove."
"description": "Short unique substring identifying the entry to replace or remove.",
},
},
"required": ["action", "target"],
@@ -493,10 +505,7 @@ registry.register(
target=args.get("target", "memory"),
content=args.get("content"),
old_text=args.get("old_text"),
store=kw.get("store")),
store=kw.get("store"),
),
check_fn=check_memory_requirements,
)

View File

@@ -38,21 +38,27 @@ Configuration:
Usage:
from mixture_of_agents_tool import mixture_of_agents_tool
import asyncio
# Process a complex query
result = await mixture_of_agents_tool(
user_prompt="Solve this complex mathematical proof..."
)
"""
import asyncio
import datetime
import json
import logging
import os
import asyncio
import datetime
from typing import Dict, Any, List, Optional
from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key
from typing import Any
from tools.debug_helpers import DebugSession
from tools.openrouter_client import (
check_api_key as check_openrouter_api_key,
)
from tools.openrouter_client import (
get_async_client as _get_openrouter_client,
)
logger = logging.getLogger(__name__)
@@ -60,9 +66,9 @@ logger = logging.getLogger(__name__)
# Reference models - these generate diverse initial responses in parallel (OpenRouter slugs)
REFERENCE_MODELS = [
"anthropic/claude-opus-4.5",
"google/gemini-3-pro-preview",
"google/gemini-3-pro-preview",
"openai/gpt-5.2-pro",
"deepseek/deepseek-v3.2"
"deepseek/deepseek-v3.2",
]
# Aggregator model - synthesizes reference responses into final output
@@ -83,18 +89,18 @@ Responses from models:"""
_debug = DebugSession("moa_tools", env_var="MOA_TOOLS_DEBUG")
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
def _construct_aggregator_prompt(system_prompt: str, responses: list[str]) -> str:
"""
Construct the final system prompt for the aggregator including all model responses.
Args:
system_prompt (str): Base system prompt for aggregation
responses (List[str]): List of responses from reference models
Returns:
str: Complete system prompt with enumerated responses
"""
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
response_text = "\n".join([f"{i + 1}. {response}" for i, response in enumerate(responses)])
return f"{system_prompt}\n\n{response_text}"
@@ -103,48 +109,43 @@ async def _run_reference_model_safe(
user_prompt: str,
temperature: float = REFERENCE_TEMPERATURE,
max_tokens: int = 32000,
max_retries: int = 6
max_retries: int = 6,
) -> tuple[str, str, bool]:
"""
Run a single reference model with retry logic and graceful failure handling.
Args:
model (str): Model identifier to use
user_prompt (str): The user's query
temperature (float): Sampling temperature for response generation
max_tokens (int): Maximum tokens in response
max_retries (int): Maximum number of retry attempts
Returns:
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
"""
for attempt in range(max_retries):
try:
logger.info("Querying %s (attempt %s/%s)", model, attempt + 1, max_retries)
# Build parameters for the API call
api_params = {
"model": model,
"messages": [{"role": "user", "content": user_prompt}],
"extra_body": {
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
}
# GPT models (especially gpt-4o-mini) don't support custom temperature values
# Only include temperature for non-GPT models
if not model.lower().startswith('gpt-'):
if not model.lower().startswith("gpt-"):
api_params["temperature"] = temperature
response = await _get_openrouter_client().chat.completions.create(**api_params)
content = response.choices[0].message.content.strip()
logger.info("%s responded (%s characters)", model, len(content))
return model, content, True
except Exception as e:
error_str = str(e)
# Log more detailed error information for debugging
@@ -154,7 +155,7 @@ async def _run_reference_model_safe(
logger.warning("%s rate limit error (attempt %s): %s", model, attempt + 1, error_str)
else:
logger.warning("%s unknown error (attempt %s): %s", model, attempt + 1, error_str)
if attempt < max_retries - 1:
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
sleep_time = min(2 ** (attempt + 1), 60)
@@ -167,60 +168,47 @@ async def _run_reference_model_safe(
async def _run_aggregator_model(
system_prompt: str,
user_prompt: str,
temperature: float = AGGREGATOR_TEMPERATURE,
max_tokens: int = None
system_prompt: str, user_prompt: str, temperature: float = AGGREGATOR_TEMPERATURE, max_tokens: int = None
) -> str:
"""
Run the aggregator model to synthesize the final response.
Args:
system_prompt (str): System prompt with all reference responses
user_prompt (str): Original user query
temperature (float): Focused temperature for consistent aggregation
max_tokens (int): Maximum tokens in final response
Returns:
str: Synthesized final response
"""
logger.info("Running aggregator model: %s", AGGREGATOR_MODEL)
# Build parameters for the API call
api_params = {
"model": AGGREGATOR_MODEL,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"extra_body": {
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
}
# GPT models (especially gpt-4o-mini) don't support custom temperature values
# Only include temperature for non-GPT models
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
if not AGGREGATOR_MODEL.lower().startswith("gpt-"):
api_params["temperature"] = temperature
response = await _get_openrouter_client().chat.completions.create(**api_params)
content = response.choices[0].message.content.strip()
logger.info("Aggregation complete (%s characters)", len(content))
return content
async def mixture_of_agents_tool(
user_prompt: str,
reference_models: Optional[List[str]] = None,
aggregator_model: Optional[str] = None
user_prompt: str, reference_models: list[str] | None = None, aggregator_model: str | None = None
) -> str:
"""
Process a complex query using the Mixture-of-Agents methodology.
This tool leverages multiple frontier language models to collaboratively solve
extremely difficult problems requiring intense reasoning. It's particularly
effective for:
@@ -229,16 +217,16 @@ async def mixture_of_agents_tool(
- Multi-step analytical reasoning tasks
- Problems requiring diverse domain expertise
- Tasks where single models show limitations
The MoA approach uses a fixed 2-layer architecture:
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
Args:
user_prompt (str): The complex query or problem to solve
reference_models (Optional[List[str]]): Custom reference models to use
aggregator_model (Optional[str]): Custom aggregator model to use
Returns:
str: JSON string containing the MoA results with the following structure:
{
@@ -250,12 +238,12 @@ async def mixture_of_agents_tool(
},
"processing_time": float
}
Raises:
Exception: If MoA processing fails or API key is not set
"""
start_time = datetime.datetime.now()
debug_call_data = {
"parameters": {
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
@@ -263,7 +251,7 @@ async def mixture_of_agents_tool(
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
"reference_temperature": REFERENCE_TEMPERATURE,
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
},
"error": None,
"success": False,
@@ -272,161 +260,152 @@ async def mixture_of_agents_tool(
"failed_models": [],
"final_response_length": 0,
"processing_time_seconds": 0,
"models_used": {}
"models_used": {},
}
try:
logger.info("Starting Mixture-of-Agents processing...")
logger.info("Query: %s", user_prompt[:100])
# Validate API key availability
if not os.getenv("OPENROUTER_API_KEY"):
raise ValueError("OPENROUTER_API_KEY environment variable not set")
# Use provided models or defaults
ref_models = reference_models or REFERENCE_MODELS
agg_model = aggregator_model or AGGREGATOR_MODEL
logger.info("Using %s reference models in 2-layer MoA architecture", len(ref_models))
# Layer 1: Generate diverse responses from reference models (with failure handling)
logger.info("Layer 1: Generating reference responses...")
model_results = await asyncio.gather(*[
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
for model in ref_models
])
model_results = await asyncio.gather(
*[_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) for model in ref_models]
)
# Separate successful and failed responses
successful_responses = []
failed_models = []
for model_name, content, success in model_results:
if success:
successful_responses.append(content)
else:
failed_models.append(model_name)
successful_count = len(successful_responses)
failed_count = len(failed_models)
logger.info("Reference model results: %s successful, %s failed", successful_count, failed_count)
if failed_models:
logger.warning("Failed models: %s", ', '.join(failed_models))
logger.warning("Failed models: %s", ", ".join(failed_models))
# Check if we have enough successful responses to proceed
if successful_count < MIN_SUCCESSFUL_REFERENCES:
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
raise ValueError(
f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses."
)
debug_call_data["reference_responses_count"] = successful_count
debug_call_data["failed_models_count"] = failed_count
debug_call_data["failed_models"] = failed_models
# Layer 2: Aggregate responses using the aggregator model
logger.info("Layer 2: Synthesizing final response...")
aggregator_system_prompt = _construct_aggregator_prompt(
AGGREGATOR_SYSTEM_PROMPT,
successful_responses
)
final_response = await _run_aggregator_model(
aggregator_system_prompt,
user_prompt,
AGGREGATOR_TEMPERATURE
)
aggregator_system_prompt = _construct_aggregator_prompt(AGGREGATOR_SYSTEM_PROMPT, successful_responses)
final_response = await _run_aggregator_model(aggregator_system_prompt, user_prompt, AGGREGATOR_TEMPERATURE)
# Calculate processing time
end_time = datetime.datetime.now()
processing_time = (end_time - start_time).total_seconds()
logger.info("MoA processing completed in %.2f seconds", processing_time)
# Prepare successful response (only final aggregated result, minimal fields)
result = {
"success": True,
"response": final_response,
"models_used": {
"reference_models": ref_models,
"aggregator_model": agg_model
}
"models_used": {"reference_models": ref_models, "aggregator_model": agg_model},
}
debug_call_data["success"] = True
debug_call_data["final_response_length"] = len(final_response)
debug_call_data["processing_time_seconds"] = processing_time
debug_call_data["models_used"] = result["models_used"]
# Log debug information
_debug.log_call("mixture_of_agents_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"Error in MoA processing: {str(e)}"
logger.error("%s", error_msg)
# Calculate processing time even for errors
end_time = datetime.datetime.now()
processing_time = (end_time - start_time).total_seconds()
# Prepare error response (minimal fields)
result = {
"success": False,
"response": "MoA processing failed. Please try again or use a single model for this query.",
"models_used": {
"reference_models": reference_models or REFERENCE_MODELS,
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
},
"error": error_msg
"error": error_msg,
}
debug_call_data["error"] = error_msg
debug_call_data["processing_time_seconds"] = processing_time
_debug.log_call("mixture_of_agents_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
def check_moa_requirements() -> bool:
"""
Check if all requirements for MoA tools are met.
Returns:
bool: True if requirements are met, False otherwise
"""
return check_openrouter_api_key()
def get_debug_session_info() -> Dict[str, Any]:
def get_debug_session_info() -> dict[str, Any]:
"""
Get information about the current debug session.
Returns:
Dict[str, Any]: Dictionary containing debug session information
"""
return _debug.get_session_info()
def get_available_models() -> Dict[str, List[str]]:
def get_available_models() -> dict[str, list[str]]:
"""
Get information about available models for MoA processing.
Returns:
Dict[str, List[str]]: Dictionary with reference and aggregator models
"""
return {
"reference_models": REFERENCE_MODELS,
"aggregator_models": [AGGREGATOR_MODEL],
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL],
}
def get_moa_configuration() -> Dict[str, Any]:
def get_moa_configuration() -> dict[str, Any]:
"""
Get the current MoA configuration settings.
Returns:
Dict[str, Any]: Dictionary containing all configuration parameters
"""
@@ -437,7 +416,7 @@ def get_moa_configuration() -> Dict[str, Any]:
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
"total_reference_models": len(REFERENCE_MODELS),
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail",
}
@@ -447,10 +426,10 @@ if __name__ == "__main__":
"""
print("🤖 Mixture-of-Agents Tool Module")
print("=" * 50)
# Check if API key is available
api_available = check_openrouter_api_key()
if not api_available:
print("❌ OPENROUTER_API_KEY environment variable not set")
print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
@@ -458,26 +437,26 @@ if __name__ == "__main__":
exit(1)
else:
print("✅ OpenRouter API key found")
print("🛠️ MoA tools ready for use!")
# Show current configuration
config = get_moa_configuration()
print(f"\n⚙️ Current Configuration:")
print("\n⚙️ Current Configuration:")
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
# Show debug mode status
if _debug.active:
print(f"\n🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{_debug.session_id}.json")
else:
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
print("\nBasic usage:")
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
print(" import asyncio")
@@ -488,24 +467,26 @@ if __name__ == "__main__":
print(" )")
print(" print(result)")
print(" asyncio.run(main())")
print("\nBest use cases:")
print(" - Complex mathematical proofs and calculations")
print(" - Advanced coding problems and algorithm design")
print(" - Multi-step analytical reasoning tasks")
print(" - Problems requiring diverse domain expertise")
print(" - Tasks where single models show limitations")
print("\nPerformance characteristics:")
print(" - Higher latency due to multiple model calls")
print(" - Significantly improved quality for complex tasks")
print(" - Parallel processing for efficiency")
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
print(
f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation"
)
print(" - Token-efficient: only returns final aggregated response")
print(" - Resilient: continues with partial model failures")
print(f" - Configurable: easy to modify models and settings at top of file")
print(" - Configurable: easy to modify models and settings at top of file")
print(" - State-of-the-art results on challenging benchmarks")
print("\nDebug mode:")
print(" # Enable debug logging")
print(" export MOA_TOOLS_DEBUG=true")
@@ -526,11 +507,11 @@ MOA_SCHEMA = {
"properties": {
"user_prompt": {
"type": "string",
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning."
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning.",
}
},
"required": ["user_prompt"]
}
"required": ["user_prompt"],
},
}
registry.register(

View File

@@ -1,7 +1,7 @@
"""Shared OpenRouter API client for Hermes tools.
Provides a single lazy-initialized AsyncOpenAI client that all tool modules
can share, eliminating the duplicated _get_openrouter_client() /
can share, eliminating the duplicated _get_openrouter_client() /
_get_summarizer_client() pattern previously copy-pasted across web_tools,
vision_tools, mixture_of_agents_tool, and session_search_tool.
"""
@@ -9,6 +9,7 @@ vision_tools, mixture_of_agents_tool, and session_search_tool.
import os
from openai import AsyncOpenAI
from hermes_constants import OPENROUTER_BASE_URL
_client: AsyncOpenAI | None = None

View File

@@ -20,7 +20,7 @@ V4A Format:
Usage:
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
operations, error = parse_v4a_patch(patch_content)
if error:
print(f"Parse error: {error}")
@@ -30,8 +30,8 @@ Usage:
import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Any
from enum import Enum
from typing import Any
class OperationType(Enum):
@@ -44,6 +44,7 @@ class OperationType(Enum):
@dataclass
class HunkLine:
"""A single line in a patch hunk."""
prefix: str # ' ', '-', or '+'
content: str
@@ -51,182 +52,174 @@ class HunkLine:
@dataclass
class Hunk:
"""A group of changes within a file."""
context_hint: Optional[str] = None
lines: List[HunkLine] = field(default_factory=list)
context_hint: str | None = None
lines: list[HunkLine] = field(default_factory=list)
@dataclass
class PatchOperation:
"""A single operation in a V4A patch."""
operation: OperationType
file_path: str
new_path: Optional[str] = None # For move operations
hunks: List[Hunk] = field(default_factory=list)
content: Optional[str] = None # For add file operations
new_path: str | None = None # For move operations
hunks: list[Hunk] = field(default_factory=list)
content: str | None = None # For add file operations
def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]:
def parse_v4a_patch(patch_content: str) -> tuple[list[PatchOperation], str | None]:
"""
Parse a V4A format patch.
Args:
patch_content: The patch text in V4A format
Returns:
Tuple of (operations, error_message)
- If successful: (list_of_operations, None)
- If failed: ([], error_description)
"""
lines = patch_content.split('\n')
operations: List[PatchOperation] = []
lines = patch_content.split("\n")
operations: list[PatchOperation] = []
# Find patch boundaries
start_idx = None
end_idx = None
for i, line in enumerate(lines):
if '*** Begin Patch' in line or '***Begin Patch' in line:
if "*** Begin Patch" in line or "***Begin Patch" in line:
start_idx = i
elif '*** End Patch' in line or '***End Patch' in line:
elif "*** End Patch" in line or "***End Patch" in line:
end_idx = i
break
if start_idx is None:
# Try to parse without explicit begin marker
start_idx = -1
if end_idx is None:
end_idx = len(lines)
# Parse operations between boundaries
i = start_idx + 1
current_op: Optional[PatchOperation] = None
current_hunk: Optional[Hunk] = None
current_op: PatchOperation | None = None
current_hunk: Hunk | None = None
while i < end_idx:
line = lines[i]
# Check for file operation markers
update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line)
add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line)
delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line)
move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line)
update_match = re.match(r"\*\*\*\s*Update\s+File:\s*(.+)", line)
add_match = re.match(r"\*\*\*\s*Add\s+File:\s*(.+)", line)
delete_match = re.match(r"\*\*\*\s*Delete\s+File:\s*(.+)", line)
move_match = re.match(r"\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)", line)
if update_match:
# Save previous operation
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.UPDATE,
file_path=update_match.group(1).strip()
)
current_op = PatchOperation(operation=OperationType.UPDATE, file_path=update_match.group(1).strip())
current_hunk = None
elif add_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.ADD,
file_path=add_match.group(1).strip()
)
current_op = PatchOperation(operation=OperationType.ADD, file_path=add_match.group(1).strip())
current_hunk = Hunk()
elif delete_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.DELETE,
file_path=delete_match.group(1).strip()
)
current_op = PatchOperation(operation=OperationType.DELETE, file_path=delete_match.group(1).strip())
operations.append(current_op)
current_op = None
current_hunk = None
elif move_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.MOVE,
file_path=move_match.group(1).strip(),
new_path=move_match.group(2).strip()
new_path=move_match.group(2).strip(),
)
operations.append(current_op)
current_op = None
current_hunk = None
elif line.startswith('@@'):
elif line.startswith("@@"):
# Context hint / hunk marker
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
# Extract context hint
hint_match = re.match(r'@@\s*(.+?)\s*@@', line)
hint_match = re.match(r"@@\s*(.+?)\s*@@", line)
hint = hint_match.group(1) if hint_match else None
current_hunk = Hunk(context_hint=hint)
elif current_op and line:
# Parse hunk line
if current_hunk is None:
current_hunk = Hunk()
if line.startswith('+'):
current_hunk.lines.append(HunkLine('+', line[1:]))
elif line.startswith('-'):
current_hunk.lines.append(HunkLine('-', line[1:]))
elif line.startswith(' '):
current_hunk.lines.append(HunkLine(' ', line[1:]))
elif line.startswith('\\'):
if line.startswith("+"):
current_hunk.lines.append(HunkLine("+", line[1:]))
elif line.startswith("-"):
current_hunk.lines.append(HunkLine("-", line[1:]))
elif line.startswith(" "):
current_hunk.lines.append(HunkLine(" ", line[1:]))
elif line.startswith("\\"):
# "\ No newline at end of file" marker - skip
pass
else:
# Treat as context line (implicit space prefix)
current_hunk.lines.append(HunkLine(' ', line))
current_hunk.lines.append(HunkLine(" ", line))
i += 1
# Don't forget the last operation
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
return operations, None
def apply_v4a_operations(operations: List[PatchOperation],
file_ops: Any) -> 'PatchResult':
def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "PatchResult":
"""
Apply V4A patch operations using a file operations interface.
Args:
operations: List of PatchOperation from parse_v4a_patch
file_ops: Object with read_file, write_file methods
Returns:
PatchResult with results of all operations
"""
# Import here to avoid circular imports
from tools.file_operations import PatchResult
files_modified = []
files_created = []
files_deleted = []
all_diffs = []
errors = []
for op in operations:
try:
if op.operation == OperationType.ADD:
@@ -236,7 +229,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
all_diffs.append(result[1])
else:
errors.append(f"Failed to add {op.file_path}: {result[1]}")
elif op.operation == OperationType.DELETE:
result = _apply_delete(op, file_ops)
if result[0]:
@@ -244,7 +237,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
all_diffs.append(result[1])
else:
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
elif op.operation == OperationType.MOVE:
result = _apply_move(op, file_ops)
if result[0]:
@@ -252,7 +245,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
all_diffs.append(result[1])
else:
errors.append(f"Failed to move {op.file_path}: {result[1]}")
elif op.operation == OperationType.UPDATE:
result = _apply_update(op, file_ops)
if result[0]:
@@ -260,19 +253,19 @@ def apply_v4a_operations(operations: List[PatchOperation],
all_diffs.append(result[1])
else:
errors.append(f"Failed to update {op.file_path}: {result[1]}")
except Exception as e:
errors.append(f"Error processing {op.file_path}: {str(e)}")
# Run lint on all modified/created files
lint_results = {}
for f in files_modified + files_created:
if hasattr(file_ops, '_check_lint'):
if hasattr(file_ops, "_check_lint"):
lint_result = file_ops._check_lint(f)
lint_results[f] = lint_result.to_dict()
combined_diff = '\n'.join(all_diffs)
combined_diff = "\n".join(all_diffs)
if errors:
return PatchResult(
success=False,
@@ -281,123 +274,124 @@ def apply_v4a_operations(operations: List[PatchOperation],
files_created=files_created,
files_deleted=files_deleted,
lint=lint_results if lint_results else None,
error='; '.join(errors)
error="; ".join(errors),
)
return PatchResult(
success=True,
diff=combined_diff,
files_modified=files_modified,
files_created=files_created,
files_deleted=files_deleted,
lint=lint_results if lint_results else None
lint=lint_results if lint_results else None,
)
def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
def _apply_add(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
"""Apply an add file operation."""
# Extract content from hunks (all + lines)
content_lines = []
for hunk in op.hunks:
for line in hunk.lines:
if line.prefix == '+':
if line.prefix == "+":
content_lines.append(line.content)
content = '\n'.join(content_lines)
content = "\n".join(content_lines)
result = file_ops.write_file(op.file_path, content)
if result.error:
return False, result.error
diff = f"--- /dev/null\n+++ b/{op.file_path}\n"
diff += '\n'.join(f"+{line}" for line in content_lines)
diff += "\n".join(f"+{line}" for line in content_lines)
return True, diff
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
def _apply_delete(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
"""Apply a delete file operation."""
# Read file first for diff
read_result = file_ops.read_file(op.file_path)
if read_result.error and "not found" in read_result.error.lower():
# File doesn't exist, nothing to delete
return True, f"# {op.file_path} already deleted or doesn't exist"
# Delete directly via shell command using the underlying environment
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
if rm_result.exit_code != 0:
return False, rm_result.stdout
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
return True, diff
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
def _apply_move(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
"""Apply a move file operation."""
# Use shell mv command
mv_result = file_ops._exec(
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
)
if mv_result.exit_code != 0:
return False, mv_result.stdout
diff = f"# Moved: {op.file_path} -> {op.new_path}"
return True, diff
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
def _apply_update(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
"""Apply an update file operation."""
# Read current content
read_result = file_ops.read_file(op.file_path, limit=10000)
if read_result.error:
return False, f"Cannot read file: {read_result.error}"
# Parse content (remove line numbers)
current_lines = []
for line in read_result.content.split('\n'):
if '|' in line:
for line in read_result.content.split("\n"):
if "|" in line:
# Line format: " 123|content"
parts = line.split('|', 1)
parts = line.split("|", 1)
if len(parts) == 2:
current_lines.append(parts[1])
else:
current_lines.append(line)
else:
current_lines.append(line)
current_content = '\n'.join(current_lines)
current_content = "\n".join(current_lines)
# Apply each hunk
new_content = current_content
for hunk in op.hunks:
# Build search pattern from context and removed lines
search_lines = []
replace_lines = []
for line in hunk.lines:
if line.prefix == ' ':
if line.prefix == " ":
search_lines.append(line.content)
replace_lines.append(line.content)
elif line.prefix == '-':
elif line.prefix == "-":
search_lines.append(line.content)
elif line.prefix == '+':
elif line.prefix == "+":
replace_lines.append(line.content)
if search_lines:
search_pattern = '\n'.join(search_lines)
replacement = '\n'.join(replace_lines)
search_pattern = "\n".join(search_lines)
replacement = "\n".join(replace_lines)
# Use fuzzy matching
from tools.fuzzy_match import fuzzy_find_and_replace
new_content, count, error = fuzzy_find_and_replace(
new_content, search_pattern, replacement, replace_all=False
)
if error and count == 0:
# Try with context hint if available
if hunk.context_hint:
@@ -408,31 +402,32 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
window_start = max(0, hint_pos - 500)
window_end = min(len(new_content), hint_pos + 2000)
window = new_content[window_start:window_end]
window_new, count, error = fuzzy_find_and_replace(
window, search_pattern, replacement, replace_all=False
)
if count > 0:
new_content = new_content[:window_start] + window_new + new_content[window_end:]
error = None
if error:
return False, f"Could not apply hunk: {error}"
# Write new content
write_result = file_ops.write_file(op.file_path, new_content)
if write_result.error:
return False, write_result.error
# Generate diff
import difflib
diff_lines = difflib.unified_diff(
current_content.splitlines(keepends=True),
new_content.splitlines(keepends=True),
fromfile=f"a/{op.file_path}",
tofile=f"b/{op.file_path}"
tofile=f"b/{op.file_path}",
)
diff = ''.join(diff_lines)
diff = "".join(diff_lines)
return True, diff

View File

@@ -34,7 +34,6 @@ import logging
import os
import platform
import shlex
import shutil
import signal
import subprocess
import threading
@@ -42,10 +41,11 @@ import time
import uuid
_IS_WINDOWS = platform.system() == "Windows"
from tools.environments.local import _find_shell
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
from tools.environments.local import _find_shell
logger = logging.getLogger(__name__)
@@ -54,30 +54,31 @@ logger = logging.getLogger(__name__)
CHECKPOINT_PATH = Path(os.path.expanduser("~/.hermes/processes.json"))
# Limits
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
@dataclass
class ProcessSession:
"""A tracked background process with output buffering."""
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
command: str # Original command string
task_id: str = "" # Task/sandbox isolation key
session_key: str = "" # Gateway session key (for reset protection)
pid: Optional[int] = None # OS process ID
process: Optional[subprocess.Popen] = None # Popen handle (local only)
env_ref: Any = None # Reference to the environment object
cwd: Optional[str] = None # Working directory
started_at: float = 0.0 # time.time() of spawn
exited: bool = False # Whether the process has finished
exit_code: Optional[int] = None # Exit code (None if still running)
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
command: str # Original command string
task_id: str = "" # Task/sandbox isolation key
session_key: str = "" # Gateway session key (for reset protection)
pid: int | None = None # OS process ID
process: subprocess.Popen | None = None # Popen handle (local only)
env_ref: Any = None # Reference to the environment object
cwd: str | None = None # Working directory
started_at: float = 0.0 # time.time() of spawn
exited: bool = False # Whether the process has finished
exit_code: int | None = None # Exit code (None if still running)
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
max_output_chars: int = MAX_OUTPUT_CHARS
detached: bool = False # True if recovered from crash (no pipe)
detached: bool = False # True if recovered from crash (no pipe)
_lock: threading.Lock = field(default_factory=threading.Lock)
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
_reader_thread: threading.Thread | None = field(default=None, repr=False)
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
@@ -100,12 +101,12 @@ class ProcessRegistry:
)
def __init__(self):
self._running: Dict[str, ProcessSession] = {}
self._finished: Dict[str, ProcessSession] = {}
self._running: dict[str, ProcessSession] = {}
self._finished: dict[str, ProcessSession] = {}
self._lock = threading.Lock()
# Side-channel for check_interval watchers (gateway reads after agent run)
self.pending_watchers: List[Dict[str, Any]] = []
self.pending_watchers: list[dict[str, Any]] = []
@staticmethod
def _clean_shell_noise(text: str) -> str:
@@ -149,6 +150,7 @@ class ProcessRegistry:
# Try PTY mode for interactive CLI tools
try:
import ptyprocess
user_shell = _find_shell()
pty_env = os.environ | (env_vars or {})
pty_env["PYTHONUNBUFFERED"] = "1"
@@ -260,10 +262,7 @@ class ProcessRegistry:
log_path = f"/tmp/hermes_bg_{session.id}.log"
pid_path = f"/tmp/hermes_bg_{session.id}.pid"
quoted_command = shlex.quote(command)
bg_command = (
f"nohup bash -c {quoted_command} > {log_path} 2>&1 & "
f"echo $! > {pid_path} && cat {pid_path}"
)
bg_command = f"nohup bash -c {quoted_command} > {log_path} 2>&1 & echo $! > {pid_path} && cat {pid_path}"
try:
result = env.execute(bg_command, timeout=timeout)
@@ -313,7 +312,7 @@ class ProcessRegistry:
with session._lock:
session.output_buffer += chunk
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
session.output_buffer = session.output_buffer[-session.max_output_chars :]
except Exception as e:
logger.debug("Process stdout reader ended: %s", e)
@@ -326,9 +325,7 @@ class ProcessRegistry:
session.exit_code = session.process.returncode
self._move_to_finished(session)
def _env_poller_loop(
self, session: ProcessSession, env: Any, log_path: str, pid_path: str
):
def _env_poller_loop(self, session: ProcessSession, env: Any, log_path: str, pid_path: str):
"""Background thread: poll a sandbox log file for non-local backends."""
while not session.exited:
time.sleep(2) # Poll every 2 seconds
@@ -340,7 +337,7 @@ class ProcessRegistry:
with session._lock:
session.output_buffer = new_output
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
session.output_buffer = session.output_buffer[-session.max_output_chars :]
# Check if process is still running
check = env.execute(
@@ -383,7 +380,7 @@ class ProcessRegistry:
with session._lock:
session.output_buffer += text
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
session.output_buffer = session.output_buffer[-session.max_output_chars :]
except EOFError:
break
except Exception:
@@ -397,7 +394,7 @@ class ProcessRegistry:
except Exception as e:
logger.debug("PTY wait timed out or failed: %s", e)
session.exited = True
session.exit_code = pty.exitstatus if hasattr(pty, 'exitstatus') else -1
session.exit_code = pty.exitstatus if hasattr(pty, "exitstatus") else -1
self._move_to_finished(session)
def _move_to_finished(self, session: ProcessSession):
@@ -409,7 +406,7 @@ class ProcessRegistry:
# ----- Query Methods -----
def get(self, session_id: str) -> Optional[ProcessSession]:
def get(self, session_id: str) -> ProcessSession | None:
"""Get a session by ID (running or finished)."""
with self._lock:
return self._running.get(session_id) or self._finished.get(session_id)
@@ -454,7 +451,7 @@ class ProcessRegistry:
if offset == 0 and limit > 0:
selected = lines[-limit:]
else:
selected = lines[offset:offset + limit]
selected = lines[offset : offset + limit]
return {
"session_id": session.id,
@@ -485,10 +482,7 @@ class ProcessRegistry:
if requested_timeout and requested_timeout > max_timeout:
effective_timeout = max_timeout
timeout_note = (
f"Requested wait of {requested_timeout}s was clamped "
f"to configured limit of {max_timeout}s"
)
timeout_note = f"Requested wait of {requested_timeout}s was clamped to configured limit of {max_timeout}s"
else:
effective_timeout = requested_timeout or max_timeout
@@ -581,7 +575,7 @@ class ProcessRegistry:
return {"status": "already_exited", "error": "Process has already finished"}
# PTY mode -- write through pty handle (expects bytes)
if hasattr(session, '_pty') and session._pty:
if hasattr(session, "_pty") and session._pty:
try:
pty_data = data.encode("utf-8") if isinstance(data, str) else data
session._pty.write(pty_data)
@@ -635,26 +629,17 @@ class ProcessRegistry:
def has_active_processes(self, task_id: str) -> bool:
"""Check if there are active (running) processes for a task_id."""
with self._lock:
return any(
s.task_id == task_id and not s.exited
for s in self._running.values()
)
return any(s.task_id == task_id and not s.exited for s in self._running.values())
def has_active_for_session(self, session_key: str) -> bool:
"""Check if there are active processes for a gateway session key."""
with self._lock:
return any(
s.session_key == session_key and not s.exited
for s in self._running.values()
)
return any(s.session_key == session_key and not s.exited for s in self._running.values())
def kill_all(self, task_id: str = None) -> int:
"""Kill all running processes, optionally filtered by task_id. Returns count killed."""
with self._lock:
targets = [
s for s in self._running.values()
if (task_id is None or s.task_id == task_id) and not s.exited
]
targets = [s for s in self._running.values() if (task_id is None or s.task_id == task_id) and not s.exited]
killed = 0
for session in targets:
@@ -669,10 +654,7 @@ class ProcessRegistry:
"""Remove oldest finished sessions if over MAX_PROCESSES. Must hold _lock."""
# First prune expired finished sessions
now = time.time()
expired = [
sid for sid, s in self._finished.items()
if (now - s.started_at) > FINISHED_TTL_SECONDS
]
expired = [sid for sid, s in self._finished.items() if (now - s.started_at) > FINISHED_TTL_SECONDS]
for sid in expired:
del self._finished[sid]
@@ -696,18 +678,21 @@ class ProcessRegistry:
entries = []
for s in self._running.values():
if not s.exited:
entries.append({
"session_id": s.id,
"command": s.command,
"pid": s.pid,
"cwd": s.cwd,
"started_at": s.started_at,
"task_id": s.task_id,
"session_key": s.session_key,
})
entries.append(
{
"session_id": s.id,
"command": s.command,
"pid": s.pid,
"cwd": s.cwd,
"started_at": s.started_at,
"task_id": s.task_id,
"session_key": s.session_key,
}
)
# Atomic write to avoid corruption on crash
from utils import atomic_json_write
atomic_json_write(CHECKPOINT_PATH, entries)
except Exception as e:
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
@@ -759,6 +744,7 @@ class ProcessRegistry:
# Clear the checkpoint (will be rewritten as processes finish)
try:
from utils import atomic_json_write
atomic_json_write(CHECKPOINT_PATH, [])
except Exception as e:
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
@@ -790,38 +776,32 @@ PROCESS_SCHEMA = {
"action": {
"type": "string",
"enum": ["list", "poll", "log", "wait", "kill", "write", "submit"],
"description": "Action to perform on background processes"
"description": "Action to perform on background processes",
},
"session_id": {
"type": "string",
"description": "Process session ID (from terminal background output). Required for all actions except 'list'."
"description": "Process session ID (from terminal background output). Required for all actions except 'list'.",
},
"data": {
"type": "string",
"description": "Text to send to process stdin (for 'write' and 'submit' actions)"
"description": "Text to send to process stdin (for 'write' and 'submit' actions)",
},
"timeout": {
"type": "integer",
"description": "Max seconds to block for 'wait' action. Returns partial output on timeout.",
"minimum": 1
"minimum": 1,
},
"offset": {
"type": "integer",
"description": "Line offset for 'log' action (default: last 200 lines)"
},
"limit": {
"type": "integer",
"description": "Max lines to return for 'log' action",
"minimum": 1
}
"offset": {"type": "integer", "description": "Line offset for 'log' action (default: last 200 lines)"},
"limit": {"type": "integer", "description": "Max lines to return for 'log' action", "minimum": 1},
},
"required": ["action"]
}
"required": ["action"],
},
}
def _handle_process(args, **kw):
import json as _json
task_id = kw.get("task_id")
action = args.get("action", "")
# Coerce to string — some models send session_id as an integer
@@ -835,8 +815,10 @@ def _handle_process(args, **kw):
if action == "poll":
return _json.dumps(process_registry.poll(session_id), ensure_ascii=False)
elif action == "log":
return _json.dumps(process_registry.read_log(
session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)), ensure_ascii=False)
return _json.dumps(
process_registry.read_log(session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)),
ensure_ascii=False,
)
elif action == "wait":
return _json.dumps(process_registry.wait(session_id, timeout=args.get("timeout")), ensure_ascii=False)
elif action == "kill":
@@ -845,7 +827,10 @@ def _handle_process(args, **kw):
return _json.dumps(process_registry.write_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
elif action == "submit":
return _json.dumps(process_registry.submit_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
return _json.dumps({"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"}, ensure_ascii=False)
return _json.dumps(
{"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"},
ensure_ascii=False,
)
registry.register(

View File

@@ -16,7 +16,7 @@ Import chain (circular-import safe):
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Set
from collections.abc import Callable
logger = logging.getLogger(__name__)
@@ -25,12 +25,17 @@ class ToolEntry:
"""Metadata for a single registered tool."""
__slots__ = (
"name", "toolset", "schema", "handler", "check_fn",
"requires_env", "is_async", "description",
"name",
"toolset",
"schema",
"handler",
"check_fn",
"requires_env",
"is_async",
"description",
)
def __init__(self, name, toolset, schema, handler, check_fn,
requires_env, is_async, description):
def __init__(self, name, toolset, schema, handler, check_fn, requires_env, is_async, description):
self.name = name
self.toolset = toolset
self.schema = schema
@@ -45,8 +50,8 @@ class ToolRegistry:
"""Singleton registry that collects tool schemas + handlers from tool files."""
def __init__(self):
self._tools: Dict[str, ToolEntry] = {}
self._toolset_checks: Dict[str, Callable] = {}
self._tools: dict[str, ToolEntry] = {}
self._toolset_checks: dict[str, Callable] = {}
# ------------------------------------------------------------------
# Registration
@@ -81,7 +86,7 @@ class ToolRegistry:
# Schema retrieval
# ------------------------------------------------------------------
def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]:
def get_definitions(self, tool_names: set[str], quiet: bool = False) -> list[dict]:
"""Return OpenAI-format tool schemas for the requested tool names.
Only tools whose ``check_fn()`` returns True (or have no check_fn)
@@ -122,6 +127,7 @@ class ToolRegistry:
try:
if entry.is_async:
from model_tools import _run_async
return _run_async(entry.handler(args, **kwargs))
return entry.handler(args, **kwargs)
except Exception as e:
@@ -132,16 +138,16 @@ class ToolRegistry:
# Query helpers (replace redundant dicts in model_tools.py)
# ------------------------------------------------------------------
def get_all_tool_names(self) -> List[str]:
def get_all_tool_names(self) -> list[str]:
"""Return sorted list of all registered tool names."""
return sorted(self._tools.keys())
def get_toolset_for_tool(self, name: str) -> Optional[str]:
def get_toolset_for_tool(self, name: str) -> str | None:
"""Return the toolset a tool belongs to, or None."""
entry = self._tools.get(name)
return entry.toolset if entry else None
def get_tool_to_toolset_map(self) -> Dict[str, str]:
def get_tool_to_toolset_map(self) -> dict[str, str]:
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
return {name: e.toolset for name, e in self._tools.items()}
@@ -160,14 +166,14 @@ class ToolRegistry:
logger.debug("Toolset %s check raised; marking unavailable", toolset)
return False
def check_toolset_requirements(self) -> Dict[str, bool]:
def check_toolset_requirements(self) -> dict[str, bool]:
"""Return ``{toolset: available_bool}`` for every toolset."""
toolsets = set(e.toolset for e in self._tools.values())
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
def get_available_toolsets(self) -> Dict[str, dict]:
def get_available_toolsets(self) -> dict[str, dict]:
"""Return toolset metadata for UI display."""
toolsets: Dict[str, dict] = {}
toolsets: dict[str, dict] = {}
for entry in self._tools.values():
ts = entry.toolset
if ts not in toolsets:
@@ -184,9 +190,9 @@ class ToolRegistry:
toolsets[ts]["requirements"].append(env)
return toolsets
def get_toolset_requirements(self) -> Dict[str, dict]:
def get_toolset_requirements(self) -> dict[str, dict]:
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
result: Dict[str, dict] = {}
result: dict[str, dict] = {}
for entry in self._tools.values():
ts = entry.toolset
if ts not in result:
@@ -217,11 +223,13 @@ class ToolRegistry:
if self.is_toolset_available(ts):
available.append(ts)
else:
unavailable.append({
"name": ts,
"env_vars": entry.requires_env,
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
})
unavailable.append(
{
"name": ts,
"env_vars": entry.requires_env,
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
}
)
return available, unavailable

File diff suppressed because it is too large Load Diff

View File

@@ -29,19 +29,16 @@ SEND_MESSAGE_SCHEMA = {
"action": {
"type": "string",
"enum": ["send", "list"],
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms."
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms.",
},
"target": {
"type": "string",
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'"
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'",
},
"message": {
"type": "string",
"description": "The message text to send"
}
"message": {"type": "string", "description": "The message text to send"},
},
"required": []
}
"required": [],
},
}
@@ -59,6 +56,7 @@ def _handle_list():
"""Return formatted list of available messaging targets."""
try:
from gateway.channel_directory import format_directory_for_display
return json.dumps({"targets": format_directory_for_display()})
except Exception as e:
return json.dumps({"error": f"Failed to load channel directory: {e}"})
@@ -79,26 +77,30 @@ def _handle_send(args):
if chat_id and not chat_id.lstrip("-").isdigit():
try:
from gateway.channel_directory import resolve_channel_name
resolved = resolve_channel_name(platform_name, chat_id)
if resolved:
chat_id = resolved
else:
return json.dumps({
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
f"Use send_message(action='list') to see available targets."
})
return json.dumps(
{
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
f"Use send_message(action='list') to see available targets."
}
)
except Exception:
return json.dumps({
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
f"Try using a numeric channel ID instead."
})
return json.dumps(
{"error": f"Could not resolve '{chat_id}' on {platform_name}. Try using a numeric channel ID instead."}
)
from tools.interrupt import is_interrupted
if is_interrupted():
return json.dumps({"error": "Interrupted"})
try:
from gateway.config import load_gateway_config, Platform
from gateway.config import Platform, load_gateway_config
config = load_gateway_config()
except Exception as e:
return json.dumps({"error": f"Failed to load gateway config: {e}"})
@@ -117,7 +119,11 @@ def _handle_send(args):
pconfig = config.platforms.get(platform)
if not pconfig or not pconfig.enabled:
return json.dumps({"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."})
return json.dumps(
{
"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."
}
)
used_home_channel = False
if not chat_id:
@@ -126,14 +132,17 @@ def _handle_send(args):
chat_id = home.chat_id
used_home_channel = True
else:
return json.dumps({
"error": f"No home channel set for {platform_name} to determine where to send the message. "
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
})
return json.dumps(
{
"error": f"No home channel set for {platform_name} to determine where to send the message. "
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
}
)
try:
from model_tools import _run_async
result = _run_async(_send_to_platform(platform, pconfig, chat_id, message))
if used_home_channel and isinstance(result, dict) and result.get("success"):
result["note"] = f"Sent to {platform_name} home channel (chat_id: {chat_id})"
@@ -142,6 +151,7 @@ def _handle_send(args):
if isinstance(result, dict) and result.get("success"):
try:
from gateway.mirror import mirror_to_session
source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli")
if mirror_to_session(platform_name, chat_id, message, source_label=source_label):
result["mirrored"] = True
@@ -156,6 +166,7 @@ def _handle_send(args):
async def _send_to_platform(platform, pconfig, chat_id, message):
"""Route a message to the appropriate platform sender."""
from gateway.config import Platform
if platform == Platform.TELEGRAM:
return await _send_telegram(pconfig.token, chat_id, message)
elif platform == Platform.DISCORD:
@@ -171,6 +182,7 @@ async def _send_telegram(token, chat_id, message):
"""Send via Telegram Bot API (one-shot, no polling needed)."""
try:
from telegram import Bot
bot = Bot(token=token)
msg = await bot.send_message(chat_id=int(chat_id), text=message)
return {"success": True, "platform": "telegram", "chat_id": chat_id, "message_id": str(msg.message_id)}
@@ -189,7 +201,7 @@ async def _send_discord(token, chat_id, message):
try:
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"}
chunks = [message[i:i+2000] for i in range(0, len(message), 2000)]
chunks = [message[i : i + 2000] for i in range(0, len(message), 2000)]
message_ids = []
async with aiohttp.ClientSession() as session:
for chunk in chunks:
@@ -266,6 +278,7 @@ def _check_send_message():
return True
try:
from gateway.status import is_gateway_running
return is_gateway_running()
except Exception:
return False

View File

@@ -18,11 +18,8 @@ Flow:
import asyncio
import concurrent.futures
import json
import os
import logging
from typing import Dict, Any, List, Optional, Union
from openai import AsyncOpenAI, OpenAI
from typing import Any
from agent.auxiliary_client import get_async_text_auxiliary_client
@@ -33,7 +30,7 @@ MAX_SESSION_CHARS = 100_000
MAX_SUMMARY_TOKENS = 10000
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
def _format_timestamp(ts: int | float | str | None) -> str:
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
Returns "unknown" for None, str(ts) if conversion fails.
@@ -43,11 +40,13 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str:
try:
if isinstance(ts, (int, float)):
from datetime import datetime
dt = datetime.fromtimestamp(ts)
return dt.strftime("%B %d, %Y at %I:%M %p")
if isinstance(ts, str):
if ts.replace(".", "").replace("-", "").isdigit():
from datetime import datetime
dt = datetime.fromtimestamp(float(ts))
return dt.strftime("%B %d, %Y at %I:%M %p")
return ts
@@ -59,7 +58,7 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str:
return str(ts)
def _format_conversation(messages: List[Dict[str, Any]]) -> str:
def _format_conversation(messages: list[dict[str, Any]]) -> str:
"""Format session messages into a readable transcript for summarization."""
parts = []
for msg in messages:
@@ -93,9 +92,7 @@ def _format_conversation(messages: List[Dict[str, Any]]) -> str:
return "\n\n".join(parts)
def _truncate_around_matches(
full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS
) -> str:
def _truncate_around_matches(full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS) -> str:
"""
Truncate a conversation transcript to max_chars, centered around
where the query terms appear. Keeps content near matches, trims the edges.
@@ -129,9 +126,7 @@ def _truncate_around_matches(
return prefix + truncated + suffix
async def _summarize_session(
conversation_text: str, query: str, session_meta: Dict[str, Any]
) -> Optional[str]:
async def _summarize_session(conversation_text: str, query: str, session_meta: dict[str, Any]) -> str | None:
"""Summarize a single session conversation focused on the search query."""
system_prompt = (
"You are reviewing a past conversation transcript to help recall what happened. "
@@ -163,7 +158,8 @@ async def _summarize_session(
max_retries = 3
for attempt in range(max_retries):
try:
from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param
from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body
_extra = get_auxiliary_extra_body()
response = await _async_aux_client.chat.completions.create(
model=_SUMMARIZER_MODEL,
@@ -221,13 +217,16 @@ def session_search(
)
if not raw_results:
return json.dumps({
"success": True,
"query": query,
"results": [],
"count": 0,
"message": "No matching sessions found.",
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"query": query,
"results": [],
"count": 0,
"message": "No matching sessions found.",
},
ensure_ascii=False,
)
# Resolve child sessions to their parent — delegation stores detailed
# content in child sessions, but the user's conversation is the parent.
@@ -283,12 +282,9 @@ def session_search(
logging.warning(f"Failed to prepare session {session_id}: {e}")
# Summarize all sessions in parallel
async def _summarize_all() -> List[Union[str, Exception]]:
async def _summarize_all() -> list[str | Exception]:
"""Summarize all sessions in parallel."""
coros = [
_summarize_session(text, query, meta)
for _, _, text, meta in tasks
]
coros = [_summarize_session(text, query, meta) for _, _, text, meta in tasks]
return await asyncio.gather(*coros, return_exceptions=True)
try:
@@ -300,10 +296,13 @@ def session_search(
results = asyncio.run(_summarize_all())
except concurrent.futures.TimeoutError:
logging.warning("Session summarization timed out after 60 seconds")
return json.dumps({
"success": False,
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
},
ensure_ascii=False,
)
summaries = []
for (session_id, match_info, _, _), result in zip(tasks, results):
@@ -311,21 +310,26 @@ def session_search(
logging.warning(f"Failed to summarize session {session_id}: {result}")
continue
if result:
summaries.append({
"session_id": session_id,
"when": _format_timestamp(match_info.get("session_started")),
"source": match_info.get("source", "unknown"),
"model": match_info.get("model"),
"summary": result,
})
summaries.append(
{
"session_id": session_id,
"when": _format_timestamp(match_info.get("session_started")),
"source": match_info.get("source", "unknown"),
"model": match_info.get("model"),
"summary": result,
}
)
return json.dumps({
"success": True,
"query": query,
"results": summaries,
"count": len(summaries),
"sessions_searched": len(seen_sessions),
}, ensure_ascii=False)
return json.dumps(
{
"success": True,
"query": query,
"results": summaries,
"count": len(summaries),
"sessions_searched": len(seen_sessions),
},
ensure_ascii=False,
)
except Exception as e:
return json.dumps({"success": False, "error": f"Search failed: {str(e)}"}, ensure_ascii=False)
@@ -337,6 +341,7 @@ def check_session_search_requirements() -> bool:
return False
try:
from hermes_state import DEFAULT_DB_PATH
return DEFAULT_DB_PATH.parent.exists()
except ImportError:
return False
@@ -356,7 +361,7 @@ SESSION_SEARCH_SCHEMA = {
"Don't hesitate to search -- it's fast and cheap. Better to search and confirm "
"than to guess or ask the user to repeat themselves.\n\n"
"Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), "
"phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). "
'phrases for exact match ("docker networking"), boolean (python NOT java), prefix (deploy*). '
"IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses "
"sessions that only mention some terms. If a broad OR query returns nothing, try individual "
"keyword searches in parallel. Returns summaries of the top matching sessions."
@@ -395,6 +400,7 @@ registry.register(
role_filter=args.get("role_filter"),
limit=args.get("limit", 3),
db=kw.get("db"),
current_session_id=kw.get("current_session_id")),
current_session_id=kw.get("current_session_id"),
),
check_fn=check_session_search_requirements,
)

View File

@@ -38,20 +38,21 @@ import os
import re
import shutil
from pathlib import Path
from typing import Dict, Any, Optional
from typing import Any
logger = logging.getLogger(__name__)
# Import security scanner — agent-created skills get the same scrutiny as
# community hub installs.
try:
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
from tools.skills_guard import format_scan_report, scan_skill, should_allow_install
_GUARD_AVAILABLE = True
except ImportError:
_GUARD_AVAILABLE = False
def _security_scan_skill(skill_dir: Path) -> Optional[str]:
def _security_scan_skill(skill_dir: Path) -> str | None:
"""Scan a skill directory after write. Returns error string if blocked, else None."""
if not _GUARD_AVAILABLE:
return None
@@ -65,8 +66,8 @@ def _security_scan_skill(skill_dir: Path) -> Optional[str]:
logger.warning("Security scan failed for %s: %s", skill_dir, e)
return None
import yaml
import yaml
# All skills live in ~/.hermes/skills/ (single source of truth)
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
@@ -76,7 +77,7 @@ MAX_NAME_LENGTH = 64
MAX_DESCRIPTION_LENGTH = 1024
# Characters allowed in skill names (filesystem-safe, URL-friendly)
VALID_NAME_RE = re.compile(r'^[a-z0-9][a-z0-9._-]*$')
VALID_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9._-]*$")
# Subdirectories allowed for write_file/remove_file
ALLOWED_SUBDIRS = {"references", "templates", "scripts", "assets"}
@@ -91,7 +92,8 @@ def check_skill_manage_requirements() -> bool:
# Validation helpers
# =============================================================================
def _validate_name(name: str) -> Optional[str]:
def _validate_name(name: str) -> str | None:
"""Validate a skill name. Returns error message or None if valid."""
if not name:
return "Skill name is required."
@@ -105,7 +107,7 @@ def _validate_name(name: str) -> Optional[str]:
return None
def _validate_frontmatter(content: str) -> Optional[str]:
def _validate_frontmatter(content: str) -> str | None:
"""
Validate that SKILL.md content has proper frontmatter with required fields.
Returns error message or None if valid.
@@ -116,11 +118,11 @@ def _validate_frontmatter(content: str) -> Optional[str]:
if not content.startswith("---"):
return "SKILL.md must start with YAML frontmatter (---). See existing skills for format."
end_match = re.search(r'\n---\s*\n', content[3:])
end_match = re.search(r"\n---\s*\n", content[3:])
if not end_match:
return "SKILL.md frontmatter is not closed. Ensure you have a closing '---' line."
yaml_content = content[3:end_match.start() + 3]
yaml_content = content[3 : end_match.start() + 3]
try:
parsed = yaml.safe_load(yaml_content)
@@ -137,7 +139,7 @@ def _validate_frontmatter(content: str) -> Optional[str]:
if len(str(parsed["description"])) > MAX_DESCRIPTION_LENGTH:
return f"Description exceeds {MAX_DESCRIPTION_LENGTH} characters."
body = content[end_match.end() + 3:].strip()
body = content[end_match.end() + 3 :].strip()
if not body:
return "SKILL.md must have content after the frontmatter (instructions, procedures, etc.)."
@@ -151,7 +153,7 @@ def _resolve_skill_dir(name: str, category: str = None) -> Path:
return SKILLS_DIR / name
def _find_skill(name: str) -> Optional[Dict[str, Any]]:
def _find_skill(name: str) -> dict[str, Any] | None:
"""
Find a skill by name in ~/.hermes/skills/.
Returns {"path": Path} or None.
@@ -164,7 +166,7 @@ def _find_skill(name: str) -> Optional[Dict[str, Any]]:
return None
def _validate_file_path(file_path: str) -> Optional[str]:
def _validate_file_path(file_path: str) -> str | None:
"""
Validate a file path for write_file/remove_file.
Must be under an allowed subdirectory and not escape the skill dir.
@@ -194,7 +196,8 @@ def _validate_file_path(file_path: str) -> Optional[str]:
# Core actions
# =============================================================================
def _create_skill(name: str, content: str, category: str = None) -> Dict[str, Any]:
def _create_skill(name: str, content: str, category: str = None) -> dict[str, Any]:
"""Create a new user skill with SKILL.md content."""
# Validate name
err = _validate_name(name)
@@ -209,10 +212,7 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
# Check for name collisions across all directories
existing = _find_skill(name)
if existing:
return {
"success": False,
"error": f"A skill named '{name}' already exists at {existing['path']}."
}
return {"success": False, "error": f"A skill named '{name}' already exists at {existing['path']}."}
# Create the skill directory
skill_dir = _resolve_skill_dir(name, category)
@@ -238,12 +238,12 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
result["category"] = category
result["hint"] = (
"To add reference files, templates, or scripts, use "
"skill_manage(action='write_file', name='{}', file_path='references/example.md', file_content='...')".format(name)
f"skill_manage(action='write_file', name='{name}', file_path='references/example.md', file_content='...')"
)
return result
def _edit_skill(name: str, content: str) -> Dict[str, Any]:
def _edit_skill(name: str, content: str) -> dict[str, Any]:
"""Replace the SKILL.md of any existing skill (full rewrite)."""
err = _validate_frontmatter(content)
if err:
@@ -278,7 +278,7 @@ def _patch_skill(
new_string: str,
file_path: str = None,
replace_all: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Targeted find-and-replace within a skill file.
Defaults to SKILL.md. Use file_path to patch a supporting file instead.
@@ -287,7 +287,10 @@ def _patch_skill(
if not old_string:
return {"success": False, "error": "old_string is required for 'patch'."}
if new_string is None:
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
return {
"success": False,
"error": "new_string is required for 'patch'. Use an empty string to delete matched text.",
}
existing = _find_skill(name)
if not existing:
@@ -357,7 +360,7 @@ def _patch_skill(
}
def _delete_skill(name: str) -> Dict[str, Any]:
def _delete_skill(name: str) -> dict[str, Any]:
"""Delete a skill."""
existing = _find_skill(name)
if not existing:
@@ -377,7 +380,7 @@ def _delete_skill(name: str) -> Dict[str, Any]:
}
def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
def _write_file(name: str, file_path: str, file_content: str) -> dict[str, Any]:
"""Add or overwrite a supporting file within any skill directory."""
err = _validate_file_path(file_path)
if err:
@@ -412,7 +415,7 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
}
def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
def _remove_file(name: str, file_path: str) -> dict[str, Any]:
"""Remove a supporting file from any skill directory."""
err = _validate_file_path(file_path)
if err:
@@ -456,6 +459,7 @@ def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
# Main entry point
# =============================================================================
def skill_manage(
action: str,
name: str,
@@ -474,19 +478,37 @@ def skill_manage(
"""
if action == "create":
if not content:
return json.dumps({"success": False, "error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body)."}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body).",
},
ensure_ascii=False,
)
result = _create_skill(name, content, category)
elif action == "edit":
if not content:
return json.dumps({"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."},
ensure_ascii=False,
)
result = _edit_skill(name, content)
elif action == "patch":
if not old_string:
return json.dumps({"success": False, "error": "old_string is required for 'patch'. Provide the text to find."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "old_string is required for 'patch'. Provide the text to find."},
ensure_ascii=False,
)
if new_string is None:
return json.dumps({"success": False, "error": "new_string is required for 'patch'. Use empty string to delete matched text."}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "new_string is required for 'patch'. Use empty string to delete matched text.",
},
ensure_ascii=False,
)
result = _patch_skill(name, old_string, new_string, file_path, replace_all)
elif action == "delete":
@@ -494,18 +516,31 @@ def skill_manage(
elif action == "write_file":
if not file_path:
return json.dumps({"success": False, "error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'"}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'",
},
ensure_ascii=False,
)
if file_content is None:
return json.dumps({"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False
)
result = _write_file(name, file_path, file_content)
elif action == "remove_file":
if not file_path:
return json.dumps({"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False)
return json.dumps(
{"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False
)
result = _remove_file(name, file_path)
else:
result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"}
result = {
"success": False,
"error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file",
}
return json.dumps(result, ensure_ascii=False)
@@ -540,14 +575,14 @@ SKILL_MANAGE_SCHEMA = {
"action": {
"type": "string",
"enum": ["create", "patch", "edit", "delete", "write_file", "remove_file"],
"description": "The action to perform."
"description": "The action to perform.",
},
"name": {
"type": "string",
"description": (
"Skill name (lowercase, hyphens/underscores, max 64 chars). "
"Must match an existing skill for patch/edit/delete/write_file/remove_file."
)
),
},
"content": {
"type": "string",
@@ -555,7 +590,7 @@ SKILL_MANAGE_SCHEMA = {
"Full SKILL.md content (YAML frontmatter + markdown body). "
"Required for 'create' and 'edit'. For 'edit', read the skill "
"first with skill_view() and provide the complete updated text."
)
),
},
"old_string": {
"type": "string",
@@ -563,18 +598,17 @@ SKILL_MANAGE_SCHEMA = {
"Text to find in the file (required for 'patch'). Must be unique "
"unless replace_all=true. Include enough surrounding context to "
"ensure uniqueness."
)
),
},
"new_string": {
"type": "string",
"description": (
"Replacement text (required for 'patch'). Can be empty string "
"to delete the matched text."
)
"Replacement text (required for 'patch'). Can be empty string to delete the matched text."
),
},
"replace_all": {
"type": "boolean",
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false)."
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false).",
},
"category": {
"type": "string",
@@ -582,7 +616,7 @@ SKILL_MANAGE_SCHEMA = {
"Optional category/domain for organizing the skill (e.g., 'devops', "
"'data-science', 'mlops'). Creates a subdirectory grouping. "
"Only used with 'create'."
)
),
},
"file_path": {
"type": "string",
@@ -591,12 +625,9 @@ SKILL_MANAGE_SCHEMA = {
"For 'write_file'/'remove_file': required, must be under references/, "
"templates/, scripts/, or assets/. "
"For 'patch': optional, defaults to SKILL.md if omitted."
)
},
"file_content": {
"type": "string",
"description": "Content for the file. Required for 'write_file'."
),
},
"file_content": {"type": "string", "description": "Content for the file. Required for 'write_file'."},
},
"required": ["action", "name"],
},
@@ -619,5 +650,6 @@ registry.register(
file_content=args.get("file_content"),
old_string=args.get("old_string"),
new_string=args.get("new_string"),
replace_all=args.get("replace_all", False)),
replace_all=args.get("replace_all", False),
),
)

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More