mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-10 04:08:28 +08:00
Compare commits
1 Commits
bb/tui-ans
...
feat/devex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4d7e6a29e |
18
.editorconfig
Normal file
18
.editorconfig
Normal file
@@ -0,0 +1,18 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.{yml,yaml,json,toml}]
|
||||
indent_size = 2
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[Makefile]
|
||||
indent_style = tab
|
||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -46,7 +46,7 @@ Fixes #
|
||||
- [ ] My commit messages follow [Conventional Commits](https://www.conventionalcommits.org/) (`fix(scope):`, `feat(scope):`, etc.)
|
||||
- [ ] I searched for [existing PRs](https://github.com/NousResearch/hermes-agent/pulls) to make sure this isn't a duplicate
|
||||
- [ ] My PR contains **only** changes related to this fix/feature (no unrelated commits)
|
||||
- [ ] I've run `pytest tests/ -q` and all tests pass
|
||||
- [ ] I've run `make check` (lint + test) and all checks pass
|
||||
- [ ] I've added tests for my changes (required for bug fixes, strongly encouraged for features)
|
||||
- [ ] I've tested on my platform: <!-- e.g. Ubuntu 24.04, macOS 15.2, Windows 11 -->
|
||||
|
||||
|
||||
41
.github/workflows/tests.yml
vendored
41
.github/workflows/tests.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Tests
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -6,37 +6,42 @@ on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
# Cancel in-progress runs for the same PR/branch
|
||||
concurrency:
|
||||
group: tests-${{ github.ref }}
|
||||
group: ci-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
SRC: >-
|
||||
run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py
|
||||
tools/ hermes_cli/ gateway/ agent/ cron/
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 3
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
- run: uvx ruff check $SRC
|
||||
- run: uvx ruff format --check $SRC
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
||||
- name: Set up Python 3.11
|
||||
run: uv python install 3.11
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
- run: uv python install 3.11
|
||||
- run: |
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
- run: |
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q --ignore=tests/integration --tb=short
|
||||
env:
|
||||
# Ensure tests don't accidentally call real APIs
|
||||
OPENROUTER_API_KEY: ""
|
||||
OPENAI_API_KEY: ""
|
||||
NOUS_API_KEY: ""
|
||||
|
||||
80
.gitignore
vendored
80
.gitignore
vendored
@@ -1,51 +1,53 @@
|
||||
/venv/
|
||||
/_pycache/
|
||||
*.pyc*
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Environments
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# Tools
|
||||
.ruff_cache/
|
||||
.mypy_cache/
|
||||
.pytest_cache/
|
||||
|
||||
# Editors
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Secrets & config
|
||||
.env
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
.env.development
|
||||
.env.test
|
||||
export*
|
||||
__pycache__/model_tools.cpython-310.pyc
|
||||
__pycache__/web_tools.cpython-310.pyc
|
||||
.env.*.local
|
||||
*.pem
|
||||
*.ppk
|
||||
|
||||
# Node
|
||||
node_modules/
|
||||
|
||||
# Project-specific
|
||||
logs/
|
||||
data/
|
||||
.pytest_cache/
|
||||
tmp/
|
||||
temp_vision_images/
|
||||
hermes-*/*
|
||||
examples/
|
||||
tests/quick_test_dataset.jsonl
|
||||
tests/sample_dataset.jsonl
|
||||
run_datagen_kimik2-thinking.sh
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
run_datagen_sonnet.sh
|
||||
source-data/*
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
data/*
|
||||
node_modules/
|
||||
wandb/
|
||||
images/
|
||||
browser-use/
|
||||
agent-browser/
|
||||
# Private keys
|
||||
*.ppk
|
||||
*.pem
|
||||
privvy*
|
||||
images/
|
||||
__pycache__/
|
||||
hermes_agent.egg-info/
|
||||
wandb/
|
||||
testlogs
|
||||
|
||||
# CLI config (may contain sensitive SSH paths)
|
||||
source-data/
|
||||
testlogs/
|
||||
ignored/
|
||||
.worktrees/
|
||||
temp_vision_images/
|
||||
cli-config.yaml
|
||||
|
||||
# Skills Hub state (lives in ~/.hermes/skills/.hub/ at runtime, but just in case)
|
||||
skills/.hub/
|
||||
ignored/
|
||||
.worktrees/
|
||||
hermes-*/*
|
||||
examples/
|
||||
export*
|
||||
privvy*
|
||||
run_datagen_*.sh
|
||||
tests/quick_test_dataset.jsonl
|
||||
tests/sample_dataset.jsonl
|
||||
|
||||
18
.pre-commit-config.yaml
Normal file
18
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-merge-conflict
|
||||
- id: check-yaml
|
||||
args: [--allow-multiple-documents]
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=500]
|
||||
23
AGENTS.md
23
AGENTS.md
@@ -5,7 +5,8 @@ Instructions for AI coding assistants and developers working on the hermes-agent
|
||||
## Development Environment
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # ALWAYS activate before running Python
|
||||
make setup # First time: creates .venv, installs deps, sets up pre-commit
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
@@ -228,15 +229,27 @@ The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HER
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
## Development Commands
|
||||
|
||||
```bash
|
||||
make setup # First time: .venv + deps + pre-commit hooks
|
||||
make check # Lint + test (mirrors CI — run before pushing)
|
||||
make lint # Ruff check
|
||||
make fmt # Ruff format + auto-fix
|
||||
make test # Full test suite (~2500 tests, ~2 min)
|
||||
make test-fast # Tests with fail-fast (-x)
|
||||
make test-watch # Rerun tests on file changes
|
||||
make dev-cli # Auto-restart CLI on file changes
|
||||
make dev-gateway # Auto-restart gateway on file changes
|
||||
```
|
||||
|
||||
For targeted testing, use `pytest` directly:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q # Full suite (~2500 tests, ~2 min)
|
||||
python -m pytest tests/test_model_tools.py -q # Toolset resolution
|
||||
python -m pytest tests/test_cli_init.py -q # CLI config loading
|
||||
python -m pytest tests/gateway/ -q # Gateway tests
|
||||
python -m pytest tests/tools/ -q # Tool-level tests
|
||||
```
|
||||
|
||||
Always run the full suite before pushing changes.
|
||||
Formatting is enforced by **ruff** (config in `pyproject.toml`). Pre-commit hooks run on every commit.
|
||||
|
||||
@@ -65,18 +65,7 @@ If your skill is specialized, community-contributed, or niche, it's better suite
|
||||
```bash
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
|
||||
# Create venv with Python 3.11
|
||||
uv venv venv --python 3.11
|
||||
export VIRTUAL_ENV="$(pwd)/venv"
|
||||
|
||||
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
||||
uv pip install -e ".[all,dev]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
uv pip install -e "./tinker-atropos"
|
||||
|
||||
# Optional: browser tools
|
||||
npm install
|
||||
make setup # creates .venv, installs all deps
|
||||
```
|
||||
|
||||
### Configure for development
|
||||
@@ -90,22 +79,16 @@ touch ~/.hermes/.env
|
||||
echo 'OPENROUTER_API_KEY=sk-or-v1-your-key' >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
### Run
|
||||
### Common commands
|
||||
|
||||
```bash
|
||||
# Symlink for global access
|
||||
mkdir -p ~/.local/bin
|
||||
ln -sf "$(pwd)/venv/bin/hermes" ~/.local/bin/hermes
|
||||
|
||||
# Verify
|
||||
hermes doctor
|
||||
hermes chat -q "Hello"
|
||||
```
|
||||
|
||||
### Run tests
|
||||
|
||||
```bash
|
||||
pytest tests/ -v
|
||||
make test # run unit tests
|
||||
make lint # ruff check
|
||||
make fmt # ruff format + fix
|
||||
make check # lint + test (same as CI)
|
||||
make dev-cli # auto-restart hermes CLI on file changes
|
||||
make dev-gateway # auto-restart gateway on file changes
|
||||
make test-watch # rerun tests on file changes
|
||||
```
|
||||
|
||||
---
|
||||
@@ -227,7 +210,7 @@ User message → AIAgent._run_agent_loop()
|
||||
|
||||
## Code Style
|
||||
|
||||
- **PEP 8** with practical exceptions (we don't enforce strict line length)
|
||||
- **Formatting**: Enforced by **ruff** (config in `pyproject.toml`). Run `make fmt` to auto-fix, `make lint` to check. Pre-commit hooks handle this automatically.
|
||||
- **Comments**: Only when explaining non-obvious intent, trade-offs, or API quirks. Don't narrate what the code does — `# increment counter` adds nothing
|
||||
- **Error handling**: Catch specific exceptions. Log with `logger.warning()`/`logger.error()` — use `exc_info=True` for unexpected errors so stack traces appear in logs
|
||||
- **Cross-platform**: Never assume Unix. See [Cross-Platform Compatibility](#cross-platform-compatibility)
|
||||
@@ -457,7 +440,7 @@ refactor/description # Code restructuring
|
||||
|
||||
### Before submitting
|
||||
|
||||
1. **Run tests**: `pytest tests/ -v`
|
||||
1. **Run checks**: `make check` (lint + test — same as CI)
|
||||
2. **Test manually**: Run `hermes` and exercise the code path you changed
|
||||
3. **Check cross-platform impact**: If you touch file I/O, process management, or terminal handling, consider Windows and macOS
|
||||
4. **Keep PRs focused**: One logical change per PR. Don't mix a bug fix with a refactor with a new feature.
|
||||
|
||||
69
Makefile
Normal file
69
Makefile
Normal file
@@ -0,0 +1,69 @@
|
||||
.DEFAULT_GOAL := help
|
||||
SHELL := /bin/bash
|
||||
VENV := .venv
|
||||
UV := uv
|
||||
|
||||
SRC := run_agent.py model_tools.py toolsets.py cli.py hermes_state.py batch_runner.py \
|
||||
tools/ hermes_cli/ gateway/ agent/ cron/
|
||||
|
||||
# ─── Setup ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: setup sync clean
|
||||
|
||||
setup: ## Full dev setup (venv + deps + pre-commit)
|
||||
$(UV) venv $(VENV) --python 3.11
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]"
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e "./mini-swe-agent"
|
||||
. $(VENV)/bin/activate && pre-commit install
|
||||
@echo "\n✅ Setup complete. Run: source $(VENV)/bin/activate"
|
||||
|
||||
sync: ## Reinstall deps into existing venv
|
||||
. $(VENV)/bin/activate && $(UV) pip install -e ".[all,dev]"
|
||||
|
||||
clean: ## Remove build artifacts and caches
|
||||
rm -rf .ruff_cache .mypy_cache .pytest_cache dist build *.egg-info
|
||||
find . -type d -name __pycache__ -not -path "./.venv/*" -exec rm -rf {} +
|
||||
|
||||
# ─── Quality ────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: lint fmt check
|
||||
|
||||
lint: ## Check lint + formatting (no changes)
|
||||
. $(VENV)/bin/activate && ruff check $(SRC)
|
||||
. $(VENV)/bin/activate && ruff format --check $(SRC)
|
||||
|
||||
fmt: ## Auto-fix lint + format
|
||||
. $(VENV)/bin/activate && ruff format $(SRC)
|
||||
. $(VENV)/bin/activate && ruff check --fix $(SRC)
|
||||
|
||||
check: lint test ## Lint + test (mirrors CI)
|
||||
|
||||
# ─── Test ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: test test-fast test-watch
|
||||
|
||||
test: ## Run full test suite
|
||||
. $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short
|
||||
|
||||
test-fast: ## Run tests with fail-fast
|
||||
. $(VENV)/bin/activate && python -m pytest tests/ -q --ignore=tests/integration --tb=short -x
|
||||
|
||||
test-watch: ## Rerun tests on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m pytest tests/ -q --ignore=tests/integration --tb=short -x" $(SRC) tests/
|
||||
|
||||
# ─── Dev Servers ────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: dev-cli dev-gateway
|
||||
|
||||
dev-cli: ## Auto-restart CLI on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m hermes_cli.main" $(SRC)
|
||||
|
||||
dev-gateway: ## Auto-restart gateway on file changes
|
||||
. $(VENV)/bin/activate && python -m watchfiles "python -m gateway.run" $(SRC)
|
||||
|
||||
# ─── Misc ───────────────────────────────────────────────────────────────────────
|
||||
|
||||
.PHONY: help
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}'
|
||||
@@ -95,12 +95,8 @@ Quick start for contributors:
|
||||
```bash
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all,dev]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
python -m pytest tests/ -q
|
||||
make setup # creates .venv, installs everything
|
||||
make check # lint + test (same as CI)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -34,7 +34,7 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -43,7 +43,7 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
_API_KEY_PROVIDER_AUX_MODELS: dict[str, str] = {
|
||||
"zai": "glm-4.5-flash",
|
||||
"kimi-coding": "kimi-k2-turbo-preview",
|
||||
"minimax": "MiniMax-M2.5-highspeed",
|
||||
@@ -102,7 +102,7 @@ def _convert_content_for_responses(content: Any) -> Any:
|
||||
if not isinstance(content, list):
|
||||
return str(content) if content else ""
|
||||
|
||||
converted: List[Dict[str, Any]] = []
|
||||
converted: list[dict[str, Any]] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
@@ -113,7 +113,7 @@ def _convert_content_for_responses(content: Any) -> Any:
|
||||
# chat.completions nests the URL: {"image_url": {"url": "..."}}
|
||||
image_data = part.get("image_url", {})
|
||||
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
|
||||
entry: Dict[str, Any] = {"type": "input_image", "image_url": url}
|
||||
entry: dict[str, Any] = {"type": "input_image", "image_url": url}
|
||||
# Preserve detail if specified
|
||||
detail = image_data.get("detail") if isinstance(image_data, dict) else None
|
||||
if detail:
|
||||
@@ -148,19 +148,21 @@ class _CodexCompletionsAdapter:
|
||||
# Convert chat.completions multimodal content blocks to Responses
|
||||
# API format (input_text / input_image instead of text / image_url).
|
||||
instructions = "You are a helpful assistant."
|
||||
input_msgs: List[Dict[str, Any]] = []
|
||||
input_msgs: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content") or ""
|
||||
if role == "system":
|
||||
instructions = content if isinstance(content, str) else str(content)
|
||||
else:
|
||||
input_msgs.append({
|
||||
"role": role,
|
||||
"content": _convert_content_for_responses(content),
|
||||
})
|
||||
input_msgs.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": _convert_content_for_responses(content),
|
||||
}
|
||||
)
|
||||
|
||||
resp_kwargs: Dict[str, Any] = {
|
||||
resp_kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"instructions": instructions,
|
||||
"input": input_msgs or [{"role": "user", "content": ""}],
|
||||
@@ -179,18 +181,20 @@ class _CodexCompletionsAdapter:
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
})
|
||||
converted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
if converted:
|
||||
resp_kwargs["tools"] = converted
|
||||
|
||||
# Stream and collect the response
|
||||
text_parts: List[str] = []
|
||||
tool_calls_raw: List[Any] = []
|
||||
text_parts: list[str] = []
|
||||
tool_calls_raw: list[Any] = []
|
||||
usage = None
|
||||
|
||||
try:
|
||||
@@ -208,14 +212,16 @@ class _CodexCompletionsAdapter:
|
||||
if ptype in ("output_text", "text"):
|
||||
text_parts.append(getattr(part, "text", ""))
|
||||
elif item_type == "function_call":
|
||||
tool_calls_raw.append(SimpleNamespace(
|
||||
id=getattr(item, "call_id", ""),
|
||||
type="function",
|
||||
function=SimpleNamespace(
|
||||
name=getattr(item, "name", ""),
|
||||
arguments=getattr(item, "arguments", "{}"),
|
||||
),
|
||||
))
|
||||
tool_calls_raw.append(
|
||||
SimpleNamespace(
|
||||
id=getattr(item, "call_id", ""),
|
||||
type="function",
|
||||
function=SimpleNamespace(
|
||||
name=getattr(item, "name", ""),
|
||||
arguments=getattr(item, "arguments", "{}"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
resp_usage = getattr(final, "usage", None)
|
||||
if resp_usage:
|
||||
@@ -285,6 +291,7 @@ class _AsyncCodexCompletionsAdapter:
|
||||
|
||||
async def create(self, **kwargs) -> Any:
|
||||
import asyncio
|
||||
|
||||
return await asyncio.to_thread(self._sync.create, **kwargs)
|
||||
|
||||
|
||||
@@ -304,7 +311,7 @@ class AsyncCodexAuxiliaryClient:
|
||||
self.base_url = sync_wrapper.base_url
|
||||
|
||||
|
||||
def _read_nous_auth() -> Optional[dict]:
|
||||
def _read_nous_auth() -> dict | None:
|
||||
"""Read and validate ~/.hermes/auth.json for an active Nous provider.
|
||||
|
||||
Returns the provider state dict if Nous is active with tokens,
|
||||
@@ -336,10 +343,11 @@ def _nous_base_url() -> str:
|
||||
return os.getenv("NOUS_INFERENCE_BASE_URL", _NOUS_DEFAULT_BASE_URL)
|
||||
|
||||
|
||||
def _read_codex_access_token() -> Optional[str]:
|
||||
def _read_codex_access_token() -> str | None:
|
||||
"""Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json)."""
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
|
||||
data = _read_codex_tokens()
|
||||
tokens = data.get("tokens", {})
|
||||
access_token = tokens.get("access_token")
|
||||
@@ -351,7 +359,7 @@ def _read_codex_access_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
def _resolve_api_key_provider() -> tuple[OpenAI | None, str | None]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
Returns (client, model) for the first provider whose env var is set,
|
||||
@@ -398,6 +406,7 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
|
||||
# ── Provider resolution helpers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_auxiliary_provider(task: str = "") -> str:
|
||||
"""Read the provider override for a specific auxiliary task.
|
||||
|
||||
@@ -413,16 +422,15 @@ def _get_auxiliary_provider(task: str = "") -> str:
|
||||
return "auto"
|
||||
|
||||
|
||||
def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
def _try_openrouter() -> tuple[OpenAI | None, str | None]:
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not or_key:
|
||||
return None, None
|
||||
logger.debug("Auxiliary client: OpenRouter")
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
|
||||
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL, default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
|
||||
|
||||
def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
def _try_nous() -> tuple[OpenAI | None, str | None]:
|
||||
nous = _read_nous_auth()
|
||||
if not nous:
|
||||
return None, None
|
||||
@@ -435,7 +443,7 @@ def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
)
|
||||
|
||||
|
||||
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
def _try_custom_endpoint() -> tuple[OpenAI | None, str | None]:
|
||||
custom_base = os.getenv("OPENAI_BASE_URL")
|
||||
custom_key = os.getenv("OPENAI_API_KEY")
|
||||
if not custom_base or not custom_key:
|
||||
@@ -445,7 +453,7 @@ def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
return OpenAI(api_key=custom_key, base_url=custom_base), model
|
||||
|
||||
|
||||
def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
|
||||
def _try_codex() -> tuple[Any | None, str | None]:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
return None, None
|
||||
@@ -454,7 +462,7 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
def _resolve_forced_provider(forced: str) -> tuple[OpenAI | None, str | None]:
|
||||
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
|
||||
if forced == "openrouter":
|
||||
client, model = _try_openrouter()
|
||||
@@ -488,10 +496,9 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st
|
||||
return None, None
|
||||
|
||||
|
||||
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
def _resolve_auto() -> tuple[OpenAI | None, str | None]:
|
||||
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
|
||||
_try_codex, _resolve_api_key_provider):
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint, _try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
@@ -501,7 +508,8 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
|
||||
def get_text_auxiliary_client(task: str = "") -> tuple[OpenAI | None, str | None]:
|
||||
"""Return (client, default_model_slug) for text-only auxiliary tasks.
|
||||
|
||||
Args:
|
||||
@@ -544,7 +552,7 @@ def get_async_text_auxiliary_client(task: str = ""):
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
def get_vision_auxiliary_client() -> tuple[OpenAI | None, str | None]:
|
||||
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks.
|
||||
|
||||
Checks AUXILIARY_VISION_PROVIDER for a forced provider, otherwise
|
||||
@@ -564,8 +572,7 @@ def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
# back to the user's custom endpoint. Many local models (Qwen-VL,
|
||||
# LLaVA, Pixtral, etc.) support vision — skipping them entirely
|
||||
# caused silent failures for local-only users.
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex,
|
||||
_try_custom_endpoint):
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex, _try_custom_endpoint):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
@@ -575,7 +582,7 @@ def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
|
||||
def get_auxiliary_extra_body() -> dict:
|
||||
"""Return extra_body kwargs for auxiliary API calls.
|
||||
|
||||
|
||||
Includes Nous Portal product tags when the auxiliary client is backed
|
||||
by Nous Portal. Returns empty dict otherwise.
|
||||
"""
|
||||
@@ -584,7 +591,7 @@ def get_auxiliary_extra_body() -> dict:
|
||||
|
||||
def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
"""Return the correct max tokens kwarg for the auxiliary client's provider.
|
||||
|
||||
|
||||
OpenRouter and local models use 'max_tokens'. Direct OpenAI with newer
|
||||
models (gpt-4o, o-series, gpt-5+) requires 'max_completion_tokens'.
|
||||
The Codex adapter translates max_tokens internally, so we use max_tokens
|
||||
@@ -593,8 +600,6 @@ def auxiliary_max_tokens_param(value: int) -> dict:
|
||||
custom_base = os.getenv("OPENAI_BASE_URL", "")
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
# Only use max_completion_tokens for direct OpenAI custom endpoints
|
||||
if (not or_key
|
||||
and _read_nous_auth() is None
|
||||
and "api.openai.com" in custom_base.lower()):
|
||||
if not or_key and _read_nous_auth() is None and "api.openai.com" in custom_base.lower():
|
||||
return {"max_completion_tokens": value}
|
||||
return {"max_tokens": value}
|
||||
|
||||
@@ -7,12 +7,12 @@ protecting head and tail context.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
from agent.model_metadata import (
|
||||
get_model_context_length,
|
||||
estimate_messages_tokens_rough,
|
||||
get_model_context_length,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,7 +56,7 @@ class ContextCompressor:
|
||||
self.client, default_model = get_text_auxiliary_client("compression")
|
||||
self.summary_model = summary_model_override or default_model
|
||||
|
||||
def update_from_response(self, usage: Dict[str, Any]):
|
||||
def update_from_response(self, usage: dict[str, Any]):
|
||||
"""Update tracked token usage from API response."""
|
||||
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
self.last_completion_tokens = usage.get("completion_tokens", 0)
|
||||
@@ -67,12 +67,12 @@ class ContextCompressor:
|
||||
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
|
||||
return tokens >= self.threshold_tokens
|
||||
|
||||
def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool:
|
||||
def should_compress_preflight(self, messages: list[dict[str, Any]]) -> bool:
|
||||
"""Quick pre-flight check using rough estimate (before API call)."""
|
||||
rough_estimate = estimate_messages_tokens_rough(messages)
|
||||
return rough_estimate >= self.threshold_tokens
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""Get current compression status for display/logging."""
|
||||
return {
|
||||
"last_prompt_tokens": self.last_prompt_tokens,
|
||||
@@ -82,7 +82,7 @@ class ContextCompressor:
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
def _generate_summary(self, turns_to_summarize: list[dict[str, Any]]) -> str | None:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
|
||||
Tries the auxiliary model first, then falls back to the user's main
|
||||
@@ -140,7 +140,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
logging.warning(f"Main model summary also failed: {fallback_err}")
|
||||
|
||||
# 3. All models failed — return None so the caller drops turns without a summary
|
||||
logging.warning("Context compression: no model available for summary. Middle turns will be dropped without summary.")
|
||||
logging.warning(
|
||||
"Context compression: no model available for summary. Middle turns will be dropped without summary."
|
||||
)
|
||||
return None
|
||||
|
||||
def _call_summary_model(self, client, model: str, prompt: str) -> str:
|
||||
@@ -186,12 +188,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
|
||||
# Don't fallback to the same provider that just failed
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
if custom_base.rstrip("/") == OPENROUTER_BASE_URL.rstrip("/"):
|
||||
return None, None
|
||||
|
||||
model = os.getenv("LLM_MODEL") or os.getenv("OPENAI_MODEL") or self.model
|
||||
try:
|
||||
from openai import OpenAI as _OpenAI
|
||||
|
||||
client = _OpenAI(api_key=custom_key, base_url=custom_base)
|
||||
logger.debug("Built fallback auxiliary client: %s via %s", model, custom_base)
|
||||
return client, model
|
||||
@@ -210,7 +214,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
return tc.get("id", "")
|
||||
return getattr(tc, "id", "") or ""
|
||||
|
||||
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def _sanitize_tool_pairs(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Fix orphaned tool_call / tool_result pairs after compression.
|
||||
|
||||
Two failure modes:
|
||||
@@ -243,8 +247,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
orphaned_results = result_call_ids - surviving_call_ids
|
||||
if orphaned_results:
|
||||
messages = [
|
||||
m for m in messages
|
||||
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
m for m in messages if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
]
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
|
||||
@@ -252,25 +255,27 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
# 2. Add stub results for assistant tool_calls whose results were dropped
|
||||
missing_results = surviving_call_ids - result_call_ids
|
||||
if missing_results:
|
||||
patched: List[Dict[str, Any]] = []
|
||||
patched: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid in missing_results:
|
||||
patched.append({
|
||||
"role": "tool",
|
||||
"content": "[Result from earlier conversation — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
})
|
||||
patched.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "[Result from earlier conversation — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
}
|
||||
)
|
||||
messages = patched
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
|
||||
|
||||
return messages
|
||||
|
||||
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
def _align_boundary_forward(self, messages: list[dict[str, Any]], idx: int) -> int:
|
||||
"""Push a compress-start boundary forward past any orphan tool results.
|
||||
|
||||
If ``messages[idx]`` is a tool result, slide forward until we hit a
|
||||
@@ -280,7 +285,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
def _align_boundary_backward(self, messages: list[dict[str, Any]], idx: int) -> int:
|
||||
"""Pull a compress-end boundary backward to avoid splitting a
|
||||
tool_call / result group.
|
||||
|
||||
@@ -298,7 +303,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
idx -= 1
|
||||
return idx
|
||||
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
def compress(self, messages: list[dict[str, Any]], current_tokens: int = None) -> list[dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Keeps first N + last N turns, summarizes everything in between.
|
||||
@@ -308,7 +313,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
if not self.quiet_mode:
|
||||
print(f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})")
|
||||
print(
|
||||
f"⚠️ Cannot compress: only {n_messages} messages (need > {self.protect_first_n + self.protect_last_n + 1})"
|
||||
)
|
||||
return messages
|
||||
|
||||
compress_start = self.protect_first_n
|
||||
@@ -323,14 +330,20 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
display_tokens = (
|
||||
current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
)
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
|
||||
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
|
||||
print(
|
||||
f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)"
|
||||
)
|
||||
print(
|
||||
f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent * 100:.0f}% = {self.threshold_tokens:,})"
|
||||
)
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
print(f" 🗜️ Summarizing turns {compress_start + 1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
|
||||
summary = self._generate_summary(turns_to_summarize)
|
||||
|
||||
@@ -338,7 +351,9 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
for i in range(compress_start):
|
||||
msg = messages[i].copy()
|
||||
if i == 0 and msg.get("role") == "system" and self.compression_count == 0:
|
||||
msg["content"] = (msg.get("content") or "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
|
||||
msg["content"] = (
|
||||
msg.get("content") or ""
|
||||
) + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
|
||||
compressed.append(msg)
|
||||
|
||||
if summary:
|
||||
|
||||
282
agent/display.py
282
agent/display.py
@@ -6,7 +6,6 @@ Used by AIAgent._execute_tool_calls for CLI feedback.
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -20,19 +19,31 @@ _RESET = "\033[0m"
|
||||
# Tool preview (one-line summary of a tool call's primary argument)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
"""Build a short preview of a tool call's primary argument for display."""
|
||||
primary_args = {
|
||||
"terminal": "command", "web_search": "query", "web_extract": "urls",
|
||||
"read_file": "path", "write_file": "path", "patch": "path",
|
||||
"search_files": "pattern", "browser_navigate": "url",
|
||||
"browser_click": "ref", "browser_type": "text",
|
||||
"image_generate": "prompt", "text_to_speech": "text",
|
||||
"vision_analyze": "question", "mixture_of_agents": "user_prompt",
|
||||
"skill_view": "name", "skills_list": "category",
|
||||
"terminal": "command",
|
||||
"web_search": "query",
|
||||
"web_extract": "urls",
|
||||
"read_file": "path",
|
||||
"write_file": "path",
|
||||
"patch": "path",
|
||||
"search_files": "pattern",
|
||||
"browser_navigate": "url",
|
||||
"browser_click": "ref",
|
||||
"browser_type": "text",
|
||||
"image_generate": "prompt",
|
||||
"text_to_speech": "text",
|
||||
"vision_analyze": "question",
|
||||
"mixture_of_agents": "user_prompt",
|
||||
"skill_view": "name",
|
||||
"skills_list": "category",
|
||||
"schedule_cronjob": "name",
|
||||
"execute_code": "code", "delegate_task": "goal",
|
||||
"clarify": "question", "skill_manage": "name",
|
||||
"execute_code": "code",
|
||||
"delegate_task": "goal",
|
||||
"clarify": "question",
|
||||
"skill_manage": "name",
|
||||
}
|
||||
|
||||
if tool_name == "process":
|
||||
@@ -61,18 +72,18 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
|
||||
if tool_name == "session_search":
|
||||
query = args.get("query", "")
|
||||
return f"recall: \"{query[:25]}{'...' if len(query) > 25 else ''}\""
|
||||
return f'recall: "{query[:25]}{"..." if len(query) > 25 else ""}"'
|
||||
|
||||
if tool_name == "memory":
|
||||
action = args.get("action", "")
|
||||
target = args.get("target", "")
|
||||
if action == "add":
|
||||
content = args.get("content", "")
|
||||
return f"+{target}: \"{content[:25]}{'...' if len(content) > 25 else ''}\""
|
||||
return f'+{target}: "{content[:25]}{"..." if len(content) > 25 else ""}"'
|
||||
elif action == "replace":
|
||||
return f"~{target}: \"{args.get('old_text', '')[:20]}\""
|
||||
return f'~{target}: "{args.get("old_text", "")[:20]}"'
|
||||
elif action == "remove":
|
||||
return f"-{target}: \"{args.get('old_text', '')[:20]}\""
|
||||
return f'-{target}: "{args.get("old_text", "")[:20]}"'
|
||||
return action
|
||||
|
||||
if tool_name == "send_message":
|
||||
@@ -80,7 +91,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
msg = args.get("message", "")
|
||||
if len(msg) > 20:
|
||||
msg = msg[:17] + "..."
|
||||
return f"to {target}: \"{msg}\""
|
||||
return f'to {target}: "{msg}"'
|
||||
|
||||
if tool_name.startswith("rl_"):
|
||||
rl_previews = {
|
||||
@@ -115,7 +126,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
if not preview:
|
||||
return None
|
||||
if len(preview) > max_len:
|
||||
preview = preview[:max_len - 3] + "..."
|
||||
preview = preview[: max_len - 3] + "..."
|
||||
return preview
|
||||
|
||||
|
||||
@@ -123,41 +134,74 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
# KawaiiSpinner
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class KawaiiSpinner:
|
||||
"""Animated spinner with kawaii faces for CLI feedback during tool execution."""
|
||||
|
||||
SPINNERS = {
|
||||
'dots': ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'],
|
||||
'bounce': ['⠁', '⠂', '⠄', '⡀', '⢀', '⠠', '⠐', '⠈'],
|
||||
'grow': ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█', '▇', '▆', '▅', '▄', '▃', '▂'],
|
||||
'arrows': ['←', '↖', '↑', '↗', '→', '↘', '↓', '↙'],
|
||||
'star': ['✶', '✷', '✸', '✹', '✺', '✹', '✸', '✷'],
|
||||
'moon': ['🌑', '🌒', '🌓', '🌔', '🌕', '🌖', '🌗', '🌘'],
|
||||
'pulse': ['◜', '◠', '◝', '◞', '◡', '◟'],
|
||||
'brain': ['🧠', '💭', '💡', '✨', '💫', '🌟', '💡', '💭'],
|
||||
'sparkle': ['⁺', '˚', '*', '✧', '✦', '✧', '*', '˚'],
|
||||
"dots": ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"],
|
||||
"bounce": ["⠁", "⠂", "⠄", "⡀", "⢀", "⠠", "⠐", "⠈"],
|
||||
"grow": ["▁", "▂", "▃", "▄", "▅", "▆", "▇", "█", "▇", "▆", "▅", "▄", "▃", "▂"],
|
||||
"arrows": ["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"],
|
||||
"star": ["✶", "✷", "✸", "✹", "✺", "✹", "✸", "✷"],
|
||||
"moon": ["🌑", "🌒", "🌓", "🌔", "🌕", "🌖", "🌗", "🌘"],
|
||||
"pulse": ["◜", "◠", "◝", "◞", "◡", "◟"],
|
||||
"brain": ["🧠", "💭", "💡", "✨", "💫", "🌟", "💡", "💭"],
|
||||
"sparkle": ["⁺", "˚", "*", "✧", "✦", "✧", "*", "˚"],
|
||||
}
|
||||
|
||||
KAWAII_WAITING = [
|
||||
"(。◕‿◕。)", "(◕‿◕✿)", "٩(◕‿◕。)۶", "(✿◠‿◠)", "( ˘▽˘)っ",
|
||||
"♪(´ε` )", "(◕ᴗ◕✿)", "ヾ(^∇^)", "(≧◡≦)", "(★ω★)",
|
||||
"(。◕‿◕。)",
|
||||
"(◕‿◕✿)",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"( ˘▽˘)っ",
|
||||
"♪(´ε` )",
|
||||
"(◕ᴗ◕✿)",
|
||||
"ヾ(^∇^)",
|
||||
"(≧◡≦)",
|
||||
"(★ω★)",
|
||||
]
|
||||
|
||||
KAWAII_THINKING = [
|
||||
"(。•́︿•̀。)", "(◔_◔)", "(¬‿¬)", "( •_•)>⌐■-■", "(⌐■_■)",
|
||||
"(´・_・`)", "◉_◉", "(°ロ°)", "( ˘⌣˘)♡", "ヽ(>∀<☆)☆",
|
||||
"٩(๑❛ᴗ❛๑)۶", "(⊙_⊙)", "(¬_¬)", "( ͡° ͜ʖ ͡°)", "ಠ_ಠ",
|
||||
"(。•́︿•̀。)",
|
||||
"(◔_◔)",
|
||||
"(¬‿¬)",
|
||||
"( •_•)>⌐■-■",
|
||||
"(⌐■_■)",
|
||||
"(´・_・`)",
|
||||
"◉_◉",
|
||||
"(°ロ°)",
|
||||
"( ˘⌣˘)♡",
|
||||
"ヽ(>∀<☆)☆",
|
||||
"٩(๑❛ᴗ❛๑)۶",
|
||||
"(⊙_⊙)",
|
||||
"(¬_¬)",
|
||||
"( ͡° ͜ʖ ͡°)",
|
||||
"ಠ_ಠ",
|
||||
]
|
||||
|
||||
THINKING_VERBS = [
|
||||
"pondering", "contemplating", "musing", "cogitating", "ruminating",
|
||||
"deliberating", "mulling", "reflecting", "processing", "reasoning",
|
||||
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
|
||||
"pondering",
|
||||
"contemplating",
|
||||
"musing",
|
||||
"cogitating",
|
||||
"ruminating",
|
||||
"deliberating",
|
||||
"mulling",
|
||||
"reflecting",
|
||||
"processing",
|
||||
"reasoning",
|
||||
"analyzing",
|
||||
"computing",
|
||||
"synthesizing",
|
||||
"formulating",
|
||||
"brainstorming",
|
||||
]
|
||||
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots'):
|
||||
def __init__(self, message: str = "", spinner_type: str = "dots"):
|
||||
self.message = message
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS["dots"])
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self.frame_idx = 0
|
||||
@@ -167,7 +211,7 @@ class KawaiiSpinner:
|
||||
# child agents can replace sys.stdout with a black hole.
|
||||
self._out = sys.stdout
|
||||
|
||||
def _write(self, text: str, end: str = '\n', flush: bool = False):
|
||||
def _write(self, text: str, end: str = "\n", flush: bool = False):
|
||||
"""Write to the stdout captured at spinner creation time."""
|
||||
try:
|
||||
self._out.write(text + end)
|
||||
@@ -185,7 +229,7 @@ class KawaiiSpinner:
|
||||
elapsed = time.time() - self.start_time
|
||||
line = f" {frame} {self.message} ({elapsed:.1f}s)"
|
||||
pad = max(self.last_line_len - len(line), 0)
|
||||
self._write(f"\r{line}{' ' * pad}", end='', flush=True)
|
||||
self._write(f"\r{line}{' ' * pad}", end="", flush=True)
|
||||
self.last_line_len = len(line)
|
||||
self.frame_idx += 1
|
||||
time.sleep(0.12)
|
||||
@@ -216,7 +260,7 @@ class KawaiiSpinner:
|
||||
# Clear spinner line with spaces (not \033[K) to avoid garbled escape
|
||||
# codes when prompt_toolkit's patch_stdout is active — same approach
|
||||
# as stop(). Then print text; spinner redraws on next tick.
|
||||
blanks = ' ' * max(self.last_line_len + 5, 40)
|
||||
blanks = " " * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r {text}", flush=True)
|
||||
|
||||
def stop(self, final_message: str = None):
|
||||
@@ -225,8 +269,8 @@ class KawaiiSpinner:
|
||||
self.thread.join(timeout=0.5)
|
||||
# Clear the spinner line with spaces instead of \033[K to avoid
|
||||
# garbled escape codes when prompt_toolkit's patch_stdout is active.
|
||||
blanks = ' ' * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r", end='', flush=True)
|
||||
blanks = " " * max(self.last_line_len + 5, 40)
|
||||
self._write(f"\r{blanks}\r", end="", flush=True)
|
||||
if final_message:
|
||||
self._write(f" {final_message}", flush=True)
|
||||
|
||||
@@ -244,38 +288,110 @@ class KawaiiSpinner:
|
||||
# =========================================================================
|
||||
|
||||
KAWAII_SEARCH = [
|
||||
"♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ",
|
||||
"٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)ノ*:・゚✧", "\(◎o◎)/",
|
||||
"♪(´ε` )",
|
||||
"(。◕‿◕。)",
|
||||
"ヾ(^∇^)",
|
||||
"(◕ᴗ◕✿)",
|
||||
"( ˘▽˘)っ",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"♪~(´ε` )",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"\(◎o◎)/",
|
||||
]
|
||||
KAWAII_READ = [
|
||||
"φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)",
|
||||
"ヾ(@⌒ー⌒@)ノ", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )ノ",
|
||||
"φ(゜▽゜*)♪",
|
||||
"( ˘▽˘)っ",
|
||||
"(⌐■_■)",
|
||||
"٩(。•́‿•̀。)۶",
|
||||
"(◕‿◕✿)",
|
||||
"ヾ(@⌒ー⌒@)ノ",
|
||||
"(✧ω✧)",
|
||||
"♪(๑ᴖ◡ᴖ๑)♪",
|
||||
"(≧◡≦)",
|
||||
"( ´ ▽ ` )ノ",
|
||||
]
|
||||
KAWAII_TERMINAL = [
|
||||
"ヽ(>∀<☆)ノ", "(ノ°∀°)ノ", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و",
|
||||
"┗(^0^)┓", "(`・ω・´)", "\( ̄▽ ̄)/", "(ง •̀_•́)ง", "ヽ(´▽`)/",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"(ノ°∀°)ノ",
|
||||
"٩(^ᴗ^)۶",
|
||||
"ヾ(⌐■_■)ノ♪",
|
||||
"(•̀ᴗ•́)و",
|
||||
"┗(^0^)┓",
|
||||
"(`・ω・´)",
|
||||
"\( ̄▽ ̄)/",
|
||||
"(ง •̀_•́)ง",
|
||||
"ヽ(´▽`)/",
|
||||
]
|
||||
KAWAII_BROWSER = [
|
||||
"(ノ°∀°)ノ", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)?",
|
||||
"ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "\(◎o◎)/",
|
||||
"(ノ°∀°)ノ",
|
||||
"(☞゚ヮ゚)☞",
|
||||
"( ͡° ͜ʖ ͡°)",
|
||||
"┌( ಠ_ಠ)┘",
|
||||
"(⊙_⊙)?",
|
||||
"ヾ(•ω•`)o",
|
||||
"( ̄ω ̄)",
|
||||
"( ˇωˇ )",
|
||||
"(ᵔᴥᵔ)",
|
||||
"\(◎o◎)/",
|
||||
]
|
||||
KAWAII_CREATE = [
|
||||
"✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "٩(♡ε♡)۶", "(◕‿◕)♡",
|
||||
"✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(^-^)ノ", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°",
|
||||
"✧*。٩(ˊᗜˋ*)و✧",
|
||||
"(ノ◕ヮ◕)ノ*:・゚✧",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"٩(♡ε♡)۶",
|
||||
"(◕‿◕)♡",
|
||||
"✿◕ ‿ ◕✿",
|
||||
"(*≧▽≦)",
|
||||
"ヾ(^-^)ノ",
|
||||
"(☆▽☆)",
|
||||
"°˖✧◝(⁰▿⁰)◜✧˖°",
|
||||
]
|
||||
KAWAII_SKILL = [
|
||||
"ヾ(@⌒ー⌒@)ノ", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)ノ",
|
||||
"(ノ´ヮ`)ノ*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)",
|
||||
"ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "\(◎o◎)/",
|
||||
"(✧ω✧)", "ヽ(>∀<☆)ノ", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)",
|
||||
"ヾ(@⌒ー⌒@)ノ",
|
||||
"(๑˃ᴗ˂)ﻭ",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿╹◡╹)",
|
||||
"ヽ(・∀・)ノ",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"♪(๑ᴖ◡ᴖ๑)♪",
|
||||
"(◠‿◠)",
|
||||
"٩(ˊᗜˋ*)و",
|
||||
"(^▽^)",
|
||||
"ヾ(^∇^)",
|
||||
"(★ω★)/",
|
||||
"٩(。•́‿•̀。)۶",
|
||||
"(◕ᴗ◕✿)",
|
||||
"\(◎o◎)/",
|
||||
"(✧ω✧)",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"( ˘▽˘)っ",
|
||||
"(≧◡≦) ♡",
|
||||
"ヾ( ̄▽ ̄)",
|
||||
]
|
||||
KAWAII_THINK = [
|
||||
"(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)",
|
||||
"(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )ノ", "(;一_一)",
|
||||
"(っ°Д°;)っ",
|
||||
"(;′⌒`)",
|
||||
"(・_・ヾ",
|
||||
"( ´_ゝ`)",
|
||||
"( ̄ヘ ̄)",
|
||||
"(。-`ω´-)",
|
||||
"( ˘︹˘ )",
|
||||
"(¬_¬)",
|
||||
"ヽ(ー_ー )ノ",
|
||||
"(;一_一)",
|
||||
]
|
||||
KAWAII_GENERIC = [
|
||||
"♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)",
|
||||
"(ノ´ヮ`)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)",
|
||||
"♪(´ε` )",
|
||||
"(◕‿◕✿)",
|
||||
"ヾ(^∇^)",
|
||||
"٩(◕‿◕。)۶",
|
||||
"(✿◠‿◠)",
|
||||
"(ノ´ヮ`)ノ*:・゚✧",
|
||||
"ヽ(>∀<☆)ノ",
|
||||
"(☆▽☆)",
|
||||
"( ˘▽˘)っ",
|
||||
"(≧◡≦)",
|
||||
]
|
||||
|
||||
|
||||
@@ -283,6 +399,7 @@ KAWAII_GENERIC = [
|
||||
# Cute tool message (completion line that replaces the spinner)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
|
||||
"""Inspect a tool result string for signs of failure.
|
||||
|
||||
@@ -321,7 +438,10 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
||||
|
||||
|
||||
def get_cute_tool_message(
|
||||
tool_name: str, args: dict, duration: float, result: str | None = None,
|
||||
tool_name: str,
|
||||
args: dict,
|
||||
duration: float,
|
||||
result: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a formatted tool completion line for CLI quiet mode.
|
||||
|
||||
@@ -335,11 +455,11 @@ def get_cute_tool_message(
|
||||
|
||||
def _trunc(s, n=40):
|
||||
s = str(s)
|
||||
return (s[:n-3] + "...") if len(s) > n else s
|
||||
return (s[: n - 3] + "...") if len(s) > n else s
|
||||
|
||||
def _path(p, n=35):
|
||||
p = str(p)
|
||||
return ("..." + p[-(n-3):]) if len(p) > n else p
|
||||
return ("..." + p[-(n - 3) :]) if len(p) > n else p
|
||||
|
||||
def _wrap(line: str) -> str:
|
||||
"""Append failure suffix when the tool failed."""
|
||||
@@ -354,7 +474,7 @@ def get_cute_tool_message(
|
||||
if urls:
|
||||
url = urls[0] if isinstance(urls, list) else str(urls)
|
||||
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||
extra = f" +{len(urls)-1}" if len(urls) > 1 else ""
|
||||
extra = f" +{len(urls) - 1}" if len(urls) > 1 else ""
|
||||
return _wrap(f"┊ 📄 fetch {_trunc(domain, 35)}{extra} {dur}")
|
||||
return _wrap(f"┊ 📄 fetch pages {dur}")
|
||||
if tool_name == "web_crawl":
|
||||
@@ -366,8 +486,15 @@ def get_cute_tool_message(
|
||||
if tool_name == "process":
|
||||
action = args.get("action", "?")
|
||||
sid = args.get("session_id", "")[:12]
|
||||
labels = {"list": "ls processes", "poll": f"poll {sid}", "log": f"log {sid}",
|
||||
"wait": f"wait {sid}", "kill": f"kill {sid}", "write": f"write {sid}", "submit": f"submit {sid}"}
|
||||
labels = {
|
||||
"list": "ls processes",
|
||||
"poll": f"poll {sid}",
|
||||
"log": f"log {sid}",
|
||||
"wait": f"wait {sid}",
|
||||
"kill": f"kill {sid}",
|
||||
"write": f"write {sid}",
|
||||
"submit": f"submit {sid}",
|
||||
}
|
||||
return _wrap(f"┊ ⚙️ proc {labels.get(action, f'{action} {sid}')} {dur}")
|
||||
if tool_name == "read_file":
|
||||
return _wrap(f"┊ 📖 read {_path(args.get('path', ''))} {dur}")
|
||||
@@ -390,7 +517,7 @@ def get_cute_tool_message(
|
||||
if tool_name == "browser_click":
|
||||
return _wrap(f"┊ 👆 click {args.get('ref', '?')} {dur}")
|
||||
if tool_name == "browser_type":
|
||||
return _wrap(f"┊ ⌨️ type \"{_trunc(args.get('text', ''), 30)}\" {dur}")
|
||||
return _wrap(f'┊ ⌨️ type "{_trunc(args.get("text", ""), 30)}" {dur}')
|
||||
if tool_name == "browser_scroll":
|
||||
d = args.get("direction", "down")
|
||||
arrow = {"down": "↓", "up": "↑", "right": "→", "left": "←"}.get(d, "↓")
|
||||
@@ -415,16 +542,16 @@ def get_cute_tool_message(
|
||||
else:
|
||||
return _wrap(f"┊ 📋 plan {len(todos_arg)} task(s) {dur}")
|
||||
if tool_name == "session_search":
|
||||
return _wrap(f"┊ 🔍 recall \"{_trunc(args.get('query', ''), 35)}\" {dur}")
|
||||
return _wrap(f'┊ 🔍 recall "{_trunc(args.get("query", ""), 35)}" {dur}')
|
||||
if tool_name == "memory":
|
||||
action = args.get("action", "?")
|
||||
target = args.get("target", "")
|
||||
if action == "add":
|
||||
return _wrap(f"┊ 🧠 memory +{target}: \"{_trunc(args.get('content', ''), 30)}\" {dur}")
|
||||
return _wrap(f'┊ 🧠 memory +{target}: "{_trunc(args.get("content", ""), 30)}" {dur}')
|
||||
elif action == "replace":
|
||||
return _wrap(f"┊ 🧠 memory ~{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}")
|
||||
return _wrap(f'┊ 🧠 memory ~{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}')
|
||||
elif action == "remove":
|
||||
return _wrap(f"┊ 🧠 memory -{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}")
|
||||
return _wrap(f'┊ 🧠 memory -{target}: "{_trunc(args.get("old_text", ""), 20)}" {dur}')
|
||||
return _wrap(f"┊ 🧠 memory {action} {dur}")
|
||||
if tool_name == "skills_list":
|
||||
return _wrap(f"┊ 📚 skills list {args.get('category', 'all')} {dur}")
|
||||
@@ -439,7 +566,7 @@ def get_cute_tool_message(
|
||||
if tool_name == "mixture_of_agents":
|
||||
return _wrap(f"┊ 🧠 reason {_trunc(args.get('user_prompt', ''), 30)} {dur}")
|
||||
if tool_name == "send_message":
|
||||
return _wrap(f"┊ 📨 send {args.get('target', '?')}: \"{_trunc(args.get('message', ''), 25)}\" {dur}")
|
||||
return _wrap(f'┊ 📨 send {args.get("target", "?")}: "{_trunc(args.get("message", ""), 25)}" {dur}')
|
||||
if tool_name == "schedule_cronjob":
|
||||
return _wrap(f"┊ ⏰ schedule {_trunc(args.get('name', args.get('prompt', 'task')), 30)} {dur}")
|
||||
if tool_name == "list_cronjobs":
|
||||
@@ -448,11 +575,16 @@ def get_cute_tool_message(
|
||||
return _wrap(f"┊ ⏰ remove job {args.get('job_id', '?')} {dur}")
|
||||
if tool_name.startswith("rl_"):
|
||||
rl = {
|
||||
"rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}",
|
||||
"rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}",
|
||||
"rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}",
|
||||
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}",
|
||||
"rl_list_runs": "list runs", "rl_test_inference": "test inference",
|
||||
"rl_list_environments": "list envs",
|
||||
"rl_select_environment": f"select {args.get('name', '')}",
|
||||
"rl_get_current_config": "get config",
|
||||
"rl_edit_config": f"set {args.get('field', '?')}",
|
||||
"rl_start_training": "start training",
|
||||
"rl_check_status": f"status {args.get('run_id', '?')[:12]}",
|
||||
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}",
|
||||
"rl_get_results": f"results {args.get('run_id', '?')[:12]}",
|
||||
"rl_list_runs": "list runs",
|
||||
"rl_test_inference": "test inference",
|
||||
}
|
||||
return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}")
|
||||
if tool_name == "execute_code":
|
||||
|
||||
@@ -20,7 +20,7 @@ import json
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
# =========================================================================
|
||||
# Model pricing (USD per million tokens) — approximate as of early 2026
|
||||
@@ -81,7 +81,7 @@ def _has_known_pricing(model_name: str) -> bool:
|
||||
return _get_pricing(model_name) is not _DEFAULT_PRICING
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
def _get_pricing(model_name: str) -> dict[str, float]:
|
||||
"""Look up pricing for a model. Uses fuzzy matching on model name.
|
||||
|
||||
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
|
||||
@@ -150,7 +150,7 @@ def _format_duration(seconds: float) -> str:
|
||||
return f"{days:.1f}d"
|
||||
|
||||
|
||||
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
|
||||
def _bar_chart(values: list[int], max_width: int = 20) -> list[str]:
|
||||
"""Create simple horizontal bar chart strings from values."""
|
||||
peak = max(values) if values else 1
|
||||
if peak == 0:
|
||||
@@ -176,7 +176,7 @@ class InsightsEngine:
|
||||
self.db = db
|
||||
self._conn = db._conn
|
||||
|
||||
def generate(self, days: int = 30, source: str = None) -> Dict[str, Any]:
|
||||
def generate(self, days: int = 30, source: str = None) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a complete insights report.
|
||||
|
||||
@@ -233,10 +233,11 @@ class InsightsEngine:
|
||||
# =========================================================================
|
||||
|
||||
# Columns we actually need (skip system_prompt, model_config blobs)
|
||||
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
|
||||
"message_count, tool_call_count, input_tokens, output_tokens")
|
||||
_SESSION_COLS = (
|
||||
"id, source, model, started_at, ended_at, message_count, tool_call_count, input_tokens, output_tokens"
|
||||
)
|
||||
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> list[dict]:
|
||||
"""Fetch sessions within the time window."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
@@ -254,7 +255,7 @@ class InsightsEngine:
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
def _get_tool_usage(self, cutoff: float, source: str = None) -> list[dict]:
|
||||
"""Get tool call counts from messages.
|
||||
|
||||
Uses two sources:
|
||||
@@ -341,12 +342,9 @@ class InsightsEngine:
|
||||
tool_counts = merged
|
||||
|
||||
# Convert to the expected format
|
||||
return [
|
||||
{"tool_name": name, "count": count}
|
||||
for name, count in tool_counts.most_common()
|
||||
]
|
||||
return [{"tool_name": name, "count": count} for name, count in tool_counts.most_common()]
|
||||
|
||||
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
|
||||
def _get_message_stats(self, cutoff: float, source: str = None) -> dict:
|
||||
"""Get aggregate message statistics."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
@@ -373,16 +371,22 @@ class InsightsEngine:
|
||||
(cutoff,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else {
|
||||
"total_messages": 0, "user_messages": 0,
|
||||
"assistant_messages": 0, "tool_messages": 0,
|
||||
}
|
||||
return (
|
||||
dict(row)
|
||||
if row
|
||||
else {
|
||||
"total_messages": 0,
|
||||
"user_messages": 0,
|
||||
"assistant_messages": 0,
|
||||
"tool_messages": 0,
|
||||
}
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Computation
|
||||
# =========================================================================
|
||||
|
||||
def _compute_overview(self, sessions: List[Dict], message_stats: Dict) -> Dict:
|
||||
def _compute_overview(self, sessions: list[dict], message_stats: dict) -> dict:
|
||||
"""Compute high-level overview statistics."""
|
||||
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
|
||||
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
|
||||
@@ -442,12 +446,18 @@ class InsightsEngine:
|
||||
"models_without_pricing": sorted(models_without_pricing),
|
||||
}
|
||||
|
||||
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
|
||||
def _compute_model_breakdown(self, sessions: list[dict]) -> list[dict]:
|
||||
"""Break down usage by model."""
|
||||
model_data = defaultdict(lambda: {
|
||||
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
|
||||
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
|
||||
})
|
||||
model_data = defaultdict(
|
||||
lambda: {
|
||||
"sessions": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"tool_calls": 0,
|
||||
"cost": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
for s in sessions:
|
||||
model = s.get("model") or "unknown"
|
||||
@@ -464,20 +474,23 @@ class InsightsEngine:
|
||||
d["cost"] += _estimate_cost(model, inp, out)
|
||||
d["has_pricing"] = _has_known_pricing(model)
|
||||
|
||||
result = [
|
||||
{"model": model, **data}
|
||||
for model, data in model_data.items()
|
||||
]
|
||||
result = [{"model": model, **data} for model, data in model_data.items()]
|
||||
# Sort by tokens first, fall back to session count when tokens are 0
|
||||
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
|
||||
return result
|
||||
|
||||
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]:
|
||||
def _compute_platform_breakdown(self, sessions: list[dict]) -> list[dict]:
|
||||
"""Break down usage by platform/source."""
|
||||
platform_data = defaultdict(lambda: {
|
||||
"sessions": 0, "messages": 0, "input_tokens": 0,
|
||||
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
|
||||
})
|
||||
platform_data = defaultdict(
|
||||
lambda: {
|
||||
"sessions": 0,
|
||||
"messages": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"tool_calls": 0,
|
||||
}
|
||||
)
|
||||
|
||||
for s in sessions:
|
||||
source = s.get("source") or "unknown"
|
||||
@@ -491,27 +504,26 @@ class InsightsEngine:
|
||||
d["total_tokens"] += inp + out
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
|
||||
result = [
|
||||
{"platform": platform, **data}
|
||||
for platform, data in platform_data.items()
|
||||
]
|
||||
result = [{"platform": platform, **data} for platform, data in platform_data.items()]
|
||||
result.sort(key=lambda x: x["sessions"], reverse=True)
|
||||
return result
|
||||
|
||||
def _compute_tool_breakdown(self, tool_usage: List[Dict]) -> List[Dict]:
|
||||
def _compute_tool_breakdown(self, tool_usage: list[dict]) -> list[dict]:
|
||||
"""Process tool usage data into a ranked list with percentages."""
|
||||
total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0
|
||||
result = []
|
||||
for t in tool_usage:
|
||||
pct = (t["count"] / total_calls * 100) if total_calls else 0
|
||||
result.append({
|
||||
"tool": t["tool_name"],
|
||||
"count": t["count"],
|
||||
"percentage": pct,
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"tool": t["tool_name"],
|
||||
"count": t["count"],
|
||||
"percentage": pct,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def _compute_activity_patterns(self, sessions: List[Dict]) -> Dict:
|
||||
def _compute_activity_patterns(self, sessions: list[dict]) -> dict:
|
||||
"""Analyze activity patterns by day of week and hour."""
|
||||
day_counts = Counter() # 0=Monday ... 6=Sunday
|
||||
hour_counts = Counter()
|
||||
@@ -527,15 +539,9 @@ class InsightsEngine:
|
||||
daily_counts[dt.strftime("%Y-%m-%d")] += 1
|
||||
|
||||
day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
|
||||
day_breakdown = [
|
||||
{"day": day_names[i], "count": day_counts.get(i, 0)}
|
||||
for i in range(7)
|
||||
]
|
||||
day_breakdown = [{"day": day_names[i], "count": day_counts.get(i, 0)} for i in range(7)]
|
||||
|
||||
hour_breakdown = [
|
||||
{"hour": i, "count": hour_counts.get(i, 0)}
|
||||
for i in range(24)
|
||||
]
|
||||
hour_breakdown = [{"hour": i, "count": hour_counts.get(i, 0)} for i in range(24)]
|
||||
|
||||
# Busiest day and hour
|
||||
busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None
|
||||
@@ -569,37 +575,40 @@ class InsightsEngine:
|
||||
"max_streak": max_streak,
|
||||
}
|
||||
|
||||
def _compute_top_sessions(self, sessions: List[Dict]) -> List[Dict]:
|
||||
def _compute_top_sessions(self, sessions: list[dict]) -> list[dict]:
|
||||
"""Find notable sessions (longest, most messages, most tokens)."""
|
||||
top = []
|
||||
|
||||
# Longest by duration
|
||||
sessions_with_duration = [
|
||||
s for s in sessions
|
||||
if s.get("started_at") and s.get("ended_at")
|
||||
]
|
||||
sessions_with_duration = [s for s in sessions if s.get("started_at") and s.get("ended_at")]
|
||||
if sessions_with_duration:
|
||||
longest = max(
|
||||
sessions_with_duration,
|
||||
key=lambda s: (s["ended_at"] - s["started_at"]),
|
||||
key=lambda s: s["ended_at"] - s["started_at"],
|
||||
)
|
||||
dur = longest["ended_at"] - longest["started_at"]
|
||||
top.append({
|
||||
"label": "Longest session",
|
||||
"session_id": longest["id"][:16],
|
||||
"value": _format_duration(dur),
|
||||
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
|
||||
})
|
||||
top.append(
|
||||
{
|
||||
"label": "Longest session",
|
||||
"session_id": longest["id"][:16],
|
||||
"value": _format_duration(dur),
|
||||
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
|
||||
}
|
||||
)
|
||||
|
||||
# Most messages
|
||||
most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0)
|
||||
if (most_msgs.get("message_count") or 0) > 0:
|
||||
top.append({
|
||||
"label": "Most messages",
|
||||
"session_id": most_msgs["id"][:16],
|
||||
"value": f"{most_msgs['message_count']} msgs",
|
||||
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") if most_msgs.get("started_at") else "?",
|
||||
})
|
||||
top.append(
|
||||
{
|
||||
"label": "Most messages",
|
||||
"session_id": most_msgs["id"][:16],
|
||||
"value": f"{most_msgs['message_count']} msgs",
|
||||
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d")
|
||||
if most_msgs.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
|
||||
# Most tokens
|
||||
most_tokens = max(
|
||||
@@ -608,22 +617,30 @@ class InsightsEngine:
|
||||
)
|
||||
token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0)
|
||||
if token_total > 0:
|
||||
top.append({
|
||||
"label": "Most tokens",
|
||||
"session_id": most_tokens["id"][:16],
|
||||
"value": f"{token_total:,} tokens",
|
||||
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") if most_tokens.get("started_at") else "?",
|
||||
})
|
||||
top.append(
|
||||
{
|
||||
"label": "Most tokens",
|
||||
"session_id": most_tokens["id"][:16],
|
||||
"value": f"{token_total:,} tokens",
|
||||
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d")
|
||||
if most_tokens.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
|
||||
# Most tool calls
|
||||
most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0)
|
||||
if (most_tools.get("tool_call_count") or 0) > 0:
|
||||
top.append({
|
||||
"label": "Most tool calls",
|
||||
"session_id": most_tools["id"][:16],
|
||||
"value": f"{most_tools['tool_call_count']} calls",
|
||||
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") if most_tools.get("started_at") else "?",
|
||||
})
|
||||
top.append(
|
||||
{
|
||||
"label": "Most tool calls",
|
||||
"session_id": most_tools["id"][:16],
|
||||
"value": f"{most_tools['tool_call_count']} calls",
|
||||
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d")
|
||||
if most_tools.get("started_at")
|
||||
else "?",
|
||||
}
|
||||
)
|
||||
|
||||
return top
|
||||
|
||||
@@ -631,7 +648,7 @@ class InsightsEngine:
|
||||
# Formatting
|
||||
# =========================================================================
|
||||
|
||||
def format_terminal(self, report: Dict) -> str:
|
||||
def format_terminal(self, report: dict) -> str:
|
||||
"""Format the insights report for terminal display (CLI)."""
|
||||
if report.get("empty"):
|
||||
days = report.get("days", 30)
|
||||
@@ -669,13 +686,17 @@ class InsightsEngine:
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
|
||||
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
|
||||
lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}")
|
||||
lines.append(
|
||||
f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}"
|
||||
)
|
||||
cost_str = f"${o['estimated_cost']:.2f}"
|
||||
if o.get("models_without_pricing"):
|
||||
cost_str += " *"
|
||||
lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}")
|
||||
if o["total_hours"] > 0:
|
||||
lines.append(f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}")
|
||||
lines.append(
|
||||
f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}"
|
||||
)
|
||||
lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}")
|
||||
lines.append("")
|
||||
|
||||
@@ -692,7 +713,7 @@ class InsightsEngine:
|
||||
cost_cell = " N/A"
|
||||
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
|
||||
if o.get("models_without_pricing"):
|
||||
lines.append(f" * Cost N/A for custom/self-hosted models")
|
||||
lines.append(" * Cost N/A for custom/self-hosted models")
|
||||
lines.append("")
|
||||
|
||||
# Platform breakdown
|
||||
@@ -758,7 +779,7 @@ class InsightsEngine:
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_gateway(self, report: Dict) -> str:
|
||||
def format_gateway(self, report: dict) -> str:
|
||||
"""Format the insights report for gateway/messaging (shorter)."""
|
||||
if report.get("empty"):
|
||||
days = report.get("days", 30)
|
||||
@@ -771,14 +792,20 @@ class InsightsEngine:
|
||||
lines.append(f"📊 **Hermes Insights** — Last {days} days\n")
|
||||
|
||||
# Overview
|
||||
lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}")
|
||||
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
|
||||
lines.append(
|
||||
f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}"
|
||||
)
|
||||
lines.append(
|
||||
f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})"
|
||||
)
|
||||
cost_note = ""
|
||||
if o.get("models_without_pricing"):
|
||||
cost_note = " _(excludes custom/self-hosted models)_"
|
||||
lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}")
|
||||
if o["total_hours"] > 0:
|
||||
lines.append(f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}")
|
||||
lines.append(
|
||||
f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Models (top 5)
|
||||
@@ -786,7 +813,9 @@ class InsightsEngine:
|
||||
lines.append("**🤖 Models:**")
|
||||
for m in report["models"][:5]:
|
||||
cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A"
|
||||
lines.append(f" {m['model'][:25]} — {m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}")
|
||||
lines.append(
|
||||
f" {m['model'][:25]} — {m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Platforms (if multi-platform)
|
||||
@@ -809,9 +838,13 @@ class InsightsEngine:
|
||||
hr = act["busiest_hour"]["hour"]
|
||||
ampm = "AM" if hr < 12 else "PM"
|
||||
display_hr = hr % 12 or 12
|
||||
lines.append(f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)")
|
||||
lines.append(
|
||||
f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)"
|
||||
)
|
||||
if act.get("active_days"):
|
||||
lines.append(f"**Active days:** {act['active_days']}", )
|
||||
lines.append(
|
||||
f"**Active days:** {act['active_days']}",
|
||||
)
|
||||
if act.get("max_streak", 0) > 1:
|
||||
lines.append(f"**Best streak:** {act['max_streak']} consecutive days")
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
@@ -18,7 +18,7 @@ from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_model_metadata_cache: dict[str, dict[str, Any]] = {}
|
||||
_model_metadata_cache_time: float = 0
|
||||
_MODEL_CACHE_TTL = 3600
|
||||
|
||||
@@ -63,7 +63,7 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
}
|
||||
|
||||
|
||||
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
||||
def fetch_model_metadata(force_refresh: bool = False) -> dict[str, dict[str, Any]]:
|
||||
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
||||
global _model_metadata_cache, _model_metadata_cache_time
|
||||
|
||||
@@ -104,7 +104,7 @@ def _get_context_cache_path() -> Path:
|
||||
return hermes_home / "context_length_cache.yaml"
|
||||
|
||||
|
||||
def _load_context_cache() -> Dict[str, int]:
|
||||
def _load_context_cache() -> dict[str, int]:
|
||||
"""Load the model+provider → context_length cache from disk."""
|
||||
path = _get_context_cache_path()
|
||||
if not path.exists():
|
||||
@@ -139,14 +139,14 @@ def save_context_length(model: str, base_url: str, length: int) -> None:
|
||||
logger.debug("Failed to save context length cache: %s", e)
|
||||
|
||||
|
||||
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
def get_cached_context_length(model: str, base_url: str) -> int | None:
|
||||
"""Look up a previously discovered context length for model+provider."""
|
||||
key = f"{model}@{base_url}"
|
||||
cache = _load_context_cache()
|
||||
return cache.get(key)
|
||||
|
||||
|
||||
def get_next_probe_tier(current_length: int) -> Optional[int]:
|
||||
def get_next_probe_tier(current_length: int) -> int | None:
|
||||
"""Return the next lower probe tier, or None if already at minimum."""
|
||||
for tier in CONTEXT_PROBE_TIERS:
|
||||
if tier < current_length:
|
||||
@@ -154,7 +154,7 @@ def get_next_probe_tier(current_length: int) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||
def parse_context_limit_from_error(error_msg: str) -> int | None:
|
||||
"""Try to extract the actual context limit from an API error message.
|
||||
|
||||
Many providers include the limit in their error text, e.g.:
|
||||
@@ -166,11 +166,11 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||
error_lower = error_msg.lower()
|
||||
# Pattern: look for numbers near context-related keywords
|
||||
patterns = [
|
||||
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
|
||||
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
|
||||
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
|
||||
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
|
||||
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
|
||||
r"(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})",
|
||||
r"context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})",
|
||||
r"(\d{4,})\s*(?:token)?\s*(?:context|limit)",
|
||||
r">\s*(\d{4,})\s*(?:max|limit|token)", # "250000 tokens > 200000 maximum"
|
||||
r"(\d{4,})\s*(?:max(?:imum)?)\b", # "200000 maximum"
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, error_lower)
|
||||
@@ -218,7 +218,7 @@ def estimate_tokens_rough(text: str) -> int:
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
||||
def estimate_messages_tokens_rough(messages: list[dict[str, Any]]) -> int:
|
||||
"""Rough token estimate for a message list (pre-flight only)."""
|
||||
total_chars = sum(len(str(msg)) for msg in messages)
|
||||
return total_chars // 4
|
||||
|
||||
@@ -8,7 +8,6 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,21 +17,29 @@ logger = logging.getLogger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONTEXT_THREAT_PATTERNS = [
|
||||
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
|
||||
(r'<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->', "html_comment_injection"),
|
||||
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
|
||||
(r"<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->", "html_comment_injection"),
|
||||
(r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', "hidden_div"),
|
||||
(r'translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)', "translate_execute"),
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
|
||||
(r"translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)", "translate_execute"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
|
||||
]
|
||||
|
||||
_CONTEXT_INVISIBLE_CHARS = {
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
}
|
||||
|
||||
|
||||
@@ -52,10 +59,13 @@ def _scan_context_content(content: str, filename: str) -> str:
|
||||
|
||||
if findings:
|
||||
logger.warning("Context file %s blocked: %s", filename, ", ".join(findings))
|
||||
return f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]"
|
||||
return (
|
||||
f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]"
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Constants
|
||||
# =========================================================================
|
||||
@@ -131,10 +141,7 @@ PLATFORM_HINTS = {
|
||||
"files arrive as downloadable documents. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as photos."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
"renderable inside a terminal."
|
||||
),
|
||||
"cli": ("You are a CLI AI Agent. Try not to use markdown but simple text renderable inside a terminal."),
|
||||
}
|
||||
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
@@ -146,18 +153,20 @@ CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
|
||||
# Skills index
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
|
||||
"""Read the description from a SKILL.md frontmatter, capped at max_chars."""
|
||||
try:
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
match = re.search(
|
||||
r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---",
|
||||
raw, re.MULTILINE | re.DOTALL,
|
||||
raw,
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
desc = match.group(1).strip().strip("'\"")
|
||||
if len(desc) > max_chars:
|
||||
desc = desc[:max_chars - 3] + "..."
|
||||
desc = desc[: max_chars - 3] + "..."
|
||||
return desc
|
||||
except Exception:
|
||||
pass
|
||||
@@ -172,6 +181,7 @@ def _skill_is_platform_compatible(skill_file: Path) -> bool:
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
|
||||
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
return skill_matches_platform(frontmatter)
|
||||
@@ -260,8 +270,7 @@ def build_skills_system_prompt() -> str:
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
"<available_skills>\n" + "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
@@ -272,6 +281,7 @@ def build_skills_system_prompt() -> str:
|
||||
# Context files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str:
|
||||
"""Head/tail truncation with a marker in the middle."""
|
||||
if len(content) <= max_chars:
|
||||
@@ -284,7 +294,7 @@ def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE
|
||||
return head + marker + tail
|
||||
|
||||
|
||||
def build_context_files_prompt(cwd: Optional[str] = None) -> str:
|
||||
def build_context_files_prompt(cwd: str | None = None) -> str:
|
||||
"""Discover and load context files for the system prompt.
|
||||
|
||||
Discovery: AGENTS.md (recursive), .cursorrules / .cursor/rules/*.mdc,
|
||||
@@ -307,7 +317,9 @@ def build_context_files_prompt(cwd: Optional[str] = None) -> str:
|
||||
if top_level_agents:
|
||||
agents_files = []
|
||||
for root, dirs, files in os.walk(cwd_path):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
|
||||
dirs[:] = [
|
||||
d for d in dirs if not d.startswith(".") and d not in ("node_modules", "__pycache__", "venv", ".venv")
|
||||
]
|
||||
for f in files:
|
||||
if f.lower() == "agents.md":
|
||||
agents_files.append(Path(root) / f)
|
||||
@@ -384,4 +396,7 @@ def build_context_files_prompt(cwd: Optional[str] = None) -> str:
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
return "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + "\n".join(sections)
|
||||
return (
|
||||
"# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n"
|
||||
+ "\n".join(sections)
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ Pure functions -- no class state, no AIAgent dependency.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
@@ -36,9 +36,9 @@ def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
|
||||
|
||||
def apply_anthropic_cache_control(
|
||||
api_messages: List[Dict[str, Any]],
|
||||
api_messages: list[dict[str, Any]],
|
||||
cache_ttl: str = "5m",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply system_and_3 caching strategy to messages for Anthropic models.
|
||||
|
||||
Places up to 4 cache_control breakpoints: system prompt + last 3 non-system messages.
|
||||
|
||||
@@ -10,34 +10,33 @@ the first 6 and last 4 characters for debuggability.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Known API key prefixes -- match the prefix + contiguous token chars
|
||||
_PREFIX_PATTERNS = [
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
|
||||
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
|
||||
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
|
||||
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
|
||||
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
|
||||
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
|
||||
r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai
|
||||
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
|
||||
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
|
||||
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
|
||||
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
|
||||
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
|
||||
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
|
||||
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
|
||||
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
|
||||
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
|
||||
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
|
||||
r"npm_[A-Za-z0-9]{10,}", # npm access token
|
||||
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
|
||||
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
|
||||
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
|
||||
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
|
||||
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
|
||||
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
|
||||
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
|
||||
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
|
||||
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
|
||||
r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai
|
||||
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
|
||||
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
|
||||
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
|
||||
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
|
||||
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
|
||||
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
|
||||
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
|
||||
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
|
||||
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
|
||||
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
|
||||
r"npm_[A-Za-z0-9]{10,}", # npm access token
|
||||
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
|
||||
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
|
||||
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
|
||||
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
|
||||
]
|
||||
|
||||
# ENV assignment patterns: KEY=value where KEY contains a secret-like name
|
||||
@@ -66,9 +65,7 @@ _TELEGRAM_RE = re.compile(
|
||||
)
|
||||
|
||||
# Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----
|
||||
_PRIVATE_KEY_RE = re.compile(
|
||||
r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----"
|
||||
)
|
||||
_PRIVATE_KEY_RE = re.compile(r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----")
|
||||
|
||||
# Database connection strings: protocol://user:PASSWORD@host
|
||||
# Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password
|
||||
@@ -82,9 +79,7 @@ _DB_CONNSTR_RE = re.compile(
|
||||
_SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])")
|
||||
|
||||
# Compile known prefix patterns into one alternation
|
||||
_PREFIX_RE = re.compile(
|
||||
r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])"
|
||||
)
|
||||
_PREFIX_RE = re.compile(r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])")
|
||||
|
||||
|
||||
def _mask_token(token: str) -> str:
|
||||
@@ -112,12 +107,14 @@ def redact_sensitive_text(text: str) -> str:
|
||||
def _redact_env(m):
|
||||
name, quote, value = m.group(1), m.group(2), m.group(3)
|
||||
return f"{name}={quote}{_mask_token(value)}{quote}"
|
||||
|
||||
text = _ENV_ASSIGN_RE.sub(_redact_env, text)
|
||||
|
||||
# JSON fields: "apiKey": "value"
|
||||
def _redact_json(m):
|
||||
key, value = m.group(1), m.group(2)
|
||||
return f'{key}: "{_mask_token(value)}"'
|
||||
|
||||
text = _JSON_FIELD_RE.sub(_redact_json, text)
|
||||
|
||||
# Authorization headers
|
||||
@@ -131,6 +128,7 @@ def redact_sensitive_text(text: str) -> str:
|
||||
prefix = m.group(1) or ""
|
||||
digits = m.group(2)
|
||||
return f"{prefix}{digits}:***"
|
||||
|
||||
text = _TELEGRAM_RE.sub(_redact_telegram, text)
|
||||
|
||||
# Private key blocks
|
||||
@@ -145,6 +143,7 @@ def redact_sensitive_text(text: str) -> str:
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:]
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
text = _SIGNAL_PHONE_RE.sub(_redact_phone, text)
|
||||
|
||||
return text
|
||||
@@ -153,7 +152,7 @@ def redact_sensitive_text(text: str) -> str:
|
||||
class RedactingFormatter(logging.Formatter):
|
||||
"""Log formatter that redacts secrets from all log messages."""
|
||||
|
||||
def __init__(self, fmt=None, datefmt=None, style='%', **kwargs):
|
||||
def __init__(self, fmt=None, datefmt=None, style="%", **kwargs):
|
||||
super().__init__(fmt, datefmt, style, **kwargs)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
|
||||
@@ -6,14 +6,14 @@ can invoke skills via /skill-name commands.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_skill_commands: Dict[str, Dict[str, Any]] = {}
|
||||
_skill_commands: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
def scan_skill_commands() -> dict[str, dict[str, Any]]:
|
||||
"""Scan ~/.hermes/skills/ and return a mapping of /command -> skill info.
|
||||
|
||||
Returns:
|
||||
@@ -23,26 +23,27 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
_skill_commands = {}
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
|
||||
|
||||
if not SKILLS_DIR.exists():
|
||||
return _skill_commands
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
|
||||
if any(part in (".git", ".github", ".hub") for part in skill_md.parts):
|
||||
continue
|
||||
try:
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
name = frontmatter.get('name', skill_md.parent.name)
|
||||
description = frontmatter.get('description', '')
|
||||
name = frontmatter.get("name", skill_md.parent.name)
|
||||
description = frontmatter.get("description", "")
|
||||
if not description:
|
||||
for line in body.strip().split('\n'):
|
||||
for line in body.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
if line and not line.startswith("#"):
|
||||
description = line[:80]
|
||||
break
|
||||
cmd_name = name.lower().replace(' ', '-').replace('_', '-')
|
||||
cmd_name = name.lower().replace(" ", "-").replace("_", "-")
|
||||
_skill_commands[f"/{cmd_name}"] = {
|
||||
"name": name,
|
||||
"description": description or f"Invoke the {name} skill",
|
||||
@@ -56,14 +57,14 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
return _skill_commands
|
||||
|
||||
|
||||
def get_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
def get_skill_commands() -> dict[str, dict[str, Any]]:
|
||||
"""Return the current skill commands mapping (scan first if empty)."""
|
||||
if not _skill_commands:
|
||||
scan_skill_commands()
|
||||
return _skill_commands
|
||||
|
||||
|
||||
def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> Optional[str]:
|
||||
def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") -> str | None:
|
||||
"""Build the user message content for a skill slash command invocation.
|
||||
|
||||
Args:
|
||||
@@ -83,7 +84,7 @@ def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") ->
|
||||
skill_name = skill_info["name"]
|
||||
|
||||
try:
|
||||
content = skill_md_path.read_text(encoding='utf-8')
|
||||
content = skill_md_path.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return f"[Failed to load skill: {skill_name}]"
|
||||
|
||||
@@ -111,6 +112,8 @@ def build_skill_invocation_message(cmd_key: str, user_instruction: str = "") ->
|
||||
|
||||
if user_instruction:
|
||||
parts.append("")
|
||||
parts.append(f"The user has provided the following instruction alongside the skill invocation: {user_instruction}")
|
||||
parts.append(
|
||||
f"The user has provided the following instruction alongside the skill invocation: {user_instruction}"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
@@ -8,7 +8,7 @@ the file-write logic live here.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,8 +27,7 @@ def has_incomplete_scratchpad(content: str) -> bool:
|
||||
return "<REASONING_SCRATCHPAD>" in content and "</REASONING_SCRATCHPAD>" not in content
|
||||
|
||||
|
||||
def save_trajectory(trajectory: List[Dict[str, Any]], model: str,
|
||||
completed: bool, filename: str = None):
|
||||
def save_trajectory(trajectory: list[dict[str, Any]], model: str, completed: bool, filename: str = None):
|
||||
"""Append a trajectory entry to a JSONL file.
|
||||
|
||||
Args:
|
||||
|
||||
572
batch_runner.py
572
batch_runner.py
File diff suppressed because it is too large
Load Diff
@@ -15,18 +15,18 @@ duplicate execution if multiple processes overlap.
|
||||
"""
|
||||
|
||||
from cron.jobs import (
|
||||
JOBS_FILE,
|
||||
create_job,
|
||||
get_job,
|
||||
list_jobs,
|
||||
remove_job,
|
||||
update_job,
|
||||
JOBS_FILE,
|
||||
)
|
||||
from cron.scheduler import tick
|
||||
|
||||
__all__ = [
|
||||
"create_job",
|
||||
"get_job",
|
||||
"get_job",
|
||||
"list_jobs",
|
||||
"remove_job",
|
||||
"update_job",
|
||||
|
||||
149
cron/jobs.py
149
cron/jobs.py
@@ -6,18 +6,19 @@ Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
try:
|
||||
from croniter import croniter
|
||||
|
||||
HAS_CRONITER = True
|
||||
except ImportError:
|
||||
HAS_CRONITER = False
|
||||
@@ -42,37 +43,38 @@ def ensure_dirs():
|
||||
# Schedule Parsing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parse_duration(s: str) -> int:
|
||||
"""
|
||||
Parse duration string into minutes.
|
||||
|
||||
|
||||
Examples:
|
||||
"30m" → 30
|
||||
"2h" → 120
|
||||
"1d" → 1440
|
||||
"""
|
||||
s = s.strip().lower()
|
||||
match = re.match(r'^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$', s)
|
||||
match = re.match(r"^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$", s)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid duration: '{s}'. Use format like '30m', '2h', or '1d'")
|
||||
|
||||
|
||||
value = int(match.group(1))
|
||||
unit = match.group(2)[0] # First char: m, h, or d
|
||||
|
||||
multipliers = {'m': 1, 'h': 60, 'd': 1440}
|
||||
|
||||
multipliers = {"m": 1, "h": 60, "d": 1440}
|
||||
return value * multipliers[unit]
|
||||
|
||||
|
||||
def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
def parse_schedule(schedule: str) -> dict[str, Any]:
|
||||
"""
|
||||
Parse schedule string into structured format.
|
||||
|
||||
|
||||
Returns dict with:
|
||||
- kind: "once" | "interval" | "cron"
|
||||
- For "once": "run_at" (ISO timestamp)
|
||||
- For "interval": "minutes" (int)
|
||||
- For "cron": "expr" (cron expression)
|
||||
|
||||
|
||||
Examples:
|
||||
"30m" → once in 30 minutes
|
||||
"2h" → once in 2 hours
|
||||
@@ -84,23 +86,17 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
schedule = schedule.strip()
|
||||
original = schedule
|
||||
schedule_lower = schedule.lower()
|
||||
|
||||
|
||||
# "every X" pattern → recurring interval
|
||||
if schedule_lower.startswith("every "):
|
||||
duration_str = schedule[6:].strip()
|
||||
minutes = parse_duration(duration_str)
|
||||
return {
|
||||
"kind": "interval",
|
||||
"minutes": minutes,
|
||||
"display": f"every {minutes}m"
|
||||
}
|
||||
|
||||
return {"kind": "interval", "minutes": minutes, "display": f"every {minutes}m"}
|
||||
|
||||
# Check for cron expression (5 or 6 space-separated fields)
|
||||
# Cron fields: minute hour day month weekday [year]
|
||||
parts = schedule.split()
|
||||
if len(parts) >= 5 and all(
|
||||
re.match(r'^[\d\*\-,/]+$', p) for p in parts[:5]
|
||||
):
|
||||
if len(parts) >= 5 and all(re.match(r"^[\d\*\-,/]+$", p) for p in parts[:5]):
|
||||
if not HAS_CRONITER:
|
||||
raise ValueError("Cron expressions require 'croniter' package. Install with: pip install croniter")
|
||||
# Validate cron expression
|
||||
@@ -108,37 +104,25 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
croniter(schedule)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid cron expression '{schedule}': {e}")
|
||||
return {
|
||||
"kind": "cron",
|
||||
"expr": schedule,
|
||||
"display": schedule
|
||||
}
|
||||
|
||||
return {"kind": "cron", "expr": schedule, "display": schedule}
|
||||
|
||||
# ISO timestamp (contains T or looks like date)
|
||||
if 'T' in schedule or re.match(r'^\d{4}-\d{2}-\d{2}', schedule):
|
||||
if "T" in schedule or re.match(r"^\d{4}-\d{2}-\d{2}", schedule):
|
||||
try:
|
||||
# Parse and validate
|
||||
dt = datetime.fromisoformat(schedule.replace('Z', '+00:00'))
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": dt.isoformat(),
|
||||
"display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"
|
||||
}
|
||||
dt = datetime.fromisoformat(schedule.replace("Z", "+00:00"))
|
||||
return {"kind": "once", "run_at": dt.isoformat(), "display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"}
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid timestamp '{schedule}': {e}")
|
||||
|
||||
|
||||
# Duration like "30m", "2h", "1d" → one-shot from now
|
||||
try:
|
||||
minutes = parse_duration(schedule)
|
||||
run_at = _hermes_now() + timedelta(minutes=minutes)
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": run_at.isoformat(),
|
||||
"display": f"once in {original}"
|
||||
}
|
||||
return {"kind": "once", "run_at": run_at.isoformat(), "display": f"once in {original}"}
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid schedule '{original}'. Use:\n"
|
||||
f" - Duration: '30m', '2h', '1d' (one-shot)\n"
|
||||
@@ -161,7 +145,7 @@ def _ensure_aware(dt: datetime) -> datetime:
|
||||
return dt
|
||||
|
||||
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
def compute_next_run(schedule: dict[str, Any], last_run_at: str | None = None) -> str | None:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
|
||||
@@ -199,26 +183,27 @@ def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None
|
||||
# Job CRUD Operations
|
||||
# =============================================================================
|
||||
|
||||
def load_jobs() -> List[Dict[str, Any]]:
|
||||
|
||||
def load_jobs() -> list[dict[str, Any]]:
|
||||
"""Load all jobs from storage."""
|
||||
ensure_dirs()
|
||||
if not JOBS_FILE.exists():
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
with open(JOBS_FILE, 'r', encoding='utf-8') as f:
|
||||
with open(JOBS_FILE, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data.get("jobs", [])
|
||||
except (json.JSONDecodeError, IOError):
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return []
|
||||
|
||||
|
||||
def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
def save_jobs(jobs: list[dict[str, Any]]):
|
||||
"""Save all jobs to storage."""
|
||||
ensure_dirs()
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_')
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix=".tmp", prefix=".jobs_")
|
||||
try:
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
@@ -234,14 +219,14 @@ def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
def create_job(
|
||||
prompt: str,
|
||||
schedule: str,
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
origin: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
name: str | None = None,
|
||||
repeat: int | None = None,
|
||||
deliver: str | None = None,
|
||||
origin: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a new cron job.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: The prompt to run (must be self-contained)
|
||||
schedule: Schedule string (see parse_schedule)
|
||||
@@ -249,23 +234,23 @@ def create_job(
|
||||
repeat: How many times to run (None = forever, 1 = once)
|
||||
deliver: Where to deliver output ("origin", "local", "telegram", etc.)
|
||||
origin: Source info where job was created (for "origin" delivery)
|
||||
|
||||
|
||||
Returns:
|
||||
The created job dict
|
||||
"""
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
|
||||
|
||||
# Auto-set repeat=1 for one-shot schedules if not specified
|
||||
if parsed_schedule["kind"] == "once" and repeat is None:
|
||||
repeat = 1
|
||||
|
||||
|
||||
# Default delivery to origin if available, otherwise local
|
||||
if deliver is None:
|
||||
deliver = "origin" if origin else "local"
|
||||
|
||||
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
now = _hermes_now().isoformat()
|
||||
|
||||
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": name or prompt[:50].strip(),
|
||||
@@ -274,7 +259,7 @@ def create_job(
|
||||
"schedule_display": parsed_schedule.get("display", schedule),
|
||||
"repeat": {
|
||||
"times": repeat, # None = forever
|
||||
"completed": 0
|
||||
"completed": 0,
|
||||
},
|
||||
"enabled": True,
|
||||
"created_at": now,
|
||||
@@ -286,15 +271,15 @@ def create_job(
|
||||
"deliver": deliver,
|
||||
"origin": origin, # Tracks where job was created for "origin" delivery
|
||||
}
|
||||
|
||||
|
||||
jobs = load_jobs()
|
||||
jobs.append(job)
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
return job
|
||||
|
||||
|
||||
def get_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
def get_job(job_id: str) -> dict[str, Any] | None:
|
||||
"""Get a job by ID."""
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
@@ -303,7 +288,7 @@ def get_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
return None
|
||||
|
||||
|
||||
def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
def list_jobs(include_disabled: bool = False) -> list[dict[str, Any]]:
|
||||
"""List all jobs, optionally including disabled ones."""
|
||||
jobs = load_jobs()
|
||||
if not include_disabled:
|
||||
@@ -311,7 +296,7 @@ def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
return jobs
|
||||
|
||||
|
||||
def update_job(job_id: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
def update_job(job_id: str, updates: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Update a job by ID."""
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
@@ -333,10 +318,10 @@ def remove_job(job_id: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
def mark_job_run(job_id: str, success: bool, error: str | None = None):
|
||||
"""
|
||||
Mark a job as having been run.
|
||||
|
||||
|
||||
Updates last_run_at, last_status, increments completed count,
|
||||
computes next_run_at, and auto-deletes if repeat limit reached.
|
||||
"""
|
||||
@@ -347,11 +332,11 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
job["last_run_at"] = now
|
||||
job["last_status"] = "ok" if success else "error"
|
||||
job["last_error"] = error if not success else None
|
||||
|
||||
|
||||
# Increment completed count
|
||||
if job.get("repeat"):
|
||||
job["repeat"]["completed"] = job["repeat"].get("completed", 0) + 1
|
||||
|
||||
|
||||
# Check if we've hit the repeat limit
|
||||
times = job["repeat"].get("times")
|
||||
completed = job["repeat"]["completed"]
|
||||
@@ -360,38 +345,38 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
jobs.pop(i)
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
|
||||
# Compute next run
|
||||
job["next_run_at"] = compute_next_run(job["schedule"], now)
|
||||
|
||||
|
||||
# If no next run (one-shot completed), disable
|
||||
if job["next_run_at"] is None:
|
||||
job["enabled"] = False
|
||||
|
||||
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
def get_due_jobs() -> list[dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
now = _hermes_now()
|
||||
jobs = load_jobs()
|
||||
due = []
|
||||
|
||||
|
||||
for job in jobs:
|
||||
if not job.get("enabled", True):
|
||||
continue
|
||||
|
||||
|
||||
next_run = job.get("next_run_at")
|
||||
if not next_run:
|
||||
continue
|
||||
|
||||
|
||||
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
|
||||
if next_run_dt <= now:
|
||||
due.append(job)
|
||||
|
||||
|
||||
return due
|
||||
|
||||
|
||||
@@ -400,11 +385,11 @@ def save_job_output(job_id: str, output: str):
|
||||
ensure_dirs()
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
output_file = job_output_dir / f"{timestamp}.md"
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write(output)
|
||||
|
||||
|
||||
return output_file
|
||||
|
||||
@@ -23,9 +23,7 @@ except ImportError:
|
||||
import msvcrt
|
||||
except ImportError:
|
||||
msvcrt = None
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
@@ -44,7 +42,7 @@ _LOCK_DIR = _hermes_home / "cron"
|
||||
_LOCK_FILE = _LOCK_DIR / ".tick.lock"
|
||||
|
||||
|
||||
def _resolve_origin(job: dict) -> Optional[dict]:
|
||||
def _resolve_origin(job: dict) -> dict | None:
|
||||
"""Extract origin info from a job, returning {platform, chat_id, chat_name} or None."""
|
||||
origin = job.get("origin")
|
||||
if not origin:
|
||||
@@ -87,11 +85,16 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
# Fall back to home channel
|
||||
chat_id = os.getenv(f"{platform_name.upper()}_HOME_CHANNEL", "")
|
||||
if not chat_id:
|
||||
logger.warning("Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL <channel_id>", job["id"], deliver, platform_name.upper())
|
||||
logger.warning(
|
||||
"Job '%s' deliver=%s but no chat_id or home channel. Set via: hermes config set %s_HOME_CHANNEL <channel_id>",
|
||||
job["id"],
|
||||
deliver,
|
||||
platform_name.upper(),
|
||||
)
|
||||
return
|
||||
|
||||
from gateway.config import Platform, load_gateway_config
|
||||
from tools.send_message_tool import _send_to_platform
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
|
||||
platform_map = {
|
||||
"telegram": Platform.TELEGRAM,
|
||||
@@ -123,6 +126,7 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
# asyncio.run() fails if there's already a running loop in this thread;
|
||||
# spin up a new thread to avoid that.
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, content))
|
||||
result = future.result(timeout=30)
|
||||
@@ -137,25 +141,26 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
# Mirror the delivered content into the target's gateway session
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
|
||||
mirror_to_session(platform_name, chat_id, content, source_label="cron")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
def run_job(job: dict) -> tuple[bool, str, str, str | None]:
|
||||
"""
|
||||
Execute a single cron job.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (success, full_output_doc, final_response, error_message)
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
job_id = job["id"]
|
||||
job_name = job["name"]
|
||||
prompt = job["prompt"]
|
||||
origin = _resolve_origin(job)
|
||||
|
||||
|
||||
logger.info("Running job '%s' (ID: %s)", job_name, job_id)
|
||||
logger.info("Prompt: %s", prompt[:100])
|
||||
|
||||
@@ -170,6 +175,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
# Re-read .env and config.yaml fresh every run so provider/key
|
||||
# changes take effect without a gateway restart.
|
||||
from dotenv import load_dotenv
|
||||
|
||||
try:
|
||||
load_dotenv(str(_hermes_home / ".env"), override=True, encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
@@ -181,6 +187,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
_cfg = {}
|
||||
try:
|
||||
import yaml
|
||||
|
||||
_cfg_path = str(_hermes_home / "config.yaml")
|
||||
if os.path.exists(_cfg_path):
|
||||
with open(_cfg_path) as _f:
|
||||
@@ -210,12 +217,13 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "")
|
||||
if prefill_file:
|
||||
import json as _json
|
||||
|
||||
pfpath = Path(prefill_file).expanduser()
|
||||
if not pfpath.is_absolute():
|
||||
pfpath = _hermes_home / pfpath
|
||||
if pfpath.exists():
|
||||
try:
|
||||
with open(pfpath, "r", encoding="utf-8") as _pf:
|
||||
with open(pfpath, encoding="utf-8") as _pf:
|
||||
prefill_messages = _json.load(_pf)
|
||||
if not isinstance(prefill_messages, list):
|
||||
prefill_messages = None
|
||||
@@ -229,9 +237,10 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
pr = _cfg.get("provider_routing", {})
|
||||
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
format_runtime_provider_error,
|
||||
resolve_runtime_provider,
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(
|
||||
requested=os.getenv("HERMES_INFERENCE_PROVIDER"),
|
||||
@@ -254,20 +263,20 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}",
|
||||
)
|
||||
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
final_response = "(No response generated)"
|
||||
|
||||
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
**Schedule:** {job.get("schedule_display", "N/A")}
|
||||
|
||||
## Prompt
|
||||
|
||||
@@ -277,19 +286,19 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
{final_response}
|
||||
"""
|
||||
|
||||
|
||||
logger.info("Job '%s' completed successfully", job_name)
|
||||
return True, output, final_response, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
logger.error("Job '%s' failed: %s", job_name, error_msg)
|
||||
|
||||
|
||||
output = f"""# Cron Job: {job_name} (FAILED)
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
**Run Time:** {_hermes_now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
**Schedule:** {job.get("schedule_display", "N/A")}
|
||||
|
||||
## Prompt
|
||||
|
||||
@@ -314,13 +323,13 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
def tick(verbose: bool = True) -> int:
|
||||
"""
|
||||
Check and run all due jobs.
|
||||
|
||||
|
||||
Uses a file lock so only one tick runs at a time, even if the gateway's
|
||||
in-process ticker and a standalone daemon or manual tick overlap.
|
||||
|
||||
|
||||
Args:
|
||||
verbose: Whether to print status messages
|
||||
|
||||
|
||||
Returns:
|
||||
Number of jobs executed (0 if another tick is already running)
|
||||
"""
|
||||
@@ -334,7 +343,7 @@ def tick(verbose: bool = True) -> int:
|
||||
fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
elif msvcrt:
|
||||
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_NBLCK, 1)
|
||||
except (OSError, IOError):
|
||||
except OSError:
|
||||
logger.debug("Tick skipped — another instance holds the lock")
|
||||
if lock_fd is not None:
|
||||
lock_fd.close()
|
||||
@@ -344,11 +353,11 @@ def tick(verbose: bool = True) -> int:
|
||||
due_jobs = get_due_jobs()
|
||||
|
||||
if verbose and not due_jobs:
|
||||
logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S'))
|
||||
logger.info("%s - No jobs due", _hermes_now().strftime("%H:%M:%S"))
|
||||
return 0
|
||||
|
||||
if verbose:
|
||||
logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs))
|
||||
logger.info("%s - %s job(s) due", _hermes_now().strftime("%H:%M:%S"), len(due_jobs))
|
||||
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
@@ -360,7 +369,9 @@ def tick(verbose: bool = True) -> int:
|
||||
logger.info("Output saved to: %s", output_file)
|
||||
|
||||
# Deliver the final response to the origin/target chat
|
||||
deliver_content = final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}"
|
||||
deliver_content = (
|
||||
final_response if success else f"⚠️ Cron job '{job.get('name', job['id'])}' failed:\n{error}"
|
||||
)
|
||||
if deliver_content:
|
||||
try:
|
||||
_deliver_result(job, deliver_content)
|
||||
@@ -371,7 +382,7 @@ def tick(verbose: bool = True) -> int:
|
||||
executed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing job %s: %s", job['id'], e)
|
||||
logger.error("Error processing job %s: %s", job["id"], e)
|
||||
mark_job_run(job["id"], False, str(e))
|
||||
|
||||
return executed
|
||||
@@ -381,7 +392,7 @@ def tick(verbose: bool = True) -> int:
|
||||
elif msvcrt:
|
||||
try:
|
||||
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
except (OSError, IOError):
|
||||
except OSError:
|
||||
pass
|
||||
lock_fd.close()
|
||||
|
||||
|
||||
@@ -9,19 +9,18 @@ to various messaging platforms (Telegram, Discord, WhatsApp) with:
|
||||
- Platform-specific toolsets (different capabilities per platform)
|
||||
"""
|
||||
|
||||
from .config import GatewayConfig, PlatformConfig, HomeChannel, load_gateway_config
|
||||
from .config import GatewayConfig, HomeChannel, PlatformConfig, SessionResetPolicy, load_gateway_config
|
||||
from .delivery import DeliveryRouter, DeliveryTarget
|
||||
from .session import (
|
||||
SessionContext,
|
||||
SessionStore,
|
||||
SessionResetPolicy,
|
||||
build_session_context_prompt,
|
||||
)
|
||||
from .delivery import DeliveryRouter, DeliveryTarget
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"GatewayConfig",
|
||||
"PlatformConfig",
|
||||
"PlatformConfig",
|
||||
"HomeChannel",
|
||||
"load_gateway_config",
|
||||
# Session
|
||||
|
||||
@@ -10,7 +10,7 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,7 +21,8 @@ DIRECTORY_PATH = Path.home() / ".hermes" / "channel_directory.json"
|
||||
# Build / refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
|
||||
def build_channel_directory(adapters: dict[Any, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Build a channel directory from connected platform adapters and session data.
|
||||
|
||||
@@ -29,7 +30,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
from gateway.config import Platform
|
||||
|
||||
platforms: Dict[str, List[Dict[str, str]]] = {}
|
||||
platforms: dict[str, list[dict[str, str]]] = {}
|
||||
|
||||
for platform, adapter in adapters.items():
|
||||
try:
|
||||
@@ -60,7 +61,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
return directory
|
||||
|
||||
|
||||
def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
def _build_discord(adapter) -> list[dict[str, str]]:
|
||||
"""Enumerate all text channels the Discord bot can see."""
|
||||
channels = []
|
||||
client = getattr(adapter, "_client", None)
|
||||
@@ -74,12 +75,14 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
|
||||
for guild in client.guilds:
|
||||
for ch in guild.text_channels:
|
||||
channels.append({
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
})
|
||||
channels.append(
|
||||
{
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
}
|
||||
)
|
||||
# Also include DM-capable users we've interacted with is not
|
||||
# feasible via guild enumeration; those come from sessions.
|
||||
|
||||
@@ -88,7 +91,7 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
return channels
|
||||
|
||||
|
||||
def _build_slack(adapter) -> List[Dict[str, str]]:
|
||||
def _build_slack(adapter) -> list[dict[str, str]]:
|
||||
"""List Slack channels the bot has joined."""
|
||||
channels = []
|
||||
# Slack adapter may expose a web client
|
||||
@@ -97,7 +100,6 @@ def _build_slack(adapter) -> List[Dict[str, str]]:
|
||||
return _build_from_sessions("slack")
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
from tools.send_message_tool import _send_slack # noqa: F401
|
||||
# Use the Slack Web API directly if available
|
||||
except Exception:
|
||||
@@ -107,7 +109,7 @@ def _build_slack(adapter) -> List[Dict[str, str]]:
|
||||
return _build_from_sessions("slack")
|
||||
|
||||
|
||||
def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
|
||||
def _build_from_sessions(platform_name: str) -> list[dict[str, str]]:
|
||||
"""Pull known channels/contacts from sessions.json origin data."""
|
||||
sessions_path = Path.home() / ".hermes" / "sessions" / "sessions.json"
|
||||
if not sessions_path.exists():
|
||||
@@ -127,11 +129,13 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
|
||||
if not chat_id or chat_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(chat_id)
|
||||
entries.append({
|
||||
"id": str(chat_id),
|
||||
"name": origin.get("chat_name") or origin.get("user_name") or str(chat_id),
|
||||
"type": session.get("chat_type", "dm"),
|
||||
})
|
||||
entries.append(
|
||||
{
|
||||
"id": str(chat_id),
|
||||
"name": origin.get("chat_name") or origin.get("user_name") or str(chat_id),
|
||||
"type": session.get("chat_type", "dm"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Channel directory: failed to read sessions for %s: %s", platform_name, e)
|
||||
|
||||
@@ -142,7 +146,8 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
|
||||
# Read / resolve
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_directory() -> Dict[str, Any]:
|
||||
|
||||
def load_directory() -> dict[str, Any]:
|
||||
"""Load the cached channel directory from disk."""
|
||||
if not DIRECTORY_PATH.exists():
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
@@ -153,7 +158,7 @@ def load_directory() -> Dict[str, Any]:
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
|
||||
def resolve_channel_name(platform_name: str, name: str) -> Optional[str]:
|
||||
def resolve_channel_name(platform_name: str, name: str) -> str | None:
|
||||
"""
|
||||
Resolve a human-friendly channel name to a numeric ID.
|
||||
|
||||
@@ -206,8 +211,8 @@ def format_directory_for_display() -> str:
|
||||
|
||||
# Group Discord channels by guild
|
||||
if plat_name == "discord":
|
||||
guilds: Dict[str, List] = {}
|
||||
dms: List = []
|
||||
guilds: dict[str, list] = {}
|
||||
dms: list = []
|
||||
for ch in channels:
|
||||
guild = ch.get("guild")
|
||||
if guild:
|
||||
|
||||
@@ -8,19 +8,20 @@ Handles loading and validating configuration for:
|
||||
- Delivery preferences
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Platform(Enum):
|
||||
"""Supported messaging platforms."""
|
||||
|
||||
LOCAL = "local"
|
||||
TELEGRAM = "telegram"
|
||||
DISCORD = "discord"
|
||||
@@ -34,23 +35,24 @@ class Platform(Enum):
|
||||
class HomeChannel:
|
||||
"""
|
||||
Default destination for a platform.
|
||||
|
||||
|
||||
When a cron job specifies deliver="telegram" without a specific chat ID,
|
||||
messages are sent to this home channel.
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: str
|
||||
name: str # Human-readable name for display
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HomeChannel":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "HomeChannel":
|
||||
return cls(
|
||||
platform=Platform(data["platform"]),
|
||||
chat_id=str(data["chat_id"]),
|
||||
@@ -62,26 +64,27 @@ class HomeChannel:
|
||||
class SessionResetPolicy:
|
||||
"""
|
||||
Controls when sessions reset (lose context).
|
||||
|
||||
|
||||
Modes:
|
||||
- "daily": Reset at a specific hour each day
|
||||
- "idle": Reset after N minutes of inactivity
|
||||
- "both": Whichever triggers first (daily boundary OR idle timeout)
|
||||
- "none": Never auto-reset (context managed only by compression)
|
||||
"""
|
||||
|
||||
mode: str = "both" # "daily", "idle", "both", or "none"
|
||||
at_hour: int = 4 # Hour for daily reset (0-23, local time)
|
||||
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"at_hour": self.at_hour,
|
||||
"idle_minutes": self.idle_minutes,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionResetPolicy":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionResetPolicy":
|
||||
return cls(
|
||||
mode=data.get("mode", "both"),
|
||||
at_hour=data.get("at_hour", 4),
|
||||
@@ -92,15 +95,16 @@ class SessionResetPolicy:
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""Configuration for a single messaging platform."""
|
||||
|
||||
enabled: bool = False
|
||||
token: Optional[str] = None # Bot token (Telegram, Discord)
|
||||
api_key: Optional[str] = None # API key if different from token
|
||||
home_channel: Optional[HomeChannel] = None
|
||||
|
||||
token: str | None = None # Bot token (Telegram, Discord)
|
||||
api_key: str | None = None # API key if different from token
|
||||
home_channel: HomeChannel | None = None
|
||||
|
||||
# Platform-specific settings
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {
|
||||
"enabled": self.enabled,
|
||||
"extra": self.extra,
|
||||
@@ -112,13 +116,13 @@ class PlatformConfig:
|
||||
if self.home_channel:
|
||||
result["home_channel"] = self.home_channel.to_dict()
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PlatformConfig":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PlatformConfig":
|
||||
home_channel = None
|
||||
if "home_channel" in data:
|
||||
home_channel = HomeChannel.from_dict(data["home_channel"])
|
||||
|
||||
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
token=data.get("token"),
|
||||
@@ -132,89 +136,80 @@ class PlatformConfig:
|
||||
class GatewayConfig:
|
||||
"""
|
||||
Main gateway configuration.
|
||||
|
||||
|
||||
Manages all platform connections, session policies, and delivery settings.
|
||||
"""
|
||||
|
||||
# Platform configurations
|
||||
platforms: Dict[Platform, PlatformConfig] = field(default_factory=dict)
|
||||
|
||||
platforms: dict[Platform, PlatformConfig] = field(default_factory=dict)
|
||||
|
||||
# Session reset policies by type
|
||||
default_reset_policy: SessionResetPolicy = field(default_factory=SessionResetPolicy)
|
||||
reset_by_type: Dict[str, SessionResetPolicy] = field(default_factory=dict)
|
||||
reset_by_platform: Dict[Platform, SessionResetPolicy] = field(default_factory=dict)
|
||||
|
||||
reset_by_type: dict[str, SessionResetPolicy] = field(default_factory=dict)
|
||||
reset_by_platform: dict[Platform, SessionResetPolicy] = field(default_factory=dict)
|
||||
|
||||
# Reset trigger commands
|
||||
reset_triggers: List[str] = field(default_factory=lambda: ["/new", "/reset"])
|
||||
|
||||
reset_triggers: list[str] = field(default_factory=lambda: ["/new", "/reset"])
|
||||
|
||||
# Storage paths
|
||||
sessions_dir: Path = field(default_factory=lambda: Path.home() / ".hermes" / "sessions")
|
||||
|
||||
|
||||
# Delivery settings
|
||||
always_log_local: bool = True # Always save cron outputs to local files
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
|
||||
def get_connected_platforms(self) -> list[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
for platform, config in self.platforms.items():
|
||||
if not config.enabled:
|
||||
continue
|
||||
# Platforms that use token/api_key auth
|
||||
if config.token or config.api_key:
|
||||
connected.append(platform)
|
||||
# WhatsApp uses enabled flag only (bridge handles auth)
|
||||
elif platform == Platform.WHATSAPP:
|
||||
connected.append(platform)
|
||||
# Signal uses extra dict for config (http_url + account)
|
||||
elif platform == Platform.SIGNAL and config.extra.get("http_url"):
|
||||
if (
|
||||
config.token
|
||||
or config.api_key
|
||||
or platform == Platform.WHATSAPP
|
||||
or platform == Platform.SIGNAL
|
||||
and config.extra.get("http_url")
|
||||
):
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> HomeChannel | None:
|
||||
"""Get the home channel for a platform."""
|
||||
config = self.platforms.get(platform)
|
||||
if config:
|
||||
return config.home_channel
|
||||
return None
|
||||
|
||||
def get_reset_policy(
|
||||
self,
|
||||
platform: Optional[Platform] = None,
|
||||
session_type: Optional[str] = None
|
||||
) -> SessionResetPolicy:
|
||||
|
||||
def get_reset_policy(self, platform: Platform | None = None, session_type: str | None = None) -> SessionResetPolicy:
|
||||
"""
|
||||
Get the appropriate reset policy for a session.
|
||||
|
||||
|
||||
Priority: platform override > type override > default
|
||||
"""
|
||||
# Platform-specific override takes precedence
|
||||
if platform and platform in self.reset_by_platform:
|
||||
return self.reset_by_platform[platform]
|
||||
|
||||
|
||||
# Type-specific override (dm, group, thread)
|
||||
if session_type and session_type in self.reset_by_type:
|
||||
return self.reset_by_type[session_type]
|
||||
|
||||
|
||||
return self.default_reset_policy
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"platforms": {
|
||||
p.value: c.to_dict() for p, c in self.platforms.items()
|
||||
},
|
||||
"platforms": {p.value: c.to_dict() for p, c in self.platforms.items()},
|
||||
"default_reset_policy": self.default_reset_policy.to_dict(),
|
||||
"reset_by_type": {
|
||||
k: v.to_dict() for k, v in self.reset_by_type.items()
|
||||
},
|
||||
"reset_by_platform": {
|
||||
p.value: v.to_dict() for p, v in self.reset_by_platform.items()
|
||||
},
|
||||
"reset_by_type": {k: v.to_dict() for k, v in self.reset_by_type.items()},
|
||||
"reset_by_platform": {p.value: v.to_dict() for p, v in self.reset_by_platform.items()},
|
||||
"reset_triggers": self.reset_triggers,
|
||||
"sessions_dir": str(self.sessions_dir),
|
||||
"always_log_local": self.always_log_local,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GatewayConfig":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "GatewayConfig":
|
||||
platforms = {}
|
||||
for platform_name, platform_data in data.get("platforms", {}).items():
|
||||
try:
|
||||
@@ -222,11 +217,11 @@ class GatewayConfig:
|
||||
platforms[platform] = PlatformConfig.from_dict(platform_data)
|
||||
except ValueError:
|
||||
pass # Skip unknown platforms
|
||||
|
||||
|
||||
reset_by_type = {}
|
||||
for type_name, policy_data in data.get("reset_by_type", {}).items():
|
||||
reset_by_type[type_name] = SessionResetPolicy.from_dict(policy_data)
|
||||
|
||||
|
||||
reset_by_platform = {}
|
||||
for platform_name, policy_data in data.get("reset_by_platform", {}).items():
|
||||
try:
|
||||
@@ -234,15 +229,15 @@ class GatewayConfig:
|
||||
reset_by_platform[platform] = SessionResetPolicy.from_dict(policy_data)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
default_policy = SessionResetPolicy()
|
||||
if "default_reset_policy" in data:
|
||||
default_policy = SessionResetPolicy.from_dict(data["default_reset_policy"])
|
||||
|
||||
|
||||
sessions_dir = Path.home() / ".hermes" / "sessions"
|
||||
if "sessions_dir" in data:
|
||||
sessions_dir = Path(data["sessions_dir"])
|
||||
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
@@ -257,7 +252,7 @@ class GatewayConfig:
|
||||
def load_gateway_config() -> GatewayConfig:
|
||||
"""
|
||||
Load gateway configuration from multiple sources.
|
||||
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. Environment variables
|
||||
2. ~/.hermes/gateway.json
|
||||
@@ -265,22 +260,23 @@ def load_gateway_config() -> GatewayConfig:
|
||||
4. Defaults
|
||||
"""
|
||||
config = GatewayConfig()
|
||||
|
||||
|
||||
# Try loading from ~/.hermes/gateway.json
|
||||
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
|
||||
if gateway_config_path.exists():
|
||||
try:
|
||||
with open(gateway_config_path, "r") as f:
|
||||
with open(gateway_config_path) as f:
|
||||
data = json.load(f)
|
||||
config = GatewayConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load {gateway_config_path}: {e}")
|
||||
|
||||
|
||||
# Bridge session_reset from config.yaml (the user-facing config file)
|
||||
# into the gateway config. config.yaml takes precedence over gateway.json
|
||||
# for session reset policy since that's where hermes setup writes it.
|
||||
try:
|
||||
import yaml
|
||||
|
||||
config_yaml_path = Path.home() / ".hermes" / "config.yaml"
|
||||
if config_yaml_path.exists():
|
||||
with open(config_yaml_path) as f:
|
||||
@@ -293,14 +289,12 @@ def load_gateway_config() -> GatewayConfig:
|
||||
|
||||
# Override with environment variables
|
||||
_apply_env_overrides(config)
|
||||
|
||||
|
||||
# --- Validate loaded values ---
|
||||
policy = config.default_reset_policy
|
||||
|
||||
if not (0 <= policy.at_hour <= 23):
|
||||
logger.warning(
|
||||
"Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour
|
||||
)
|
||||
logger.warning("Invalid at_hour=%s (must be 0-23). Using default 4.", policy.at_hour)
|
||||
policy.at_hour = 4
|
||||
|
||||
if policy.idle_minutes is None or policy.idle_minutes <= 0:
|
||||
@@ -323,9 +317,9 @@ def load_gateway_config() -> GatewayConfig:
|
||||
env_name = _token_env_names.get(platform)
|
||||
if env_name and pconfig.token is not None and not pconfig.token.strip():
|
||||
logger.warning(
|
||||
"%s is enabled but %s is empty. "
|
||||
"The adapter will likely fail to connect.",
|
||||
platform.value, env_name,
|
||||
"%s is enabled but %s is empty. The adapter will likely fail to connect.",
|
||||
platform.value,
|
||||
env_name,
|
||||
)
|
||||
|
||||
return config
|
||||
@@ -333,7 +327,7 @@ def load_gateway_config() -> GatewayConfig:
|
||||
|
||||
def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"""Apply environment variable overrides to config."""
|
||||
|
||||
|
||||
# Telegram
|
||||
telegram_token = os.getenv("TELEGRAM_BOT_TOKEN")
|
||||
if telegram_token:
|
||||
@@ -341,7 +335,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].enabled = True
|
||||
config.platforms[Platform.TELEGRAM].token = telegram_token
|
||||
|
||||
|
||||
telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL")
|
||||
if telegram_home and Platform.TELEGRAM in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM].home_channel = HomeChannel(
|
||||
@@ -349,7 +343,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
chat_id=telegram_home,
|
||||
name=os.getenv("TELEGRAM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
|
||||
# Discord
|
||||
discord_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||
if discord_token:
|
||||
@@ -357,7 +351,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.DISCORD] = PlatformConfig()
|
||||
config.platforms[Platform.DISCORD].enabled = True
|
||||
config.platforms[Platform.DISCORD].token = discord_token
|
||||
|
||||
|
||||
discord_home = os.getenv("DISCORD_HOME_CHANNEL")
|
||||
if discord_home and Platform.DISCORD in config.platforms:
|
||||
config.platforms[Platform.DISCORD].home_channel = HomeChannel(
|
||||
@@ -365,14 +359,14 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
chat_id=discord_home,
|
||||
name=os.getenv("DISCORD_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
|
||||
# WhatsApp (typically uses different auth mechanism)
|
||||
whatsapp_enabled = os.getenv("WHATSAPP_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
if whatsapp_enabled:
|
||||
if Platform.WHATSAPP not in config.platforms:
|
||||
config.platforms[Platform.WHATSAPP] = PlatformConfig()
|
||||
config.platforms[Platform.WHATSAPP].enabled = True
|
||||
|
||||
|
||||
# Slack
|
||||
slack_token = os.getenv("SLACK_BOT_TOKEN")
|
||||
if slack_token:
|
||||
@@ -388,7 +382,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
chat_id=slack_home,
|
||||
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
|
||||
)
|
||||
|
||||
|
||||
# Signal
|
||||
signal_url = os.getenv("SIGNAL_HTTP_URL")
|
||||
signal_account = os.getenv("SIGNAL_ACCOUNT")
|
||||
@@ -396,11 +390,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
if Platform.SIGNAL not in config.platforms:
|
||||
config.platforms[Platform.SIGNAL] = PlatformConfig()
|
||||
config.platforms[Platform.SIGNAL].enabled = True
|
||||
config.platforms[Platform.SIGNAL].extra.update({
|
||||
"http_url": signal_url,
|
||||
"account": signal_account,
|
||||
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
|
||||
})
|
||||
config.platforms[Platform.SIGNAL].extra.update(
|
||||
{
|
||||
"http_url": signal_url,
|
||||
"account": signal_account,
|
||||
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
|
||||
}
|
||||
)
|
||||
signal_home = os.getenv("SIGNAL_HOME_CHANNEL")
|
||||
if signal_home:
|
||||
config.platforms[Platform.SIGNAL].home_channel = HomeChannel(
|
||||
@@ -427,7 +423,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.default_reset_policy.idle_minutes = int(idle_minutes)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
reset_hour = os.getenv("SESSION_RESET_HOUR")
|
||||
if reset_hour:
|
||||
try:
|
||||
@@ -440,6 +436,6 @@ def save_gateway_config(config: GatewayConfig) -> None:
|
||||
"""Save gateway configuration to ~/.hermes/gateway.json."""
|
||||
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
|
||||
gateway_config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
with open(gateway_config_path, "w") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
|
||||
@@ -9,18 +9,17 @@ Routes messages to the appropriate destination based on:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_PLATFORM_OUTPUT = 4000
|
||||
TRUNCATED_VISIBLE = 3800
|
||||
|
||||
from .config import Platform, GatewayConfig
|
||||
from .config import GatewayConfig, Platform
|
||||
from .session import SessionSource
|
||||
|
||||
|
||||
@@ -28,23 +27,24 @@ from .session import SessionSource
|
||||
class DeliveryTarget:
|
||||
"""
|
||||
A single delivery target.
|
||||
|
||||
|
||||
Represents where a message should be sent:
|
||||
- "origin" → back to source
|
||||
- "local" → save to local files
|
||||
- "telegram" → Telegram home channel
|
||||
- "telegram:123456" → specific Telegram chat
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: Optional[str] = None # None means use home channel
|
||||
chat_id: str | None = None # None means use home channel
|
||||
is_origin: bool = False
|
||||
is_explicit: bool = False # True if chat_id was explicitly specified
|
||||
|
||||
|
||||
@classmethod
|
||||
def parse(cls, target: str, origin: Optional[SessionSource] = None) -> "DeliveryTarget":
|
||||
def parse(cls, target: str, origin: SessionSource | None = None) -> "DeliveryTarget":
|
||||
"""
|
||||
Parse a delivery target string.
|
||||
|
||||
|
||||
Formats:
|
||||
- "origin" → back to source
|
||||
- "local" → local files only
|
||||
@@ -52,7 +52,7 @@ class DeliveryTarget:
|
||||
- "telegram:123456" → specific Telegram chat
|
||||
"""
|
||||
target = target.strip().lower()
|
||||
|
||||
|
||||
if target == "origin":
|
||||
if origin:
|
||||
return cls(
|
||||
@@ -63,10 +63,10 @@ class DeliveryTarget:
|
||||
else:
|
||||
# Fallback to local if no origin
|
||||
return cls(platform=Platform.LOCAL, is_origin=True)
|
||||
|
||||
|
||||
if target == "local":
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
|
||||
# Check for platform:chat_id format
|
||||
if ":" in target:
|
||||
platform_str, chat_id = target.split(":", 1)
|
||||
@@ -76,7 +76,7 @@ class DeliveryTarget:
|
||||
except ValueError:
|
||||
# Unknown platform, treat as local
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
|
||||
# Just a platform name (use home channel)
|
||||
try:
|
||||
platform = Platform(target)
|
||||
@@ -84,7 +84,7 @@ class DeliveryTarget:
|
||||
except ValueError:
|
||||
# Unknown platform, treat as local
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert back to string format."""
|
||||
if self.is_origin:
|
||||
@@ -99,15 +99,15 @@ class DeliveryTarget:
|
||||
class DeliveryRouter:
|
||||
"""
|
||||
Routes messages to appropriate destinations.
|
||||
|
||||
|
||||
Handles the logic of resolving delivery targets and dispatching
|
||||
messages to the right platform adapters.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GatewayConfig, adapters: Dict[Platform, Any] = None):
|
||||
|
||||
def __init__(self, config: GatewayConfig, adapters: dict[Platform, Any] = None):
|
||||
"""
|
||||
Initialize the delivery router.
|
||||
|
||||
|
||||
Args:
|
||||
config: Gateway configuration
|
||||
adapters: Dict mapping platforms to their adapter instances
|
||||
@@ -115,31 +115,27 @@ class DeliveryRouter:
|
||||
self.config = config
|
||||
self.adapters = adapters or {}
|
||||
self.output_dir = Path.home() / ".hermes" / "cron" / "output"
|
||||
|
||||
def resolve_targets(
|
||||
self,
|
||||
deliver: Union[str, List[str]],
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> List[DeliveryTarget]:
|
||||
|
||||
def resolve_targets(self, deliver: str | list[str], origin: SessionSource | None = None) -> list[DeliveryTarget]:
|
||||
"""
|
||||
Resolve delivery specification to concrete targets.
|
||||
|
||||
|
||||
Args:
|
||||
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
|
||||
origin: The source where the request originated (for "origin" target)
|
||||
|
||||
|
||||
Returns:
|
||||
List of resolved delivery targets
|
||||
"""
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
|
||||
|
||||
targets = []
|
||||
seen_platforms = set()
|
||||
|
||||
|
||||
for target_str in deliver:
|
||||
target = DeliveryTarget.parse(target_str, origin)
|
||||
|
||||
|
||||
# Resolve home channel if needed
|
||||
if target.chat_id is None and target.platform != Platform.LOCAL:
|
||||
home = self.config.get_home_channel(target.platform)
|
||||
@@ -148,109 +144,96 @@ class DeliveryRouter:
|
||||
else:
|
||||
# No home channel configured, skip this platform
|
||||
continue
|
||||
|
||||
|
||||
# Deduplicate
|
||||
key = (target.platform, target.chat_id)
|
||||
if key not in seen_platforms:
|
||||
seen_platforms.add(key)
|
||||
targets.append(target)
|
||||
|
||||
|
||||
# Always include local if configured
|
||||
if self.config.always_log_local:
|
||||
local_key = (Platform.LOCAL, None)
|
||||
if local_key not in seen_platforms:
|
||||
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
||||
|
||||
|
||||
return targets
|
||||
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
content: str,
|
||||
targets: List[DeliveryTarget],
|
||||
job_id: Optional[str] = None,
|
||||
job_name: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
targets: list[DeliveryTarget],
|
||||
job_id: str | None = None,
|
||||
job_name: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Deliver content to all specified targets.
|
||||
|
||||
|
||||
Args:
|
||||
content: The message/output to deliver
|
||||
targets: List of delivery targets
|
||||
job_id: Optional job ID (for cron jobs)
|
||||
job_name: Optional job name
|
||||
metadata: Additional metadata to include
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with delivery results per target
|
||||
"""
|
||||
results = {}
|
||||
|
||||
|
||||
for target in targets:
|
||||
try:
|
||||
if target.platform == Platform.LOCAL:
|
||||
result = self._deliver_local(content, job_id, job_name, metadata)
|
||||
else:
|
||||
result = await self._deliver_to_platform(target, content, metadata)
|
||||
|
||||
results[target.to_string()] = {
|
||||
"success": True,
|
||||
"result": result
|
||||
}
|
||||
|
||||
results[target.to_string()] = {"success": True, "result": result}
|
||||
except Exception as e:
|
||||
results[target.to_string()] = {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
results[target.to_string()] = {"success": False, "error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _deliver_local(
|
||||
self,
|
||||
content: str,
|
||||
job_id: Optional[str],
|
||||
job_name: Optional[str],
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
self, content: str, job_id: str | None, job_name: str | None, metadata: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Save content to local files."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
if job_id:
|
||||
output_path = self.output_dir / job_id / f"{timestamp}.md"
|
||||
else:
|
||||
output_path = self.output_dir / "misc" / f"{timestamp}.md"
|
||||
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Build the output document
|
||||
lines = []
|
||||
if job_name:
|
||||
lines.append(f"# {job_name}")
|
||||
else:
|
||||
lines.append("# Delivery Output")
|
||||
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
|
||||
if job_id:
|
||||
lines.append(f"**Job ID:** {job_id}")
|
||||
|
||||
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
lines.append(f"**{key}:** {value}")
|
||||
|
||||
|
||||
lines.append("")
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
lines.append(content)
|
||||
|
||||
|
||||
output_path.write_text("\n".join(lines))
|
||||
|
||||
return {
|
||||
"path": str(output_path),
|
||||
"timestamp": timestamp
|
||||
}
|
||||
|
||||
|
||||
return {"path": str(output_path), "timestamp": timestamp}
|
||||
|
||||
def _save_full_output(self, content: str, job_id: str) -> Path:
|
||||
"""Save full cron output to disk and return the file path."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -261,41 +244,33 @@ class DeliveryRouter:
|
||||
return path
|
||||
|
||||
async def _deliver_to_platform(
|
||||
self,
|
||||
target: DeliveryTarget,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
self, target: DeliveryTarget, content: str, metadata: dict[str, Any] | None
|
||||
) -> dict[str, Any]:
|
||||
"""Deliver content to a messaging platform."""
|
||||
adapter = self.adapters.get(target.platform)
|
||||
|
||||
|
||||
if not adapter:
|
||||
raise ValueError(f"No adapter configured for {target.platform.value}")
|
||||
|
||||
|
||||
if not target.chat_id:
|
||||
raise ValueError(f"No chat ID for {target.platform.value} delivery")
|
||||
|
||||
|
||||
# Guard: truncate oversized cron output to stay within platform limits
|
||||
if len(content) > MAX_PLATFORM_OUTPUT:
|
||||
job_id = (metadata or {}).get("job_id", "unknown")
|
||||
saved_path = self._save_full_output(content, job_id)
|
||||
logger.info("Cron output truncated (%d chars) — full output: %s", len(content), saved_path)
|
||||
content = (
|
||||
content[:TRUNCATED_VISIBLE]
|
||||
+ f"\n\n... [truncated, full output saved to {saved_path}]"
|
||||
)
|
||||
|
||||
content = content[:TRUNCATED_VISIBLE] + f"\n\n... [truncated, full output saved to {saved_path}]"
|
||||
|
||||
return await adapter.send(target.chat_id, content, metadata=metadata)
|
||||
|
||||
|
||||
def parse_deliver_spec(
|
||||
deliver: Optional[Union[str, List[str]]],
|
||||
origin: Optional[SessionSource] = None,
|
||||
default: str = "origin"
|
||||
) -> Union[str, List[str]]:
|
||||
deliver: str | list[str] | None, origin: SessionSource | None = None, default: str = "origin"
|
||||
) -> str | list[str]:
|
||||
"""
|
||||
Normalize a delivery specification.
|
||||
|
||||
|
||||
If None or empty, returns the default.
|
||||
"""
|
||||
if not deliver:
|
||||
@@ -303,17 +278,14 @@ def parse_deliver_spec(
|
||||
return deliver
|
||||
|
||||
|
||||
def build_delivery_context_for_tool(
|
||||
config: GatewayConfig,
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> Dict[str, Any]:
|
||||
def build_delivery_context_for_tool(config: GatewayConfig, origin: SessionSource | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Build context for the schedule_cronjob tool to understand delivery options.
|
||||
|
||||
|
||||
This is passed to the tool so it can validate and explain delivery targets.
|
||||
"""
|
||||
connected = config.get_connected_platforms()
|
||||
|
||||
|
||||
options = {
|
||||
"origin": {
|
||||
"description": "Back to where this job was created",
|
||||
@@ -322,9 +294,9 @@ def build_delivery_context_for_tool(
|
||||
"local": {
|
||||
"description": "Save to local files only",
|
||||
"available": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for platform in connected:
|
||||
home = config.get_home_channel(platform)
|
||||
options[platform.value] = {
|
||||
@@ -332,7 +304,7 @@ def build_delivery_context_for_tool(
|
||||
"available": True,
|
||||
"home_channel": home.to_dict() if home else None,
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"origin": origin.to_dict() if origin else None,
|
||||
"options": options,
|
||||
|
||||
@@ -21,12 +21,12 @@ Errors in hooks are caught and logged but never block the main pipeline.
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
HOOKS_DIR = Path(os.path.expanduser("~/.hermes/hooks"))
|
||||
|
||||
|
||||
@@ -42,11 +42,11 @@ class HookRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# event_type -> [handler_fn, ...]
|
||||
self._handlers: Dict[str, List[Callable]] = {}
|
||||
self._loaded_hooks: List[dict] = [] # metadata for listing
|
||||
self._handlers: dict[str, list[Callable]] = {}
|
||||
self._loaded_hooks: list[dict] = [] # metadata for listing
|
||||
|
||||
@property
|
||||
def loaded_hooks(self) -> List[dict]:
|
||||
def loaded_hooks(self) -> list[dict]:
|
||||
"""Return metadata about all loaded hooks."""
|
||||
return list(self._loaded_hooks)
|
||||
|
||||
@@ -84,9 +84,7 @@ class HookRegistry:
|
||||
continue
|
||||
|
||||
# Dynamically load the handler module
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
f"hermes_hook_{hook_name}", handler_path
|
||||
)
|
||||
spec = importlib.util.spec_from_file_location(f"hermes_hook_{hook_name}", handler_path)
|
||||
if spec is None or spec.loader is None:
|
||||
print(f"[hooks] Skipping {hook_name}: could not load handler.py", flush=True)
|
||||
continue
|
||||
@@ -103,19 +101,21 @@ class HookRegistry:
|
||||
for event in events:
|
||||
self._handlers.setdefault(event, []).append(handle_fn)
|
||||
|
||||
self._loaded_hooks.append({
|
||||
"name": hook_name,
|
||||
"description": manifest.get("description", ""),
|
||||
"events": events,
|
||||
"path": str(hook_dir),
|
||||
})
|
||||
self._loaded_hooks.append(
|
||||
{
|
||||
"name": hook_name,
|
||||
"description": manifest.get("description", ""),
|
||||
"events": events,
|
||||
"path": str(hook_dir),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"[hooks] Loaded hook '{hook_name}' for events: {events}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[hooks] Error loading hook {hook_dir.name}: {e}", flush=True)
|
||||
|
||||
async def emit(self, event_type: str, context: Optional[Dict[str, Any]] = None) -> None:
|
||||
async def emit(self, event_type: str, context: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Fire all handlers registered for an event.
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,7 +60,7 @@ def mirror_to_session(
|
||||
return False
|
||||
|
||||
|
||||
def _find_session_id(platform: str, chat_id: str) -> Optional[str]:
|
||||
def _find_session_id(platform: str, chat_id: str) -> str | None:
|
||||
"""
|
||||
Find the active session_id for a platform + chat_id pair.
|
||||
|
||||
@@ -113,6 +112,7 @@ def _append_to_sqlite(session_id: str, message: dict) -> None:
|
||||
"""Append a message to the SQLite session database."""
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB()
|
||||
db.append_message(
|
||||
session_id=session_id,
|
||||
|
||||
@@ -23,21 +23,19 @@ import os
|
||||
import secrets
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Unambiguous alphabet -- excludes 0/O, 1/I to prevent confusion
|
||||
ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
CODE_LENGTH = 8
|
||||
|
||||
# Timing constants
|
||||
CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour
|
||||
RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes
|
||||
LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures
|
||||
CODE_TTL_SECONDS = 3600 # Codes expire after 1 hour
|
||||
RATE_LIMIT_SECONDS = 600 # 1 request per user per 10 minutes
|
||||
LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures
|
||||
|
||||
# Limits
|
||||
MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform
|
||||
MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout
|
||||
MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform
|
||||
MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout
|
||||
|
||||
PAIRING_DIR = Path(os.path.expanduser("~/.hermes/pairing"))
|
||||
|
||||
@@ -123,9 +121,7 @@ class PairingStore:
|
||||
|
||||
# ----- Pending codes -----
|
||||
|
||||
def generate_code(
|
||||
self, platform: str, user_id: str, user_name: str = ""
|
||||
) -> Optional[str]:
|
||||
def generate_code(self, platform: str, user_id: str, user_name: str = "") -> str | None:
|
||||
"""
|
||||
Generate a pairing code for a new user.
|
||||
|
||||
@@ -165,7 +161,7 @@ class PairingStore:
|
||||
|
||||
return code
|
||||
|
||||
def approve_code(self, platform: str, code: str) -> Optional[dict]:
|
||||
def approve_code(self, platform: str, code: str) -> dict | None:
|
||||
"""
|
||||
Approve a pairing code. Adds the user to the approved list.
|
||||
|
||||
@@ -199,13 +195,15 @@ class PairingStore:
|
||||
pending = self._load_json(self._pending_path(p))
|
||||
for code, info in pending.items():
|
||||
age_min = int((time.time() - info["created_at"]) / 60)
|
||||
results.append({
|
||||
"platform": p,
|
||||
"code": code,
|
||||
"user_id": info["user_id"],
|
||||
"user_name": info.get("user_name", ""),
|
||||
"age_minutes": age_min,
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"platform": p,
|
||||
"code": code,
|
||||
"user_id": info["user_id"],
|
||||
"user_name": info.get("user_name", ""),
|
||||
"age_minutes": age_min,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def clear_pending(self, platform: str = None) -> int:
|
||||
@@ -251,8 +249,11 @@ class PairingStore:
|
||||
lockout_key = f"_lockout:{platform}"
|
||||
limits[lockout_key] = time.time() + LOCKOUT_SECONDS
|
||||
limits[fail_key] = 0 # Reset counter
|
||||
print(f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s "
|
||||
f"after {MAX_FAILED_ATTEMPTS} failed attempts", flush=True)
|
||||
print(
|
||||
f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s "
|
||||
f"after {MAX_FAILED_ATTEMPTS} failed attempts",
|
||||
flush=True,
|
||||
)
|
||||
self._save_json(self._rate_limit_path(), limits)
|
||||
|
||||
# ----- Cleanup -----
|
||||
@@ -262,10 +263,7 @@ class PairingStore:
|
||||
path = self._pending_path(platform)
|
||||
pending = self._load_json(path)
|
||||
now = time.time()
|
||||
expired = [
|
||||
code for code, info in pending.items()
|
||||
if (now - info["created_at"]) > CODE_TTL_SECONDS
|
||||
]
|
||||
expired = [code for code, info in pending.items() if (now - info["created_at"]) > CODE_TTL_SECONDS]
|
||||
if expired:
|
||||
for code in expired:
|
||||
del pending[code]
|
||||
|
||||
@@ -303,8 +303,8 @@ Optional but valuable:
|
||||
After implementing everything, verify with:
|
||||
|
||||
```bash
|
||||
# All tests pass
|
||||
python -m pytest tests/ -q
|
||||
# All checks pass (lint + test)
|
||||
make check
|
||||
|
||||
# Grep for your platform name to find any missed integration points
|
||||
grep -r "telegram\|discord\|whatsapp\|slack" gateway/ tools/ agent/ cron/ hermes_cli/ toolsets.py \
|
||||
|
||||
@@ -13,20 +13,20 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from pathlib import Path as _Path
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image cache utilities
|
||||
#
|
||||
@@ -251,6 +251,7 @@ def cleanup_document_cache(max_age_hours: int = 24) -> int:
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of incoming messages."""
|
||||
|
||||
TEXT = "text"
|
||||
LOCATION = "location"
|
||||
PHOTO = "photo"
|
||||
@@ -266,42 +267,43 @@ class MessageType(Enum):
|
||||
class MessageEvent:
|
||||
"""
|
||||
Incoming message from a platform.
|
||||
|
||||
|
||||
Normalized representation that all adapters produce.
|
||||
"""
|
||||
|
||||
# Message content
|
||||
text: str
|
||||
message_type: MessageType = MessageType.TEXT
|
||||
|
||||
|
||||
# Source information
|
||||
source: SessionSource = None
|
||||
|
||||
|
||||
# Original platform data
|
||||
raw_message: Any = None
|
||||
message_id: Optional[str] = None
|
||||
|
||||
message_id: str | None = None
|
||||
|
||||
# Media attachments
|
||||
media_urls: List[str] = field(default_factory=list)
|
||||
media_types: List[str] = field(default_factory=list)
|
||||
|
||||
media_urls: list[str] = field(default_factory=list)
|
||||
media_types: list[str] = field(default_factory=list)
|
||||
|
||||
# Reply context
|
||||
reply_to_message_id: Optional[str] = None
|
||||
|
||||
reply_to_message_id: str | None = None
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
def is_command(self) -> bool:
|
||||
"""Check if this is a command message (e.g., /new, /reset)."""
|
||||
return self.text.startswith("/")
|
||||
|
||||
def get_command(self) -> Optional[str]:
|
||||
|
||||
def get_command(self) -> str | None:
|
||||
"""Extract command name if this is a command message."""
|
||||
if not self.is_command():
|
||||
return None
|
||||
# Split on space and get first word, strip the /
|
||||
parts = self.text.split(maxsplit=1)
|
||||
return parts[0][1:].lower() if parts else None
|
||||
|
||||
|
||||
def get_command_args(self) -> str:
|
||||
"""Get the arguments after a command."""
|
||||
if not self.is_command():
|
||||
@@ -310,91 +312,88 @@ class MessageEvent:
|
||||
return parts[1] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class SendResult:
|
||||
"""Result of sending a message."""
|
||||
|
||||
success: bool
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
message_id: str | None = None
|
||||
error: str | None = None
|
||||
raw_response: Any = None
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[str | None]]
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
Base class for platform adapters.
|
||||
|
||||
|
||||
Subclasses implement platform-specific logic for:
|
||||
- Connecting and authenticating
|
||||
- Receiving messages
|
||||
- Sending messages/responses
|
||||
- Handling media
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig, platform: Platform):
|
||||
self.config = config
|
||||
self.platform = platform
|
||||
self._message_handler: Optional[MessageHandler] = None
|
||||
self._message_handler: MessageHandler | None = None
|
||||
self._running = False
|
||||
|
||||
|
||||
# Track active message handlers per session for interrupt support
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
|
||||
self._active_sessions: dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: dict[str, MessageEvent] = {}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for this adapter."""
|
||||
return self.platform.value.title()
|
||||
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if adapter is currently connected."""
|
||||
return self._running
|
||||
|
||||
|
||||
def set_message_handler(self, handler: MessageHandler) -> None:
|
||||
"""
|
||||
Set the handler for incoming messages.
|
||||
|
||||
|
||||
The handler receives a MessageEvent and should return
|
||||
an optional response string.
|
||||
"""
|
||||
self._message_handler = handler
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Connect to the platform and start receiving messages.
|
||||
|
||||
|
||||
Returns True if connection was successful.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the platform."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID to send to
|
||||
content: Message content (may be markdown)
|
||||
reply_to: Optional message ID to reply to
|
||||
metadata: Additional platform-specific options
|
||||
|
||||
|
||||
Returns:
|
||||
SendResult with success status and message ID
|
||||
"""
|
||||
@@ -416,21 +415,21 @@ class BasePlatformAdapter(ABC):
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""
|
||||
Send a typing indicator.
|
||||
|
||||
|
||||
Override in subclasses if the platform supports it.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an image natively via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send images as proper attachments
|
||||
instead of plain-text URLs. Default falls back to sending the
|
||||
URL as a text message.
|
||||
@@ -438,87 +437,91 @@ class BasePlatformAdapter(ABC):
|
||||
# Fallback: send URL as text (subclasses override for native images)
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an animated GIF natively via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send GIFs as proper animations
|
||||
(e.g., Telegram send_animation) so they auto-play inline.
|
||||
Default falls back to send_image.
|
||||
"""
|
||||
return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _is_animation_url(url: str) -> bool:
|
||||
"""Check if a URL points to an animated GIF (vs a static image)."""
|
||||
lower = url.lower().split('?')[0] # Strip query params
|
||||
return lower.endswith('.gif')
|
||||
lower = url.lower().split("?")[0] # Strip query params
|
||||
return lower.endswith(".gif")
|
||||
|
||||
@staticmethod
|
||||
def extract_images(content: str) -> Tuple[List[Tuple[str, str]], str]:
|
||||
def extract_images(content: str) -> tuple[list[tuple[str, str]], str]:
|
||||
"""
|
||||
Extract image URLs from markdown and HTML image tags in a response.
|
||||
|
||||
|
||||
Finds patterns like:
|
||||
- 
|
||||
- <img src="https://example.com/image.png">
|
||||
- <img src="https://example.com/image.png"></img>
|
||||
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed).
|
||||
"""
|
||||
images = []
|
||||
cleaned = content
|
||||
|
||||
|
||||
# Match markdown images: 
|
||||
md_pattern = r'!\[([^\]]*)\]\((https?://[^\s\)]+)\)'
|
||||
md_pattern = r"!\[([^\]]*)\]\((https?://[^\s\)]+)\)"
|
||||
for match in re.finditer(md_pattern, content):
|
||||
alt_text = match.group(1)
|
||||
url = match.group(2)
|
||||
# Only extract URLs that look like actual images
|
||||
if any(url.lower().endswith(ext) or ext in url.lower() for ext in
|
||||
['.png', '.jpg', '.jpeg', '.gif', '.webp', 'fal.media', 'fal-cdn', 'replicate.delivery']):
|
||||
if any(
|
||||
url.lower().endswith(ext) or ext in url.lower()
|
||||
for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp", "fal.media", "fal-cdn", "replicate.delivery"]
|
||||
):
|
||||
images.append((url, alt_text))
|
||||
|
||||
|
||||
# Match HTML img tags: <img src="url"> or <img src="url"></img> or <img src="url"/>
|
||||
html_pattern = r'<img\s+src=["\']?(https?://[^\s"\'<>]+)["\']?\s*/?>\s*(?:</img>)?'
|
||||
for match in re.finditer(html_pattern, content):
|
||||
url = match.group(1)
|
||||
images.append((url, ""))
|
||||
|
||||
|
||||
# Remove only the matched image tags from content (not all markdown images)
|
||||
if images:
|
||||
extracted_urls = {url for url, _ in images}
|
||||
|
||||
def _remove_if_extracted(match):
|
||||
url = match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||
return '' if url in extracted_urls else match.group(0)
|
||||
return "" if url in extracted_urls else match.group(0)
|
||||
|
||||
cleaned = re.sub(md_pattern, _remove_if_extracted, cleaned)
|
||||
cleaned = re.sub(html_pattern, _remove_if_extracted, cleaned)
|
||||
# Clean up leftover blank lines
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
|
||||
|
||||
return images, cleaned
|
||||
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an audio file as a native voice message via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send audio as voice bubbles (Telegram)
|
||||
or file attachments (Discord). Default falls back to sending the
|
||||
file path as text.
|
||||
@@ -532,8 +535,8 @@ class BasePlatformAdapter(ABC):
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a video natively via the platform API.
|
||||
@@ -550,9 +553,9 @@ class BasePlatformAdapter(ABC):
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a document/file natively via the platform API.
|
||||
@@ -569,8 +572,8 @@ class BasePlatformAdapter(ABC):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a local image file natively via the platform API.
|
||||
@@ -585,45 +588,45 @@ class BasePlatformAdapter(ABC):
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]:
|
||||
def extract_media(content: str) -> tuple[list[tuple[str, bool]], str]:
|
||||
"""
|
||||
Extract MEDIA:<path> tags and [[audio_as_voice]] directives from response text.
|
||||
|
||||
|
||||
The TTS tool returns responses like:
|
||||
[[audio_as_voice]]
|
||||
MEDIA:/path/to/audio.ogg
|
||||
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed).
|
||||
"""
|
||||
media = []
|
||||
cleaned = content
|
||||
|
||||
|
||||
# Check for [[audio_as_voice]] directive
|
||||
has_voice_tag = "[[audio_as_voice]]" in content
|
||||
cleaned = cleaned.replace("[[audio_as_voice]]", "")
|
||||
|
||||
|
||||
# Extract MEDIA:<path> tags (path may contain spaces)
|
||||
media_pattern = r'MEDIA:(\S+)'
|
||||
media_pattern = r"MEDIA:(\S+)"
|
||||
for match in re.finditer(media_pattern, content):
|
||||
path = match.group(1).strip()
|
||||
if path:
|
||||
media.append((path, has_voice_tag))
|
||||
|
||||
|
||||
# Remove MEDIA tags from content
|
||||
if media:
|
||||
cleaned = re.sub(media_pattern, '', cleaned)
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
cleaned = re.sub(media_pattern, "", cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
|
||||
|
||||
return media, cleaned
|
||||
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
|
||||
|
||||
Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2
|
||||
to recover quickly after progress messages interrupt it.
|
||||
"""
|
||||
@@ -633,20 +636,20 @@ class BasePlatformAdapter(ABC):
|
||||
await asyncio.sleep(interval)
|
||||
except asyncio.CancelledError:
|
||||
pass # Normal cancellation when handler completes
|
||||
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
|
||||
|
||||
This method returns quickly by spawning background tasks.
|
||||
This allows new messages to be processed even while an agent is running,
|
||||
enabling interruption support.
|
||||
"""
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
|
||||
session_key = event.source.chat_id
|
||||
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
# Store this as a pending message - it will interrupt the running agent
|
||||
@@ -655,10 +658,10 @@ class BasePlatformAdapter(ABC):
|
||||
# Signal the interrupt (the processing task checks this)
|
||||
self._active_sessions[session_key].set()
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
|
||||
# Spawn background task to process this message
|
||||
asyncio.create_task(self._process_message_background(event, session_key))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_human_delay() -> float:
|
||||
"""
|
||||
@@ -685,35 +688,40 @@ class BasePlatformAdapter(ABC):
|
||||
# Create interrupt event for this session
|
||||
interrupt_event = asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
||||
|
||||
|
||||
try:
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
|
||||
|
||||
# Send response if any
|
||||
if not response:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
if images:
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
logger.info(
|
||||
"[%s] extract_images found %d image(s) in response (%d chars)",
|
||||
self.name,
|
||||
len(images),
|
||||
len(response),
|
||||
)
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
reply_to=event.message_id
|
||||
logger.info(
|
||||
"[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id
|
||||
)
|
||||
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id, content=text_content, reply_to=event.message_id
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
@@ -721,14 +729,14 @@ class BasePlatformAdapter(ABC):
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id
|
||||
reply_to=event.message_id,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
@@ -736,7 +744,12 @@ class BasePlatformAdapter(ABC):
|
||||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "")
|
||||
logger.info(
|
||||
"[%s] Sending image: %s (alt=%s)",
|
||||
self.name,
|
||||
image_url[:80],
|
||||
alt_text[:30] if alt_text else "",
|
||||
)
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
if self._is_animation_url(image_url):
|
||||
img_result = await self.send_animation(
|
||||
@@ -754,11 +767,11 @@ class BasePlatformAdapter(ABC):
|
||||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'}
|
||||
_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'}
|
||||
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"}
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
for media_path, is_voice in media_files:
|
||||
if human_delay > 0:
|
||||
@@ -790,7 +803,7 @@ class BasePlatformAdapter(ABC):
|
||||
print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}")
|
||||
except Exception as media_err:
|
||||
print(f"[{self.name}] Error sending media: {media_err}")
|
||||
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
pending_event = self._pending_messages.pop(session_key)
|
||||
@@ -806,10 +819,11 @@ class BasePlatformAdapter(ABC):
|
||||
# Process pending message in new background task
|
||||
await self._process_message_background(pending_event, session_key)
|
||||
return # Already cleaned up
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error handling message: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Stop typing indicator
|
||||
@@ -821,26 +835,26 @@ class BasePlatformAdapter(ABC):
|
||||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
|
||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||
"""Check if there's a pending interrupt for a session."""
|
||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||
|
||||
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
||||
|
||||
def get_pending_message(self, session_key: str) -> MessageEvent | None:
|
||||
"""Get and clear any pending message for a session."""
|
||||
return self._pending_messages.pop(session_key, None)
|
||||
|
||||
|
||||
def build_source(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_name: Optional[str] = None,
|
||||
chat_name: str | None = None,
|
||||
chat_type: str = "dm",
|
||||
user_id: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
chat_topic: Optional[str] = None,
|
||||
user_id_alt: Optional[str] = None,
|
||||
chat_id_alt: Optional[str] = None,
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
chat_topic: str | None = None,
|
||||
user_id_alt: str | None = None,
|
||||
chat_id_alt: str | None = None,
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
# Normalize empty topic to None
|
||||
@@ -858,30 +872,30 @@ class BasePlatformAdapter(ABC):
|
||||
user_id_alt=user_id_alt,
|
||||
chat_id_alt=chat_id_alt,
|
||||
)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about a chat/channel.
|
||||
|
||||
|
||||
Returns dict with at least:
|
||||
- name: Chat name
|
||||
- type: "dm", "group", "channel"
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format a message for this platform.
|
||||
|
||||
|
||||
Override in subclasses to handle platform-specific formatting
|
||||
(e.g., Telegram MarkdownV2, Discord markdown).
|
||||
|
||||
|
||||
Default implementation returns content as-is.
|
||||
"""
|
||||
return content
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> List[str]:
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> list[str]:
|
||||
"""
|
||||
Split a long message into chunks, preserving code block boundaries.
|
||||
|
||||
@@ -900,14 +914,14 @@ class BasePlatformAdapter(ABC):
|
||||
if len(content) <= max_length:
|
||||
return [content]
|
||||
|
||||
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
|
||||
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
|
||||
FENCE_CLOSE = "\n```"
|
||||
|
||||
chunks: List[str] = []
|
||||
chunks: list[str] = []
|
||||
remaining = content
|
||||
# When the previous chunk ended mid-code-block, this holds the
|
||||
# language tag (possibly "") so we can reopen the fence.
|
||||
carry_lang: Optional[str] = None
|
||||
carry_lang: str | None = None
|
||||
|
||||
while remaining:
|
||||
# If we're continuing a code block from the previous chunk,
|
||||
@@ -965,8 +979,6 @@ class BasePlatformAdapter(ABC):
|
||||
# Append chunk indicators when the response spans multiple messages
|
||||
if len(chunks) > 1:
|
||||
total = len(chunks)
|
||||
chunks = [
|
||||
f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)
|
||||
]
|
||||
chunks = [f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)]
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -10,14 +10,16 @@ Uses discord.py library for:
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import discord
|
||||
from discord import Message as DiscordMessage, Intents
|
||||
from discord import Intents
|
||||
from discord import Message as DiscordMessage
|
||||
from discord.ext import commands
|
||||
|
||||
DISCORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
DISCORD_AVAILABLE = False
|
||||
@@ -28,6 +30,7 @@ except ImportError:
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
@@ -36,8 +39,8 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,7 +52,7 @@ def check_discord_requirements() -> bool:
|
||||
class DiscordAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Discord bot adapter.
|
||||
|
||||
|
||||
Handles:
|
||||
- Receiving messages from servers and DMs
|
||||
- Sending responses with Discord markdown
|
||||
@@ -59,26 +62,26 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
- Auto-threading for long conversations
|
||||
- Reaction-based feedback
|
||||
"""
|
||||
|
||||
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.DISCORD)
|
||||
self._client: Optional[commands.Bot] = None
|
||||
self._client: commands.Bot | None = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
print(f"[{self.name}] discord.py not installed. Run: pip install discord.py")
|
||||
return False
|
||||
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Set up intents -- members intent needed for username-to-ID resolution
|
||||
intents = Intents.default()
|
||||
@@ -86,30 +89,28 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
intents.members = True
|
||||
|
||||
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
)
|
||||
|
||||
|
||||
# Parse allowed user entries (may contain usernames or IDs)
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
self._allowed_user_ids = {
|
||||
uid.strip() for uid in allowed_env.split(",") if uid.strip()
|
||||
}
|
||||
|
||||
self._allowed_user_ids = {uid.strip() for uid in allowed_env.split(",") if uid.strip()}
|
||||
|
||||
adapter_self = self # capture for closure
|
||||
|
||||
|
||||
# Register event handlers
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
print(f"[{adapter_self.name}] Connected as {adapter_self._client.user}")
|
||||
|
||||
|
||||
# Resolve any usernames in the allowed list to numeric IDs
|
||||
await adapter_self._resolve_allowed_usernames()
|
||||
|
||||
|
||||
# Sync slash commands with Discord
|
||||
try:
|
||||
synced = await adapter_self._client.tree.sync()
|
||||
@@ -117,33 +118,33 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{adapter_self.name}] Slash command sync failed: {e}")
|
||||
adapter_self._ready_event.set()
|
||||
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Ignore bot's own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
await self._handle_message(message)
|
||||
|
||||
|
||||
# Register slash commands
|
||||
self._register_slash_commands()
|
||||
|
||||
|
||||
# Start the bot in background
|
||||
asyncio.create_task(self._client.start(self.config.token))
|
||||
|
||||
|
||||
# Wait for ready
|
||||
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
|
||||
|
||||
|
||||
self._running = True
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
except TimeoutError:
|
||||
print(f"[{self.name}] Timeout waiting for connection")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
if self._client:
|
||||
@@ -151,59 +152,55 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Discord channel."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
|
||||
message_ids = []
|
||||
reference = None
|
||||
|
||||
|
||||
if reply_to:
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
reference = ref_msg
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=reference if i == 0 else None,
|
||||
)
|
||||
message_ids.append(str(msg.id))
|
||||
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids}
|
||||
raw_response={"message_ids": message_ids},
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -223,7 +220,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
formatted = self.format_message(content)
|
||||
if len(formatted) > self.MAX_MESSAGE_LENGTH:
|
||||
formatted = formatted[:self.MAX_MESSAGE_LENGTH - 3] + "..."
|
||||
formatted = formatted[: self.MAX_MESSAGE_LENGTH - 3] + "..."
|
||||
await msg.edit(content=formatted)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
@@ -233,28 +230,28 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
|
||||
# Determine filename from path
|
||||
filename = os.path.basename(audio_path)
|
||||
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
@@ -262,36 +259,36 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
@@ -299,7 +296,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
@@ -308,31 +305,31 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
|
||||
image_data = await resp.read()
|
||||
|
||||
|
||||
# Determine filename from URL or content type
|
||||
content_type = resp.headers.get("content-type", "image/png")
|
||||
ext = "png"
|
||||
@@ -342,23 +339,24 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
ext = "gif"
|
||||
elif "webp" in content_type:
|
||||
ext = "webp"
|
||||
|
||||
|
||||
import io
|
||||
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, falling back to URL. Run: pip install aiohttp")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send image attachment, falling back to URL: {e}")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
@@ -368,20 +366,20 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
await channel.typing()
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
if not self._client:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
|
||||
if not channel:
|
||||
return {"name": str(chat_id), "type": "dm"}
|
||||
|
||||
|
||||
# Determine channel type
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
@@ -397,7 +395,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
chat_type = "channel"
|
||||
name = getattr(channel, "name", str(chat_id))
|
||||
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"type": chat_type,
|
||||
@@ -406,7 +404,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
|
||||
async def _resolve_allowed_usernames(self) -> None:
|
||||
"""
|
||||
Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs.
|
||||
@@ -453,8 +451,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
uid = str(member.id)
|
||||
numeric_ids.add(uid)
|
||||
resolved_count += 1
|
||||
matched_name = name_lower if name_lower in to_resolve else (
|
||||
display_lower if display_lower in to_resolve else global_lower
|
||||
matched_name = (
|
||||
name_lower
|
||||
if name_lower in to_resolve
|
||||
else (display_lower if display_lower in to_resolve else global_lower)
|
||||
)
|
||||
to_resolve.discard(matched_name)
|
||||
print(f"[{self.name}] Resolved '{matched_name}' -> {uid} ({member.name}#{member.discriminator})")
|
||||
@@ -474,12 +474,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format message for Discord.
|
||||
|
||||
|
||||
Discord uses its own markdown variant.
|
||||
"""
|
||||
# Discord markdown is fairly standard, no special escaping needed
|
||||
return content
|
||||
|
||||
|
||||
def _register_slash_commands(self) -> None:
|
||||
"""Register Discord slash commands on the command tree."""
|
||||
if not self._client:
|
||||
@@ -694,7 +694,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_name = interaction.channel.name
|
||||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||||
|
||||
|
||||
# Get channel topic (if available)
|
||||
chat_topic = getattr(interaction.channel, "topic", None)
|
||||
|
||||
@@ -715,9 +715,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
raw_message=interaction,
|
||||
)
|
||||
|
||||
async def send_exec_approval(
|
||||
self, chat_id: str, command: str, approval_id: str
|
||||
) -> SendResult:
|
||||
async def send_exec_approval(self, chat_id: str, command: str, approval_id: str) -> SendResult:
|
||||
"""
|
||||
Send a button-based exec approval prompt for a dangerous command.
|
||||
|
||||
@@ -759,28 +757,28 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# bot responds to every message without needing a mention.
|
||||
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
|
||||
# globally (all channels become free-response). Default: "true".
|
||||
|
||||
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
# Check if this channel is in the free-response list
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
|
||||
# Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
|
||||
|
||||
is_free_channel = channel_id in free_channels
|
||||
|
||||
|
||||
if require_mention and not is_free_channel:
|
||||
# Must be @mentioned to respond
|
||||
if self._client.user not in message.mentions:
|
||||
return # Silently ignore messages that don't mention the bot
|
||||
|
||||
|
||||
# Strip the bot mention from the message text so the agent sees clean input
|
||||
if self._client.user and self._client.user in message.mentions:
|
||||
message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if message.content.startswith("/"):
|
||||
@@ -798,7 +796,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
|
||||
# Determine chat type
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
@@ -811,15 +809,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_name = getattr(message.channel, "name", str(message.channel.id))
|
||||
if hasattr(message.channel, "guild") and message.channel.guild:
|
||||
chat_name = f"{message.channel.guild.name} / #{chat_name}"
|
||||
|
||||
|
||||
# Get thread ID if in a thread
|
||||
thread_id = None
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
thread_id = str(message.channel.id)
|
||||
|
||||
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
|
||||
chat_topic = getattr(message.channel, "topic", None)
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(message.channel.id),
|
||||
@@ -830,7 +828,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
@@ -869,7 +867,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Other attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
message_type=msg_type,
|
||||
@@ -881,7 +879,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
)
|
||||
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
@@ -911,20 +909,14 @@ if DISCORD_AVAILABLE:
|
||||
return True # No allowlist = anyone can approve
|
||||
return str(interaction.user.id) in self.allowed_user_ids
|
||||
|
||||
async def _resolve(
|
||||
self, interaction: discord.Interaction, action: str, color: discord.Color
|
||||
):
|
||||
async def _resolve(self, interaction: discord.Interaction, action: str, color: discord.Color):
|
||||
"""Resolve the approval and update the message."""
|
||||
if self.resolved:
|
||||
await interaction.response.send_message(
|
||||
"This approval has already been resolved~", ephemeral=True
|
||||
)
|
||||
await interaction.response.send_message("This approval has already been resolved~", ephemeral=True)
|
||||
return
|
||||
|
||||
if not self._check_auth(interaction):
|
||||
await interaction.response.send_message(
|
||||
"You're not authorized to approve commands~", ephemeral=True
|
||||
)
|
||||
await interaction.response.send_message("You're not authorized to approve commands~", ephemeral=True)
|
||||
return
|
||||
|
||||
self.resolved = True
|
||||
@@ -944,6 +936,7 @@ if DISCORD_AVAILABLE:
|
||||
# Store the approval decision
|
||||
try:
|
||||
from tools.approval import approve_permanent
|
||||
|
||||
if action == "allow_once":
|
||||
pass # One-time approval handled by gateway
|
||||
elif action == "allow_always":
|
||||
@@ -952,21 +945,15 @@ if DISCORD_AVAILABLE:
|
||||
pass
|
||||
|
||||
@discord.ui.button(label="Allow Once", style=discord.ButtonStyle.green)
|
||||
async def allow_once(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
async def allow_once(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "allow_once", discord.Color.green())
|
||||
|
||||
@discord.ui.button(label="Always Allow", style=discord.ButtonStyle.blurple)
|
||||
async def allow_always(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
async def allow_always(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "allow_always", discord.Color.blue())
|
||||
|
||||
@discord.ui.button(label="Deny", style=discord.ButtonStyle.red)
|
||||
async def deny(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
async def deny(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "deny", discord.Color.red())
|
||||
|
||||
async def on_timeout(self):
|
||||
|
||||
@@ -19,10 +19,11 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
@@ -66,10 +67,10 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
super().__init__(config, Platform.HOMEASSISTANT)
|
||||
|
||||
# Connection state
|
||||
self._session: Optional["aiohttp.ClientSession"] = None
|
||||
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
|
||||
self._rest_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._listen_task: Optional[asyncio.Task] = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._rest_session: aiohttp.ClientSession | None = None
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._msg_id: int = 0
|
||||
|
||||
# Configuration from extra
|
||||
@@ -80,13 +81,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
self._hass_token: str = token
|
||||
|
||||
# Event filtering
|
||||
self._watch_domains: Set[str] = set(extra.get("watch_domains", []))
|
||||
self._watch_entities: Set[str] = set(extra.get("watch_entities", []))
|
||||
self._ignore_entities: Set[str] = set(extra.get("ignore_entities", []))
|
||||
self._watch_domains: set[str] = set(extra.get("watch_domains", []))
|
||||
self._watch_entities: set[str] = set(extra.get("watch_entities", []))
|
||||
self._ignore_entities: set[str] = set(extra.get("ignore_entities", []))
|
||||
self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30))
|
||||
|
||||
# Cooldown tracking: entity_id -> last_event_timestamp
|
||||
self._last_event_time: Dict[str, float] = {}
|
||||
self._last_event_time: dict[str, float] = {}
|
||||
|
||||
def _next_id(self) -> int:
|
||||
"""Return the next WebSocket message ID."""
|
||||
@@ -141,10 +142,12 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
# Step 2: Send auth
|
||||
await self._ws.send_json({
|
||||
"type": "auth",
|
||||
"access_token": self._hass_token,
|
||||
})
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"type": "auth",
|
||||
"access_token": self._hass_token,
|
||||
}
|
||||
)
|
||||
|
||||
# Step 3: Wait for auth_ok
|
||||
msg = await self._ws.receive_json()
|
||||
@@ -155,11 +158,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
|
||||
# Step 4: Subscribe to state_changed events
|
||||
sub_id = self._next_id()
|
||||
await self._ws.send_json({
|
||||
"id": sub_id,
|
||||
"type": "subscribe_events",
|
||||
"event_type": "state_changed",
|
||||
})
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"id": sub_id,
|
||||
"type": "subscribe_events",
|
||||
"event_type": "state_changed",
|
||||
}
|
||||
)
|
||||
|
||||
# Verify subscription acknowledgement
|
||||
msg = await self._ws.receive_json()
|
||||
@@ -245,7 +250,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
|
||||
break
|
||||
|
||||
async def _handle_ha_event(self, event: Dict[str, Any]) -> None:
|
||||
async def _handle_ha_event(self, event: dict[str, Any]) -> None:
|
||||
"""Process a state_changed event from Home Assistant."""
|
||||
event_data = event.get("data", {})
|
||||
entity_id: str = event_data.get("entity_id", "")
|
||||
@@ -302,9 +307,9 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
@staticmethod
|
||||
def _format_state_change(
|
||||
entity_id: str,
|
||||
old_state: Dict[str, Any],
|
||||
new_state: Dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
old_state: dict[str, Any],
|
||||
new_state: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""Convert a state_changed event into a human-readable description."""
|
||||
if not new_state:
|
||||
return None
|
||||
@@ -331,10 +336,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
|
||||
if domain == "sensor":
|
||||
unit = new_state.get("attributes", {}).get("unit_of_measurement", "")
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: changed from "
|
||||
f"{old_val}{unit} to {new_val}{unit}"
|
||||
)
|
||||
return f"[Home Assistant] {friendly_name}: changed from {old_val}{unit} to {new_val}{unit}"
|
||||
|
||||
if domain == "binary_sensor":
|
||||
return (
|
||||
@@ -344,22 +346,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
)
|
||||
|
||||
if domain in ("light", "switch", "fan"):
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: turned "
|
||||
f"{'on' if new_val == 'on' else 'off'}"
|
||||
)
|
||||
return f"[Home Assistant] {friendly_name}: turned {'on' if new_val == 'on' else 'off'}"
|
||||
|
||||
if domain == "alarm_control_panel":
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: alarm state changed from "
|
||||
f"'{old_val}' to '{new_val}'"
|
||||
)
|
||||
return f"[Home Assistant] {friendly_name}: alarm state changed from '{old_val}' to '{new_val}'"
|
||||
|
||||
# Generic fallback
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name} ({entity_id}): "
|
||||
f"changed from '{old_val}' to '{new_val}'"
|
||||
)
|
||||
return f"[Home Assistant] {friendly_name} ({entity_id}): changed from '{old_val}' to '{new_val}'"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound messaging
|
||||
@@ -369,8 +362,8 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
reply_to: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a notification via HA REST API (persistent_notification.create).
|
||||
|
||||
@@ -384,7 +377,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
}
|
||||
payload = {
|
||||
"title": "Hermes Agent",
|
||||
"message": content[:self.MAX_MESSAGE_LENGTH],
|
||||
"message": content[: self.MAX_MESSAGE_LENGTH],
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -401,20 +394,22 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
else:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
) as resp,
|
||||
):
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
return SendResult(success=False, error="Timeout sending notification to HA")
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
@@ -423,7 +418,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
"""No typing indicator for Home Assistant."""
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Return basic info about the HA event channel."""
|
||||
return {
|
||||
"name": "Home Assistant Events",
|
||||
|
||||
@@ -19,9 +19,9 @@ import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
@@ -32,9 +32,9 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_bytes,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
@@ -59,6 +59,7 @@ _PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
|
||||
if not phone:
|
||||
@@ -68,7 +69,7 @@ def _redact_phone(phone: str) -> str:
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
|
||||
def _parse_comma_list(value: str) -> List[str]:
|
||||
def _parse_comma_list(value: str) -> list[str]:
|
||||
"""Split a comma-separated string into a list, stripping whitespace."""
|
||||
return [v.strip() for v in value.split(",") if v.strip()]
|
||||
|
||||
@@ -110,7 +111,7 @@ def _render_mentions(text: str, mentions: list) -> str:
|
||||
Signal encodes @mentions as the Unicode object replacement character
|
||||
with out-of-band metadata containing the mentioned user's UUID/number.
|
||||
"""
|
||||
if not mentions or "\uFFFC" not in text:
|
||||
if not mentions or "\ufffc" not in text:
|
||||
return text
|
||||
# Sort mentions by start position (reverse) to replace from end to start
|
||||
# so indices don't shift as we replace
|
||||
@@ -121,7 +122,7 @@ def _render_mentions(text: str, mentions: list) -> str:
|
||||
# Use the mention's number or UUID as the replacement
|
||||
identifier = mention.get("number") or mention.get("uuid") or "user"
|
||||
replacement = f"@{identifier}"
|
||||
text = text[:start] + replacement + text[start + length:]
|
||||
text = text[:start] + replacement + text[start + length :]
|
||||
return text
|
||||
|
||||
|
||||
@@ -134,6 +135,7 @@ def check_signal_requirements() -> bool:
|
||||
# Signal Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SignalAdapter(BasePlatformAdapter):
|
||||
"""Signal messenger adapter using signal-cli HTTP daemon."""
|
||||
|
||||
@@ -152,22 +154,25 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self.group_allow_from = set(_parse_comma_list(group_allowed_str))
|
||||
|
||||
# HTTP client
|
||||
self.client: Optional[httpx.AsyncClient] = None
|
||||
self.client: httpx.AsyncClient | None = None
|
||||
|
||||
# Background tasks
|
||||
self._sse_task: Optional[asyncio.Task] = None
|
||||
self._health_monitor_task: Optional[asyncio.Task] = None
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._sse_task: asyncio.Task | None = None
|
||||
self._health_monitor_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._running = False
|
||||
self._last_sse_activity = 0.0
|
||||
self._sse_response: Optional[httpx.Response] = None
|
||||
self._sse_response: httpx.Response | None = None
|
||||
|
||||
# Normalize account for self-message filtering
|
||||
self._account_normalized = self.account.strip()
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
logger.info(
|
||||
"Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url,
|
||||
_redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
@@ -241,7 +246,8 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
logger.debug("Signal SSE: connecting to %s", url)
|
||||
async with self.client.stream(
|
||||
"GET", url,
|
||||
"GET",
|
||||
url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
timeout=None,
|
||||
) as response:
|
||||
@@ -306,9 +312,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if elapsed > HEALTH_CHECK_STALE_THRESHOLD:
|
||||
logger.warning("Signal: SSE idle for %.0fs, checking daemon health", elapsed)
|
||||
try:
|
||||
resp = await self.client.get(
|
||||
f"{self.http_url}/api/v1/check", timeout=10.0
|
||||
)
|
||||
resp = await self.client.get(f"{self.http_url}/api/v1/check", timeout=10.0)
|
||||
if resp.status_code == 200:
|
||||
# Daemon is alive but SSE is idle — update activity to
|
||||
# avoid repeated warnings (connection may just be quiet)
|
||||
@@ -345,11 +349,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
# Extract sender info
|
||||
sender = (
|
||||
envelope_data.get("sourceNumber")
|
||||
or envelope_data.get("sourceUuid")
|
||||
or envelope_data.get("source")
|
||||
)
|
||||
sender = envelope_data.get("sourceNumber") or envelope_data.get("sourceUuid") or envelope_data.get("source")
|
||||
sender_name = envelope_data.get("sourceName", "")
|
||||
sender_uuid = envelope_data.get("sourceUuid", "")
|
||||
|
||||
@@ -367,10 +367,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
# Get data message — also check editMessage (edited messages contain
|
||||
# their updated dataMessage inside editMessage.dataMessage)
|
||||
data_message = (
|
||||
envelope_data.get("dataMessage")
|
||||
or (envelope_data.get("editMessage") or {}).get("dataMessage")
|
||||
)
|
||||
data_message = envelope_data.get("dataMessage") or (envelope_data.get("editMessage") or {}).get("dataMessage")
|
||||
if not data_message:
|
||||
return
|
||||
|
||||
@@ -451,11 +448,11 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
ts_ms = envelope_data.get("timestamp", 0)
|
||||
if ts_ms:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
||||
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=UTC)
|
||||
except (ValueError, OSError):
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
timestamp = datetime.now(tz=UTC)
|
||||
else:
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
timestamp = datetime.now(tz=UTC)
|
||||
|
||||
# Build and dispatch event
|
||||
event = MessageEvent(
|
||||
@@ -468,8 +465,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
logger.debug("Signal: message from %s in %s: %s",
|
||||
_redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
logger.debug("Signal: message from %s in %s: %s", _redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
@@ -479,10 +475,13 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _fetch_attachment(self, attachment_id: str) -> tuple:
|
||||
"""Fetch an attachment via JSON-RPC and cache it. Returns (path, ext)."""
|
||||
result = await self._rpc("getAttachment", {
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
})
|
||||
result = await self._rpc(
|
||||
"getAttachment",
|
||||
{
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
},
|
||||
)
|
||||
|
||||
if not result:
|
||||
return None, ""
|
||||
@@ -547,13 +546,13 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to_message_id: Optional[str] = None,
|
||||
reply_to_message_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a text message."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": text,
|
||||
}
|
||||
@@ -571,7 +570,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send a typing indicator."""
|
||||
params: Dict[str, Any] = {
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
}
|
||||
|
||||
@@ -586,7 +585,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send an image. Supports http(s):// and file:// URLs."""
|
||||
@@ -611,7 +610,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if file_size > SIGNAL_MAX_ATTACHMENT_SIZE:
|
||||
return SendResult(success=False, error=f"Image too large ({file_size} bytes)")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
@@ -631,8 +630,8 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
filename: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment."""
|
||||
@@ -641,7 +640,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if not Path(file_path).exists():
|
||||
return SendResult(success=False, error="File not found")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
@@ -690,7 +689,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
# Chat Info
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a chat/contact."""
|
||||
if chat_id.startswith("group:"):
|
||||
return {
|
||||
@@ -700,10 +699,13 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
}
|
||||
|
||||
# Try to resolve contact name
|
||||
result = await self._rpc("getContact", {
|
||||
"account": self.account,
|
||||
"contactAddress": chat_id,
|
||||
})
|
||||
result = await self._rpc(
|
||||
"getContact",
|
||||
{
|
||||
"account": self.account,
|
||||
"contactAddress": chat_id,
|
||||
},
|
||||
)
|
||||
|
||||
name = chat_id
|
||||
if result and isinstance(result, dict):
|
||||
|
||||
@@ -11,12 +11,13 @@ Uses slack-bolt (Python) with Socket Mode for:
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
SLACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
SLACK_AVAILABLE = False
|
||||
@@ -26,18 +27,17 @@ except ImportError:
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,9 +66,9 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.SLACK)
|
||||
self._app: Optional[AsyncApp] = None
|
||||
self._handler: Optional[AsyncSocketModeHandler] = None
|
||||
self._bot_user_id: Optional[str] = None
|
||||
self._app: AsyncApp | None = None
|
||||
self._handler: AsyncSocketModeHandler | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
@@ -135,8 +135,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
reply_to: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a message to a Slack channel or DM."""
|
||||
if not self._app:
|
||||
@@ -193,8 +193,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file to Slack by uploading it."""
|
||||
if not self._app:
|
||||
@@ -202,6 +202,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
@@ -222,8 +223,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image to Slack by uploading the URL as a file."""
|
||||
if not self._app:
|
||||
@@ -247,7 +248,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Fall back to sending the URL as text
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
@@ -256,8 +257,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an audio file to Slack."""
|
||||
if not self._app:
|
||||
@@ -280,8 +281,8 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a video file to Slack."""
|
||||
if not self._app:
|
||||
@@ -308,9 +309,9 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment to Slack."""
|
||||
if not self._app:
|
||||
@@ -335,7 +336,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to send document: {e}")
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Slack channel."""
|
||||
if not self._app:
|
||||
return {"name": chat_id, "type": "unknown"}
|
||||
@@ -442,9 +443,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
# Download and cache
|
||||
raw_bytes = await self._download_slack_file_bytes(url)
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, original_filename or f"document{ext}"
|
||||
)
|
||||
cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}")
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
@@ -457,7 +456,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if text:
|
||||
text = f"{injection}\n\n{text}"
|
||||
@@ -499,16 +498,20 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
# Map subcommands to gateway commands
|
||||
subcommand_map = {
|
||||
"new": "/reset", "reset": "/reset",
|
||||
"status": "/status", "stop": "/stop",
|
||||
"new": "/reset",
|
||||
"reset": "/reset",
|
||||
"status": "/status",
|
||||
"stop": "/stop",
|
||||
"help": "/help",
|
||||
"model": "/model", "personality": "/personality",
|
||||
"retry": "/retry", "undo": "/undo",
|
||||
"model": "/model",
|
||||
"personality": "/personality",
|
||||
"retry": "/retry",
|
||||
"undo": "/undo",
|
||||
}
|
||||
first_word = text.split()[0] if text else ""
|
||||
if first_word in subcommand_map:
|
||||
# Preserve arguments after the subcommand
|
||||
rest = text[len(first_word):].strip()
|
||||
rest = text[len(first_word) :].strip()
|
||||
text = f"{subcommand_map[first_word]} {rest}".strip() if rest else subcommand_map[first_word]
|
||||
elif text:
|
||||
pass # Treat as a regular question
|
||||
@@ -544,9 +547,11 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
|
||||
@@ -7,24 +7,26 @@ Uses python-telegram-bot library for:
|
||||
- Handling media and commands
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from telegram import Update, Bot, Message
|
||||
from telegram import Bot, Message, Update
|
||||
from telegram.constants import ChatType, ParseMode
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CommandHandler,
|
||||
MessageHandler as TelegramMessageHandler,
|
||||
ContextTypes,
|
||||
filters,
|
||||
)
|
||||
from telegram.constants import ParseMode, ChatType
|
||||
from telegram.ext import (
|
||||
MessageHandler as TelegramMessageHandler,
|
||||
)
|
||||
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
@@ -42,22 +44,24 @@ except ImportError:
|
||||
# don't crash during class definition when the library isn't installed.
|
||||
class _MockContextTypes:
|
||||
DEFAULT_TYPE = Any
|
||||
|
||||
ContextTypes = _MockContextTypes
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_image_from_bytes,
|
||||
)
|
||||
|
||||
|
||||
@@ -68,12 +72,12 @@ def check_telegram_requirements() -> bool:
|
||||
|
||||
# Matches every character that MarkdownV2 requires to be backslash-escaped
|
||||
# when it appears outside a code span or fenced code block.
|
||||
_MDV2_ESCAPE_RE = re.compile(r'([_*\[\]()~`>#\+\-=|{}.!\\])')
|
||||
_MDV2_ESCAPE_RE = re.compile(r"([_*\[\]()~`>#\+\-=|{}.!\\])")
|
||||
|
||||
|
||||
def _escape_mdv2(text: str) -> str:
|
||||
"""Escape Telegram MarkdownV2 special characters with a preceding backslash."""
|
||||
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
|
||||
return _MDV2_ESCAPE_RE.sub(r"\\\1", text)
|
||||
|
||||
|
||||
def _strip_mdv2(text: str) -> str:
|
||||
@@ -83,103 +87,108 @@ def _strip_mdv2(text: str) -> str:
|
||||
doesn't show stray asterisks from header/bold conversion.
|
||||
"""
|
||||
# Remove escape backslashes before special characters
|
||||
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
|
||||
cleaned = re.sub(r"\\([_*\[\]()~`>#\+\-=|{}.!\\])", r"\1", text)
|
||||
# Remove MarkdownV2 bold markers that format_message converted from **bold**
|
||||
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
|
||||
cleaned = re.sub(r"\*([^*]+)\*", r"\1", cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
|
||||
|
||||
Handles:
|
||||
- Receiving messages from users and groups
|
||||
- Sending responses with Telegram markdown
|
||||
- Forum topics (thread_id support)
|
||||
- Media messages
|
||||
"""
|
||||
|
||||
|
||||
# Telegram message limits
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.TELEGRAM)
|
||||
self._app: Optional[Application] = None
|
||||
self._bot: Optional[Bot] = None
|
||||
|
||||
self._app: Application | None = None
|
||||
self._bot: Bot | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Telegram and start polling for updates."""
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
print(f"[{self.name}] python-telegram-bot not installed. Run: pip install python-telegram-bot")
|
||||
return False
|
||||
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Build the application
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
|
||||
# Register handlers
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.TEXT & ~filters.COMMAND,
|
||||
self._handle_text_message
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.COMMAND,
|
||||
self._handle_command
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION),
|
||||
self._handle_location_message
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.PHOTO | filters.VIDEO | filters.AUDIO | filters.VOICE | filters.Document.ALL | filters.Sticker.ALL,
|
||||
self._handle_media_message
|
||||
))
|
||||
|
||||
self._app.add_handler(TelegramMessageHandler(filters.TEXT & ~filters.COMMAND, self._handle_text_message))
|
||||
self._app.add_handler(TelegramMessageHandler(filters.COMMAND, self._handle_command))
|
||||
self._app.add_handler(
|
||||
TelegramMessageHandler(
|
||||
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION), self._handle_location_message
|
||||
)
|
||||
)
|
||||
self._app.add_handler(
|
||||
TelegramMessageHandler(
|
||||
filters.PHOTO
|
||||
| filters.VIDEO
|
||||
| filters.AUDIO
|
||||
| filters.VOICE
|
||||
| filters.Document.ALL
|
||||
| filters.Sticker.ALL,
|
||||
self._handle_media_message,
|
||||
)
|
||||
)
|
||||
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
try:
|
||||
from telegram import BotCommand
|
||||
await self._bot.set_my_commands([
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("model", "Show or change the model"),
|
||||
BotCommand("personality", "Set a personality"),
|
||||
BotCommand("retry", "Retry your last message"),
|
||||
BotCommand("undo", "Remove the last exchange"),
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
])
|
||||
|
||||
await self._bot.set_my_commands(
|
||||
[
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("model", "Show or change the model"),
|
||||
BotCommand("personality", "Set a personality"),
|
||||
BotCommand("retry", "Retry your last message"),
|
||||
BotCommand("undo", "Remove the last exchange"),
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Could not register command menu: {e}")
|
||||
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Connected and polling for updates")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop polling and disconnect."""
|
||||
if self._app:
|
||||
@@ -189,31 +198,27 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
|
||||
self._running = False
|
||||
self._app = None
|
||||
self._bot = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Telegram chat."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
|
||||
message_ids = []
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
@@ -227,7 +232,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
logger.warning(
|
||||
"[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error
|
||||
)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
@@ -241,13 +248,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
raise # Re-raise if not a parse error
|
||||
message_ids.append(str(msg.message_id))
|
||||
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids}
|
||||
raw_response={"message_ids": message_ids},
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -284,18 +291,19 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a native Telegram voice message or audio file."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
# .ogg files -> send as voice (round playable bubble)
|
||||
if audio_path.endswith(".ogg") or audio_path.endswith(".opus"):
|
||||
@@ -317,23 +325,24 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
@@ -350,17 +359,17 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo.
|
||||
|
||||
|
||||
Tries URL-based send first (fast, works for <5MB images).
|
||||
Falls back to downloading and uploading as file (supports up to 10MB).
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
msg = await self._bot.send_photo(
|
||||
@@ -375,11 +384,12 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Fallback: download and upload as file (supports up to 10MB)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(image_url)
|
||||
resp.raise_for_status()
|
||||
image_data = resp.content
|
||||
|
||||
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_data,
|
||||
@@ -391,18 +401,18 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
|
||||
# Final fallback: send URL as text
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an animated GIF natively as a Telegram animation (auto-plays inline)."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
msg = await self._bot.send_animation(
|
||||
chat_id=int(chat_id),
|
||||
@@ -420,21 +430,18 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"""Send typing indicator."""
|
||||
if self._bot:
|
||||
try:
|
||||
await self._bot.send_chat_action(
|
||||
chat_id=int(chat_id),
|
||||
action="typing"
|
||||
)
|
||||
await self._bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Telegram chat."""
|
||||
if not self._bot:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
chat = await self._bot.get_chat(int(chat_id))
|
||||
|
||||
|
||||
chat_type = "dm"
|
||||
if chat.type == ChatType.GROUP:
|
||||
chat_type = "group"
|
||||
@@ -444,7 +451,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
chat_type = "forum"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
|
||||
return {
|
||||
"name": chat.title or chat.full_name or str(chat_id),
|
||||
"type": chat_type,
|
||||
@@ -453,7 +460,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Convert standard markdown to Telegram MarkdownV2 format.
|
||||
@@ -480,38 +487,36 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
|
||||
# 1) Protect fenced code blocks (``` ... ```)
|
||||
text = re.sub(
|
||||
r'(```(?:[^\n]*\n)?[\s\S]*?```)',
|
||||
r"(```(?:[^\n]*\n)?[\s\S]*?```)",
|
||||
lambda m: _ph(m.group(0)),
|
||||
text,
|
||||
)
|
||||
|
||||
# 2) Protect inline code (`...`)
|
||||
text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text)
|
||||
text = re.sub(r"(`[^`]+`)", lambda m: _ph(m.group(0)), text)
|
||||
|
||||
# 3) Convert markdown links – escape the display text; inside the URL
|
||||
# only ')' and '\' need escaping per the MarkdownV2 spec.
|
||||
def _convert_link(m):
|
||||
display = _escape_mdv2(m.group(1))
|
||||
url = m.group(2).replace('\\', '\\\\').replace(')', '\\)')
|
||||
return _ph(f'[{display}]({url})')
|
||||
url = m.group(2).replace("\\", "\\\\").replace(")", "\\)")
|
||||
return _ph(f"[{display}]({url})")
|
||||
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', _convert_link, text)
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _convert_link, text)
|
||||
|
||||
# 4) Convert markdown headers (## Title) → bold *Title*
|
||||
def _convert_header(m):
|
||||
inner = m.group(1).strip()
|
||||
# Strip redundant bold markers that may appear inside a header
|
||||
inner = re.sub(r'\*\*(.+?)\*\*', r'\1', inner)
|
||||
return _ph(f'*{_escape_mdv2(inner)}*')
|
||||
inner = re.sub(r"\*\*(.+?)\*\*", r"\1", inner)
|
||||
return _ph(f"*{_escape_mdv2(inner)}*")
|
||||
|
||||
text = re.sub(
|
||||
r'^#{1,6}\s+(.+)$', _convert_header, text, flags=re.MULTILINE
|
||||
)
|
||||
text = re.sub(r"^#{1,6}\s+(.+)$", _convert_header, text, flags=re.MULTILINE)
|
||||
|
||||
# 5) Convert bold: **text** → *text* (MarkdownV2 bold)
|
||||
text = re.sub(
|
||||
r'\*\*(.+?)\*\*',
|
||||
lambda m: _ph(f'*{_escape_mdv2(m.group(1))}*'),
|
||||
r"\*\*(.+?)\*\*",
|
||||
lambda m: _ph(f"*{_escape_mdv2(m.group(1))}*"),
|
||||
text,
|
||||
)
|
||||
|
||||
@@ -519,8 +524,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# [^*\n]+ prevents matching across newlines (which would corrupt
|
||||
# bullet lists using * markers and multi-line content).
|
||||
text = re.sub(
|
||||
r'\*([^*\n]+)\*',
|
||||
lambda m: _ph(f'_{_escape_mdv2(m.group(1))}_'),
|
||||
r"\*([^*\n]+)\*",
|
||||
lambda m: _ph(f"_{_escape_mdv2(m.group(1))}_"),
|
||||
text,
|
||||
)
|
||||
|
||||
@@ -533,23 +538,23 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text = text.replace(key, placeholders[key])
|
||||
|
||||
return text
|
||||
|
||||
|
||||
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming text messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.TEXT)
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming command messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND)
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming location/venue pin messages."""
|
||||
if not update.message:
|
||||
@@ -589,9 +594,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
"""Handle incoming media messages, downloading images to local cache."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
|
||||
msg = update.message
|
||||
|
||||
|
||||
# Determine media type
|
||||
if msg.sticker:
|
||||
msg_type = MessageType.STICKER
|
||||
@@ -607,19 +612,19 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
msg_type = MessageType.DOCUMENT
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
|
||||
event = self._build_message_event(msg, msg_type)
|
||||
|
||||
|
||||
# Add caption as text
|
||||
if msg.caption:
|
||||
event.text = msg.caption
|
||||
|
||||
|
||||
# Handle stickers: describe via vision tool with caching
|
||||
if msg.sticker:
|
||||
await self._handle_sticker(msg, event)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
|
||||
# Download photo to local image cache so the vision tool can access it
|
||||
# even after Telegram's ephemeral file URLs expire (~1 hour).
|
||||
if msg.photo:
|
||||
@@ -643,7 +648,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
print(f"[Telegram] Cached user photo: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache photo: {e}", flush=True)
|
||||
|
||||
|
||||
# Download voice/audio messages to cache for STT transcription
|
||||
if msg.voice:
|
||||
try:
|
||||
@@ -685,10 +690,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Check if supported
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys()))
|
||||
event.text = (
|
||||
f"Unsupported document type '{ext or 'unknown'}'. "
|
||||
f"Supported types: {supported_list}"
|
||||
)
|
||||
event.text = f"Unsupported document type '{ext or 'unknown'}'. Supported types: {supported_list}"
|
||||
print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
@@ -696,10 +698,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Check file size (Telegram Bot API limit: 20 MB)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not doc.file_size or doc.file_size > MAX_DOC_BYTES:
|
||||
event.text = (
|
||||
"The document is too large or its size could not be verified. "
|
||||
"Maximum: 20 MB."
|
||||
)
|
||||
event.text = "The document is too large or its size could not be verified. Maximum: 20 MB."
|
||||
print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
@@ -720,20 +719,20 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if event.text:
|
||||
event.text = f"{injection}\n\n{event.text}"
|
||||
else:
|
||||
event.text = injection
|
||||
except UnicodeDecodeError:
|
||||
print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
|
||||
print("[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache document: {e}", flush=True)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None:
|
||||
"""
|
||||
Describe a Telegram sticker via vision analysis, with caching.
|
||||
@@ -743,11 +742,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
a placeholder noting the emoji.
|
||||
"""
|
||||
from gateway.sticker_cache import (
|
||||
get_cached_description,
|
||||
cache_sticker_description,
|
||||
build_sticker_injection,
|
||||
build_animated_sticker_injection,
|
||||
STICKER_VISION_PROMPT,
|
||||
build_animated_sticker_injection,
|
||||
build_sticker_injection,
|
||||
cache_sticker_description,
|
||||
get_cached_description,
|
||||
)
|
||||
|
||||
sticker = msg.sticker
|
||||
@@ -775,9 +774,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=".webp")
|
||||
print(f"[Telegram] Analyzing sticker: {cached_path}", flush=True)
|
||||
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
import json as _json
|
||||
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
|
||||
result_json = await vision_analyze_tool(
|
||||
image_url=cached_path,
|
||||
user_prompt=STICKER_VISION_PROMPT,
|
||||
@@ -792,27 +792,29 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Vision failed -- use emoji as fallback
|
||||
event.text = build_sticker_injection(
|
||||
f"a sticker with emoji {emoji}" if emoji else "a sticker",
|
||||
emoji, set_name,
|
||||
emoji,
|
||||
set_name,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Sticker analysis error: {e}", flush=True)
|
||||
event.text = build_sticker_injection(
|
||||
f"a sticker with emoji {emoji}" if emoji else "a sticker",
|
||||
emoji, set_name,
|
||||
emoji,
|
||||
set_name,
|
||||
)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message."""
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
|
||||
# Determine chat type
|
||||
chat_type = "dm"
|
||||
if chat.type in (ChatType.GROUP, ChatType.SUPERGROUP):
|
||||
chat_type = "group"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(chat.id),
|
||||
@@ -822,7 +824,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
user_name=user.full_name if user else None,
|
||||
thread_id=str(message.message_thread_id) if message.message_thread_id else None,
|
||||
)
|
||||
|
||||
|
||||
return MessageEvent(
|
||||
text=message.text or "",
|
||||
message_type=msg_type,
|
||||
|
||||
@@ -16,7 +16,6 @@ with different backends via a bridge pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
@@ -24,7 +23,7 @@ import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +35,9 @@ def _kill_port_process(port: int) -> None:
|
||||
# Use netstat to find the PID bound to this port, then taskkill
|
||||
result = subprocess.run(
|
||||
["netstat", "-ano", "-p", "TCP"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
for line in result.stdout.splitlines():
|
||||
parts = line.split()
|
||||
@@ -46,24 +47,29 @@ def _kill_port_process(port: int) -> None:
|
||||
try:
|
||||
subprocess.run(
|
||||
["taskkill", "/PID", parts[4], "/F"],
|
||||
capture_output=True, timeout=5,
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
except subprocess.SubprocessError:
|
||||
pass
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["fuser", f"{port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
subprocess.run(
|
||||
["fuser", "-k", f"{port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
@@ -72,25 +78,20 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
||||
def check_whatsapp_requirements() -> bool:
|
||||
"""
|
||||
Check if WhatsApp dependencies are available.
|
||||
|
||||
|
||||
WhatsApp requires a Node.js bridge for most implementations.
|
||||
"""
|
||||
# Check for Node.js
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=5)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
@@ -99,62 +100,61 @@ def check_whatsapp_requirements() -> bool:
|
||||
class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
WhatsApp adapter.
|
||||
|
||||
|
||||
This implementation uses a simple HTTP bridge pattern where:
|
||||
1. A Node.js process runs the WhatsApp Web client
|
||||
2. Messages are forwarded via HTTP/IPC to this Python adapter
|
||||
3. Responses are sent back through the bridge
|
||||
|
||||
|
||||
The actual Node.js bridge implementation can vary:
|
||||
- whatsapp-web.js based
|
||||
- Baileys based
|
||||
- Business API based
|
||||
|
||||
|
||||
Configuration:
|
||||
- bridge_script: Path to the Node.js bridge script
|
||||
- bridge_port: Port for HTTP communication (default: 3000)
|
||||
- session_path: Path to store WhatsApp session data
|
||||
"""
|
||||
|
||||
|
||||
# WhatsApp message limits
|
||||
MAX_MESSAGE_LENGTH = 65536 # WhatsApp allows longer messages
|
||||
|
||||
|
||||
# Default bridge location relative to the hermes-agent install
|
||||
_DEFAULT_BRIDGE_DIR = Path(__file__).resolve().parents[2] / "scripts" / "whatsapp-bridge"
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WHATSAPP)
|
||||
self._bridge_process: Optional[subprocess.Popen] = None
|
||||
self._bridge_process: subprocess.Popen | None = None
|
||||
self._bridge_port: int = config.extra.get("bridge_port", 3000)
|
||||
self._bridge_script: Optional[str] = config.extra.get(
|
||||
self._bridge_script: str | None = config.extra.get(
|
||||
"bridge_script",
|
||||
str(self._DEFAULT_BRIDGE_DIR / "bridge.js"),
|
||||
)
|
||||
self._session_path: Path = Path(config.extra.get(
|
||||
"session_path",
|
||||
Path.home() / ".hermes" / "whatsapp" / "session"
|
||||
))
|
||||
self._session_path: Path = Path(
|
||||
config.extra.get("session_path", Path.home() / ".hermes" / "whatsapp" / "session")
|
||||
)
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Optional[Path] = None
|
||||
|
||||
self._bridge_log: Path | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Start the WhatsApp bridge.
|
||||
|
||||
|
||||
This launches the Node.js bridge process and waits for it to be ready.
|
||||
"""
|
||||
if not check_whatsapp_requirements():
|
||||
logger.warning("[%s] Node.js not found. WhatsApp requires Node.js.", self.name)
|
||||
return False
|
||||
|
||||
|
||||
bridge_path = Path(self._bridge_script)
|
||||
if not bridge_path.exists():
|
||||
logger.warning("[%s] Bridge script not found: %s", self.name, bridge_path)
|
||||
return False
|
||||
|
||||
|
||||
logger.info("[%s] Bridge found at %s", self.name, bridge_path)
|
||||
|
||||
|
||||
# Auto-install npm dependencies if node_modules doesn't exist
|
||||
bridge_dir = bridge_path.parent
|
||||
if not (bridge_dir / "node_modules").exists():
|
||||
@@ -174,16 +174,17 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to install dependencies: {e}")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Ensure session directory exists
|
||||
self._session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Kill any orphaned bridge from a previous gateway run
|
||||
_kill_port_process(self._bridge_port)
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
# Start the bridge process in its own process group.
|
||||
# Route output to a log file so QR codes, errors, and reconnection
|
||||
# messages are preserved for troubleshooting.
|
||||
@@ -195,19 +196,23 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
[
|
||||
"node",
|
||||
str(bridge_path),
|
||||
"--port", str(self._bridge_port),
|
||||
"--session", str(self._session_path),
|
||||
"--mode", whatsapp_mode,
|
||||
"--port",
|
||||
str(self._bridge_port),
|
||||
"--session",
|
||||
str(self._session_path),
|
||||
"--mode",
|
||||
whatsapp_mode,
|
||||
],
|
||||
stdout=bridge_log_fh,
|
||||
stderr=bridge_log_fh,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
|
||||
# Wait for the bridge to connect to WhatsApp.
|
||||
# Phase 1: wait for the HTTP server to come up (up to 15s).
|
||||
# Phase 2: wait for WhatsApp status: connected (up to 15s more).
|
||||
import aiohttp
|
||||
|
||||
http_ready = False
|
||||
data = {}
|
||||
for attempt in range(15):
|
||||
@@ -218,17 +223,18 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
http_ready = True
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
http_ready = True
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -237,7 +243,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Check log: {self._bridge_log}")
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
||||
|
||||
# Phase 2: HTTP is up but WhatsApp may still be connecting.
|
||||
# Give it more time to authenticate with saved credentials.
|
||||
if data.get("status") != "connected":
|
||||
@@ -250,16 +256,17 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
@@ -268,19 +275,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] ⚠ WhatsApp not connected after 30s")
|
||||
print(f"[{self.name}] Bridge log: {self._bridge_log}")
|
||||
print(f"[{self.name}] If session expired, re-pair: hermes whatsapp")
|
||||
|
||||
|
||||
# Start message polling task
|
||||
asyncio.create_task(self._poll_messages())
|
||||
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
||||
|
||||
def _close_bridge_log(self) -> None:
|
||||
"""Close the bridge log file handle if open."""
|
||||
if self._bridge_log_fh:
|
||||
@@ -296,6 +303,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
try:
|
||||
# Kill the entire process group so child node processes die too
|
||||
import signal
|
||||
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
self._bridge_process.terminate()
|
||||
@@ -314,29 +322,25 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._bridge_process.kill()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error stopping bridge: {e}")
|
||||
|
||||
|
||||
# Also kill any orphaned bridge processes on our port
|
||||
_kill_port_process(self._bridge_port)
|
||||
|
||||
|
||||
self._running = False
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"chatId": chat_id,
|
||||
@@ -344,28 +348,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
}
|
||||
if reply_to:
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
f"http://localhost:{self._bridge_port}/send", json=payload, timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data
|
||||
)
|
||||
return SendResult(success=True, message_id=data.get("messageId"), raw_response=data)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
|
||||
except ImportError:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error="aiohttp not installed. Run: pip install aiohttp"
|
||||
)
|
||||
return SendResult(success=False, error="aiohttp not installed. Run: pip install aiohttp")
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -380,21 +375,24 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
f"http://localhost:{self._bridge_port}/edit",
|
||||
json={
|
||||
"chatId": chat_id,
|
||||
"messageId": message_id,
|
||||
"message": content,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
@@ -403,8 +401,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
media_type: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send any media file via bridge /send-media endpoint."""
|
||||
if not self._running:
|
||||
@@ -415,7 +413,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
if not os.path.exists(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
payload: dict[str, Any] = {
|
||||
"chatId": chat_id,
|
||||
"filePath": file_path,
|
||||
"mediaType": media_type,
|
||||
@@ -425,22 +423,24 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
if file_name:
|
||||
payload["fileName"] = file_name
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
f"http://localhost:{self._bridge_port}/send-media",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
@@ -449,8 +449,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Download image URL to cache, send natively via bridge."""
|
||||
try:
|
||||
@@ -463,8 +463,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively via bridge."""
|
||||
return await self._send_media_to_bridge(chat_id, image_path, "image", caption)
|
||||
@@ -473,8 +473,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a video natively via bridge — plays inline in WhatsApp."""
|
||||
return await self._send_media_to_bridge(chat_id, video_path, "video", caption)
|
||||
@@ -483,13 +483,16 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file as a downloadable attachment via bridge."""
|
||||
return await self._send_media_to_bridge(
|
||||
chat_id, file_path, "document", caption,
|
||||
chat_id,
|
||||
file_path,
|
||||
"document",
|
||||
caption,
|
||||
file_name or os.path.basename(file_path),
|
||||
)
|
||||
|
||||
@@ -497,44 +500,45 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
f"http://localhost:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}", timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e)
|
||||
|
||||
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
|
||||
async def _poll_messages(self) -> None:
|
||||
"""Poll the bridge for incoming messages."""
|
||||
try:
|
||||
@@ -542,29 +546,30 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, message polling disabled")
|
||||
return
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages", timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Poll error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
await asyncio.sleep(1) # Poll interval
|
||||
|
||||
async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
|
||||
async def _build_message_event(self, data: dict[str, Any]) -> MessageEvent | None:
|
||||
"""Build a MessageEvent from bridge message data, downloading images to cache."""
|
||||
try:
|
||||
# Determine message type
|
||||
@@ -579,11 +584,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
msg_type = MessageType.VOICE
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
|
||||
# Determine chat type
|
||||
is_group = data.get("isGroup", False)
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=data.get("chatId", ""),
|
||||
@@ -592,7 +597,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
user_id=data.get("senderId"),
|
||||
user_name=data.get("senderName"),
|
||||
)
|
||||
|
||||
|
||||
# Download image media URLs to the local cache so the vision tool
|
||||
# can access them reliably regardless of URL expiration.
|
||||
raw_urls = data.get("mediaUrls", [])
|
||||
@@ -622,7 +627,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
cached_urls.append(url)
|
||||
media_types.append("unknown")
|
||||
|
||||
|
||||
return MessageEvent(
|
||||
text=data.get("body", ""),
|
||||
message_type=msg_type,
|
||||
@@ -635,4 +640,3 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
847
gateway/run.py
847
gateway/run.py
File diff suppressed because it is too large
Load Diff
@@ -8,22 +8,20 @@ Handles:
|
||||
- Dynamic system prompt injection (agent knows its context)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .config import (
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
SessionResetPolicy,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
)
|
||||
|
||||
|
||||
@@ -31,29 +29,30 @@ from .config import (
|
||||
class SessionSource:
|
||||
"""
|
||||
Describes where a message originated from.
|
||||
|
||||
|
||||
This information is used to:
|
||||
1. Route responses back to the right place
|
||||
2. Inject context into the system prompt
|
||||
3. Track origin for cron job delivery
|
||||
"""
|
||||
|
||||
platform: Platform
|
||||
chat_id: str
|
||||
chat_name: Optional[str] = None
|
||||
chat_name: str | None = None
|
||||
chat_type: str = "dm" # "dm", "group", "channel", "thread"
|
||||
user_id: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
thread_id: Optional[str] = None # For forum topics, Discord threads, etc.
|
||||
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
|
||||
user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number)
|
||||
chat_id_alt: Optional[str] = None # Signal group internal ID
|
||||
|
||||
user_id: str | None = None
|
||||
user_name: str | None = None
|
||||
thread_id: str | None = None # For forum topics, Discord threads, etc.
|
||||
chat_topic: str | None = None # Channel topic/description (Discord, Slack)
|
||||
user_id_alt: str | None = None # Signal UUID (alternative to phone number)
|
||||
chat_id_alt: str | None = None # Signal group internal ID
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Human-readable description of the source."""
|
||||
if self.platform == Platform.LOCAL:
|
||||
return "CLI terminal"
|
||||
|
||||
|
||||
parts = []
|
||||
if self.chat_type == "dm":
|
||||
parts.append(f"DM with {self.user_name or self.user_id or 'user'}")
|
||||
@@ -63,13 +62,13 @@ class SessionSource:
|
||||
parts.append(f"channel: {self.chat_name or self.chat_id}")
|
||||
else:
|
||||
parts.append(self.chat_name or self.chat_id)
|
||||
|
||||
|
||||
if self.thread_id:
|
||||
parts.append(f"thread: {self.thread_id}")
|
||||
|
||||
|
||||
return ", ".join(parts)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
d = {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
@@ -85,9 +84,9 @@ class SessionSource:
|
||||
if self.chat_id_alt:
|
||||
d["chat_id_alt"] = self.chat_id_alt
|
||||
return d
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionSource":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionSource":
|
||||
return cls(
|
||||
platform=Platform(data["platform"]),
|
||||
chat_id=str(data["chat_id"]),
|
||||
@@ -100,7 +99,7 @@ class SessionSource:
|
||||
user_id_alt=data.get("user_id_alt"),
|
||||
chat_id_alt=data.get("chat_id_alt"),
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def local_cli(cls) -> "SessionSource":
|
||||
"""Create a source representing the local CLI."""
|
||||
@@ -116,29 +115,28 @@ class SessionSource:
|
||||
class SessionContext:
|
||||
"""
|
||||
Full context for a session, used for dynamic system prompt injection.
|
||||
|
||||
|
||||
The agent receives this information to understand:
|
||||
- Where messages are coming from
|
||||
- What platforms are available
|
||||
- Where it can deliver scheduled task outputs
|
||||
"""
|
||||
|
||||
source: SessionSource
|
||||
connected_platforms: List[Platform]
|
||||
home_channels: Dict[Platform, HomeChannel]
|
||||
|
||||
connected_platforms: list[Platform]
|
||||
home_channels: dict[Platform, HomeChannel]
|
||||
|
||||
# Session metadata
|
||||
session_key: str = ""
|
||||
session_id: str = ""
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"source": self.source.to_dict(),
|
||||
"connected_platforms": [p.value for p in self.connected_platforms],
|
||||
"home_channels": {
|
||||
p.value: hc.to_dict() for p, hc in self.home_channels.items()
|
||||
},
|
||||
"home_channels": {p.value: hc.to_dict() for p, hc in self.home_channels.items()},
|
||||
"session_key": self.session_key,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
@@ -149,7 +147,7 @@ class SessionContext:
|
||||
def build_session_context_prompt(context: SessionContext) -> str:
|
||||
"""
|
||||
Build the dynamic system prompt section that tells the agent about its context.
|
||||
|
||||
|
||||
This is injected into the system prompt so the agent knows:
|
||||
- Where messages are coming from
|
||||
- What platforms are connected
|
||||
@@ -159,14 +157,14 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
"## Current Session Context",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
# Source info
|
||||
platform_name = context.source.platform.value.title()
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
||||
else:
|
||||
lines.append(f"**Source:** {platform_name} ({context.source.description})")
|
||||
|
||||
|
||||
# Channel topic (if available - provides context about the channel's purpose)
|
||||
if context.source.chat_topic:
|
||||
lines.append(f"**Channel Topic:** {context.source.chat_topic}")
|
||||
@@ -176,43 +174,43 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
lines.append(f"**User:** {context.source.user_name}")
|
||||
elif context.source.user_id:
|
||||
lines.append(f"**User ID:** {context.source.user_id}")
|
||||
|
||||
|
||||
# Connected platforms
|
||||
platforms_list = ["local (files on this machine)"]
|
||||
for p in context.connected_platforms:
|
||||
if p != Platform.LOCAL:
|
||||
platforms_list.append(f"{p.value}: Connected ✓")
|
||||
|
||||
|
||||
lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}")
|
||||
|
||||
|
||||
# Home channels
|
||||
if context.home_channels:
|
||||
lines.append("")
|
||||
lines.append("**Home Channels (default destinations):**")
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f" - {platform.value}: {home.name} (ID: {home.chat_id})")
|
||||
|
||||
|
||||
# Delivery options for scheduled tasks
|
||||
lines.append("")
|
||||
lines.append("**Delivery options for scheduled tasks:**")
|
||||
|
||||
|
||||
# Origin delivery
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append("- `\"origin\"` → Local output (saved to files)")
|
||||
lines.append('- `"origin"` → Local output (saved to files)')
|
||||
else:
|
||||
lines.append(f"- `\"origin\"` → Back to this chat ({context.source.chat_name or context.source.chat_id})")
|
||||
|
||||
lines.append(f'- `"origin"` → Back to this chat ({context.source.chat_name or context.source.chat_id})')
|
||||
|
||||
# Local always available
|
||||
lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)")
|
||||
|
||||
lines.append('- `"local"` → Save to local files only (~/.hermes/cron/output/)')
|
||||
|
||||
# Platform home channels
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f"- `\"{platform.value}\"` → Home channel ({home.name})")
|
||||
|
||||
lines.append(f'- `"{platform.value}"` → Home channel ({home.name})')
|
||||
|
||||
# Note about explicit targeting
|
||||
lines.append("")
|
||||
lines.append("*For explicit targeting, use `\"platform:chat_id\"` format if the user provides a specific chat ID.*")
|
||||
|
||||
lines.append('*For explicit targeting, use `"platform:chat_id"` format if the user provides a specific chat ID.*')
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -220,32 +218,33 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
||||
class SessionEntry:
|
||||
"""
|
||||
Entry in the session store.
|
||||
|
||||
|
||||
Maps a session key to its current session ID and metadata.
|
||||
"""
|
||||
|
||||
session_key: str
|
||||
session_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# Origin metadata for delivery routing
|
||||
origin: Optional[SessionSource] = None
|
||||
|
||||
origin: SessionSource | None = None
|
||||
|
||||
# Display metadata
|
||||
display_name: Optional[str] = None
|
||||
platform: Optional[Platform] = None
|
||||
display_name: str | None = None
|
||||
platform: Platform | None = None
|
||||
chat_type: str = "dm"
|
||||
|
||||
|
||||
# Token tracking
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
# Set when a session was created because the previous one expired;
|
||||
# consumed once by the message handler to inject a notice into context
|
||||
was_auto_reset: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {
|
||||
"session_key": self.session_key,
|
||||
"session_id": self.session_id,
|
||||
@@ -261,20 +260,20 @@ class SessionEntry:
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionEntry":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionEntry":
|
||||
origin = None
|
||||
if "origin" in data and data["origin"]:
|
||||
origin = SessionSource.from_dict(data["origin"])
|
||||
|
||||
|
||||
platform = None
|
||||
if data.get("platform"):
|
||||
try:
|
||||
platform = Platform(data["platform"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
return cls(
|
||||
session_key=data["session_key"],
|
||||
session_id=data["session_id"],
|
||||
@@ -307,66 +306,65 @@ def build_session_key(source: SessionSource) -> str:
|
||||
class SessionStore:
|
||||
"""
|
||||
Manages session storage and retrieval.
|
||||
|
||||
|
||||
Uses SQLite (via SessionDB) for session metadata and message transcripts.
|
||||
Falls back to legacy JSONL files if SQLite is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, sessions_dir: Path, config: GatewayConfig,
|
||||
has_active_processes_fn=None,
|
||||
on_auto_reset=None):
|
||||
|
||||
def __init__(self, sessions_dir: Path, config: GatewayConfig, has_active_processes_fn=None, on_auto_reset=None):
|
||||
self.sessions_dir = sessions_dir
|
||||
self.config = config
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
self._entries: dict[str, SessionEntry] = {}
|
||||
self._loaded = False
|
||||
self._has_active_processes_fn = has_active_processes_fn
|
||||
# on_auto_reset is deprecated — memory flush now runs proactively
|
||||
# via the background session expiry watcher in GatewayRunner.
|
||||
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
|
||||
|
||||
|
||||
# Initialize SQLite session database
|
||||
self._db = None
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
|
||||
self._db = SessionDB()
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: SQLite session store unavailable, falling back to JSONL: {e}")
|
||||
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
"""Load sessions index from disk if not already loaded."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
|
||||
if sessions_file.exists():
|
||||
try:
|
||||
with open(sessions_file, "r", encoding="utf-8") as f:
|
||||
with open(sessions_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for key, entry_data in data.items():
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
||||
|
||||
|
||||
self._loaded = True
|
||||
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save sessions index to disk (kept for session key -> ID mapping)."""
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
|
||||
data = {key: entry.to_dict() for key, entry in self._entries.items()}
|
||||
with open(sessions_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def _generate_session_key(self, source: SessionSource) -> str:
|
||||
"""Generate a session key from a source."""
|
||||
return build_session_key(source)
|
||||
|
||||
|
||||
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||
"""Check if a session has expired based on its reset policy.
|
||||
|
||||
|
||||
Works from the entry alone — no SessionSource needed.
|
||||
Used by the background expiry watcher to proactively flush memories.
|
||||
Sessions with active background processes are never considered expired.
|
||||
@@ -393,7 +391,9 @@ class SessionStore:
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0, second=0, microsecond=0,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
@@ -405,7 +405,7 @@ class SessionStore:
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
|
||||
|
||||
Sessions with active background processes are never reset.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
@@ -413,36 +413,28 @@ class SessionStore:
|
||||
if self._has_active_processes_fn(session_key):
|
||||
return False
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=source.platform,
|
||||
session_type=source.chat_type
|
||||
)
|
||||
|
||||
policy = self.config.get_reset_policy(platform=source.platform, session_type=source.chat_type)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0
|
||||
)
|
||||
today_reset = now.replace(hour=policy.at_hour, minute=0, second=0, microsecond=0)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
|
||||
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def has_any_sessions(self) -> bool:
|
||||
"""Check if any sessions have ever been created (across all platforms).
|
||||
|
||||
@@ -463,26 +455,22 @@ class SessionStore:
|
||||
# This covers the rare case where the DB is unavailable.
|
||||
self._ensure_loaded()
|
||||
return len(self._entries) > 1
|
||||
|
||||
def get_or_create_session(
|
||||
self,
|
||||
source: SessionSource,
|
||||
force_new: bool = False
|
||||
) -> SessionEntry:
|
||||
|
||||
def get_or_create_session(self, source: SessionSource, force_new: bool = False) -> SessionEntry:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
|
||||
Evaluates reset policy to determine if the existing session is stale.
|
||||
Creates a session record in SQLite when a new session starts.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
|
||||
session_key = self._generate_session_key(source)
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
|
||||
if not self._should_reset(entry, source):
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
@@ -500,10 +488,10 @@ class SessionStore:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
else:
|
||||
was_auto_reset = False
|
||||
|
||||
|
||||
# Create new session
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
@@ -515,10 +503,10 @@ class SessionStore:
|
||||
chat_type=source.chat_type,
|
||||
was_auto_reset=was_auto_reset,
|
||||
)
|
||||
|
||||
|
||||
self._entries[session_key] = entry
|
||||
self._save()
|
||||
|
||||
|
||||
# Create session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
@@ -529,18 +517,13 @@ class SessionStore:
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to create SQLite session: {e}")
|
||||
|
||||
|
||||
return entry
|
||||
|
||||
def update_session(
|
||||
self,
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0
|
||||
) -> None:
|
||||
|
||||
def update_session(self, session_key: str, input_tokens: int = 0, output_tokens: int = 0) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
|
||||
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = datetime.now()
|
||||
@@ -548,34 +531,32 @@ class SessionStore:
|
||||
entry.output_tokens += output_tokens
|
||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||
self._save()
|
||||
|
||||
|
||||
if self._db:
|
||||
try:
|
||||
self._db.update_token_counts(
|
||||
entry.session_id, input_tokens, output_tokens
|
||||
)
|
||||
self._db.update_token_counts(entry.session_id, input_tokens, output_tokens)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
|
||||
def reset_session(self, session_key: str) -> SessionEntry | None:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
self._ensure_loaded()
|
||||
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
|
||||
|
||||
# End old session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(old_entry.session_id, "session_reset")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
|
||||
now = datetime.now()
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
@@ -586,10 +567,10 @@ class SessionStore:
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
|
||||
|
||||
# Create new session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
@@ -600,10 +581,10 @@ class SessionStore:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
|
||||
return new_entry
|
||||
|
||||
def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]:
|
||||
def switch_session(self, session_key: str, target_session_id: str) -> SessionEntry | None:
|
||||
"""Switch a session key to point at an existing session ID.
|
||||
|
||||
Used by ``/resume`` to restore a previously-named session.
|
||||
@@ -645,25 +626,25 @@ class SessionStore:
|
||||
self._save()
|
||||
return new_entry
|
||||
|
||||
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
|
||||
def list_sessions(self, active_minutes: int | None = None) -> list[SessionEntry]:
|
||||
"""List all sessions, optionally filtered by activity."""
|
||||
self._ensure_loaded()
|
||||
|
||||
|
||||
entries = list(self._entries.values())
|
||||
|
||||
|
||||
if active_minutes is not None:
|
||||
cutoff = datetime.now() - timedelta(minutes=active_minutes)
|
||||
entries = [e for e in entries if e.updated_at >= cutoff]
|
||||
|
||||
|
||||
entries.sort(key=lambda e: e.updated_at, reverse=True)
|
||||
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def get_transcript_path(self, session_id: str) -> Path:
|
||||
"""Get the path to a session's legacy transcript file."""
|
||||
return self.sessions_dir / f"{session_id}.jsonl"
|
||||
|
||||
def append_to_transcript(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
|
||||
def append_to_transcript(self, session_id: str, message: dict[str, Any]) -> None:
|
||||
"""Append a message to a session's transcript (SQLite + legacy JSONL)."""
|
||||
# Write to SQLite
|
||||
if self._db:
|
||||
@@ -678,15 +659,15 @@ class SessionStore:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
|
||||
# Also write legacy JSONL (keeps existing tooling working during transition)
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: list[dict[str, Any]]) -> None:
|
||||
"""Replace the entire transcript for a session with new messages.
|
||||
|
||||
|
||||
Used by /retry, /undo, and /compress to persist modified conversation history.
|
||||
Rewrites both SQLite and legacy JSONL storage.
|
||||
"""
|
||||
@@ -705,14 +686,14 @@ class SessionStore:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
|
||||
# JSONL: overwrite the file
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "w", encoding="utf-8") as f:
|
||||
for msg in messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
def load_transcript(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""Load all messages from a session's transcript."""
|
||||
# Try SQLite first
|
||||
if self._db:
|
||||
@@ -722,51 +703,49 @@ class SessionStore:
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.debug("Could not load messages from DB: %s", e)
|
||||
|
||||
|
||||
# Fall back to legacy JSONL
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
|
||||
|
||||
if not transcript_path.exists():
|
||||
return []
|
||||
|
||||
|
||||
messages = []
|
||||
with open(transcript_path, "r", encoding="utf-8") as f:
|
||||
with open(transcript_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
messages.append(json.loads(line))
|
||||
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_session_context(
|
||||
source: SessionSource,
|
||||
config: GatewayConfig,
|
||||
session_entry: Optional[SessionEntry] = None
|
||||
source: SessionSource, config: GatewayConfig, session_entry: SessionEntry | None = None
|
||||
) -> SessionContext:
|
||||
"""
|
||||
Build a full session context from a source and config.
|
||||
|
||||
|
||||
This is used to inject context into the agent's system prompt.
|
||||
"""
|
||||
connected = config.get_connected_platforms()
|
||||
|
||||
|
||||
home_channels = {}
|
||||
for platform in connected:
|
||||
home = config.get_home_channel(platform)
|
||||
if home:
|
||||
home_channels[platform] = home
|
||||
|
||||
|
||||
context = SessionContext(
|
||||
source=source,
|
||||
connected_platforms=connected,
|
||||
home_channels=home_channels,
|
||||
)
|
||||
|
||||
|
||||
if session_entry:
|
||||
context.session_key = session_entry.session_key
|
||||
context.session_id = session_entry.session_id
|
||||
context.created_at = session_entry.created_at
|
||||
context.updated_at = session_entry.updated_at
|
||||
|
||||
|
||||
return context
|
||||
|
||||
@@ -13,7 +13,6 @@ concurrently under distinct configurations).
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _get_pid_path() -> Path:
|
||||
@@ -37,7 +36,7 @@ def remove_pid_file() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
def get_running_pid() -> int | None:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
Checks the PID file and verifies the process is actually alive.
|
||||
|
||||
@@ -12,8 +12,6 @@ import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
CACHE_PATH = Path(os.path.expanduser("~/.hermes/sticker_cache.json"))
|
||||
|
||||
@@ -43,7 +41,7 @@ def _save_cache(cache: dict) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_cached_description(file_unique_id: str) -> Optional[dict]:
|
||||
def get_cached_description(file_unique_id: str) -> dict | None:
|
||||
"""
|
||||
Look up a cached sticker description.
|
||||
|
||||
@@ -92,11 +90,11 @@ def build_sticker_injection(
|
||||
"""
|
||||
context = ""
|
||||
if set_name and emoji:
|
||||
context = f" {emoji} from \"{set_name}\""
|
||||
context = f' {emoji} from "{set_name}"'
|
||||
elif emoji:
|
||||
context = f" {emoji}"
|
||||
|
||||
return f"[The user sent a sticker{context}~ It shows: \"{description}\" (=^.w.^=)]"
|
||||
return f'[The user sent a sticker{context}~ It shows: "{description}" (=^.w.^=)]'
|
||||
|
||||
|
||||
def build_animated_sticker_injection(emoji: str = "") -> str:
|
||||
|
||||
@@ -5,7 +5,7 @@ Provides subcommands for:
|
||||
- hermes chat - Interactive chat (same as ./hermes)
|
||||
- hermes gateway - Run gateway in foreground
|
||||
- hermes gateway start - Start gateway service
|
||||
- hermes gateway stop - Stop gateway service
|
||||
- hermes gateway stop - Stop gateway service
|
||||
- hermes setup - Interactive setup wizard
|
||||
- hermes status - Show status of all components
|
||||
- hermes cron - Manage cron jobs
|
||||
|
||||
@@ -15,27 +15,25 @@ Architecture:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import base64
|
||||
import hashlib
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
import webbrowser
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from hermes_cli.config import get_hermes_home, get_config_path
|
||||
from hermes_cli.config import get_config_path, get_hermes_home
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,8 +56,8 @@ DEFAULT_NOUS_INFERENCE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
DEFAULT_NOUS_CLIENT_ID = "hermes-cli"
|
||||
DEFAULT_NOUS_SCOPE = "inference:mint_agent_key"
|
||||
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
|
||||
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
|
||||
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
|
||||
DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
@@ -70,9 +68,11 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
# Provider Registry
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""Describes a known inference provider."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
auth_type: str # "oauth_device_code", "oauth_external", or "api_key"
|
||||
@@ -80,14 +80,14 @@ class ProviderConfig:
|
||||
inference_base_url: str = ""
|
||||
client_id: str = ""
|
||||
scope: str = ""
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
# For API-key providers: env vars to check (in priority order)
|
||||
api_key_env_vars: tuple = ()
|
||||
# Optional env var for base URL override
|
||||
base_url_env_var: str = ""
|
||||
|
||||
|
||||
PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
PROVIDER_REGISTRY: dict[str, ProviderConfig] = {
|
||||
"nous": ProviderConfig(
|
||||
id="nous",
|
||||
name="Nous Portal",
|
||||
@@ -172,14 +172,14 @@ def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) ->
|
||||
|
||||
ZAI_ENDPOINTS = [
|
||||
# (id, base_url, default_model, label)
|
||||
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
|
||||
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
|
||||
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
|
||||
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
|
||||
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
|
||||
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
|
||||
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
|
||||
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
|
||||
]
|
||||
|
||||
|
||||
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str, str]]:
|
||||
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> dict[str, str] | None:
|
||||
"""Probe z.ai endpoints to find one that accepts this API key.
|
||||
|
||||
Returns {"id": ..., "base_url": ..., "model": ..., "label": ...} for the
|
||||
@@ -219,6 +219,7 @@ def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str
|
||||
# Error Types
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AuthError(RuntimeError):
|
||||
"""Structured auth error with UX mapping hints."""
|
||||
|
||||
@@ -227,7 +228,7 @@ class AuthError(RuntimeError):
|
||||
message: str,
|
||||
*,
|
||||
provider: str = "",
|
||||
code: Optional[str] = None,
|
||||
code: str | None = None,
|
||||
relogin_required: bool = False,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
@@ -245,16 +246,10 @@ def format_auth_error(error: Exception) -> str:
|
||||
return f"{error} Run `hermes model` to re-authenticate."
|
||||
|
||||
if error.code == "subscription_required":
|
||||
return (
|
||||
"No active paid subscription found on Nous Portal. "
|
||||
"Please purchase/activate a subscription, then retry."
|
||||
)
|
||||
return "No active paid subscription found on Nous Portal. Please purchase/activate a subscription, then retry."
|
||||
|
||||
if error.code == "insufficient_credits":
|
||||
return (
|
||||
"Subscription credits are exhausted. "
|
||||
"Top up/renew credits in Nous Portal, then retry."
|
||||
)
|
||||
return "Subscription credits are exhausted. Top up/renew credits in Nous Portal, then retry."
|
||||
|
||||
if error.code == "temporarily_unavailable":
|
||||
return f"{error} Please retry in a few seconds."
|
||||
@@ -262,7 +257,7 @@ def format_auth_error(error: Exception) -> str:
|
||||
return str(error)
|
||||
|
||||
|
||||
def _token_fingerprint(token: Any) -> Optional[str]:
|
||||
def _token_fingerprint(token: Any) -> str | None:
|
||||
"""Return a short hash fingerprint for telemetry without leaking token bytes."""
|
||||
if not isinstance(token, str):
|
||||
return None
|
||||
@@ -277,10 +272,10 @@ def _oauth_trace_enabled() -> bool:
|
||||
return raw in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None:
|
||||
def _oauth_trace(event: str, *, sequence_id: str | None = None, **fields: Any) -> None:
|
||||
if not _oauth_trace_enabled():
|
||||
return
|
||||
payload: Dict[str, Any] = {"event": event}
|
||||
payload: dict[str, Any] = {"event": event}
|
||||
if sequence_id:
|
||||
payload["sequence_id"] = sequence_id
|
||||
payload.update(fields)
|
||||
@@ -291,6 +286,7 @@ def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any
|
||||
# Auth Store — persistence layer for ~/.hermes/auth.json
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _auth_file_path() -> Path:
|
||||
return get_hermes_home() / "auth.json"
|
||||
|
||||
@@ -326,7 +322,7 @@ def _auth_store_lock(timeout_seconds: float = AUTH_LOCK_TIMEOUT_SECONDS):
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def _load_auth_store(auth_file: Optional[Path] = None) -> Dict[str, Any]:
|
||||
def _load_auth_store(auth_file: Path | None = None) -> dict[str, Any]:
|
||||
auth_file = auth_file or _auth_file_path()
|
||||
if not auth_file.exists():
|
||||
return {"version": AUTH_STORE_VERSION, "providers": {}}
|
||||
@@ -345,17 +341,16 @@ def _load_auth_store(auth_file: Optional[Path] = None) -> Dict[str, Any]:
|
||||
providers = {}
|
||||
if "nous_portal" in systems:
|
||||
providers["nous"] = systems["nous_portal"]
|
||||
return {"version": AUTH_STORE_VERSION, "providers": providers,
|
||||
"active_provider": "nous" if providers else None}
|
||||
return {"version": AUTH_STORE_VERSION, "providers": providers, "active_provider": "nous" if providers else None}
|
||||
|
||||
return {"version": AUTH_STORE_VERSION, "providers": {}}
|
||||
|
||||
|
||||
def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
|
||||
def _save_auth_store(auth_store: dict[str, Any]) -> Path:
|
||||
auth_file = _auth_file_path()
|
||||
auth_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
auth_store["version"] = AUTH_STORE_VERSION
|
||||
auth_store["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
auth_store["updated_at"] = datetime.now(UTC).isoformat()
|
||||
payload = json.dumps(auth_store, indent=2) + "\n"
|
||||
tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
|
||||
try:
|
||||
@@ -387,7 +382,7 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
|
||||
return auth_file
|
||||
|
||||
|
||||
def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
def _load_provider_state(auth_store: dict[str, Any], provider_id: str) -> dict[str, Any] | None:
|
||||
providers = auth_store.get("providers")
|
||||
if not isinstance(providers, dict):
|
||||
return None
|
||||
@@ -395,7 +390,7 @@ def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Option
|
||||
return dict(state) if isinstance(state, dict) else None
|
||||
|
||||
|
||||
def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Dict[str, Any]) -> None:
|
||||
def _save_provider_state(auth_store: dict[str, Any], provider_id: str, state: dict[str, Any]) -> None:
|
||||
providers = auth_store.setdefault("providers", {})
|
||||
if not isinstance(providers, dict):
|
||||
auth_store["providers"] = {}
|
||||
@@ -404,19 +399,19 @@ def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Di
|
||||
auth_store["active_provider"] = provider_id
|
||||
|
||||
|
||||
def get_provider_auth_state(provider_id: str) -> Optional[Dict[str, Any]]:
|
||||
def get_provider_auth_state(provider_id: str) -> dict[str, Any] | None:
|
||||
"""Return persisted auth state for a provider, or None."""
|
||||
auth_store = _load_auth_store()
|
||||
return _load_provider_state(auth_store, provider_id)
|
||||
|
||||
|
||||
def get_active_provider() -> Optional[str]:
|
||||
def get_active_provider() -> str | None:
|
||||
"""Return the currently active provider ID from auth store."""
|
||||
auth_store = _load_auth_store()
|
||||
return auth_store.get("active_provider")
|
||||
|
||||
|
||||
def clear_provider_auth(provider_id: Optional[str] = None) -> bool:
|
||||
def clear_provider_auth(provider_id: str | None = None) -> bool:
|
||||
"""
|
||||
Clear auth state for a provider. Used by `hermes logout`.
|
||||
If provider_id is None, clears the active provider.
|
||||
@@ -455,11 +450,12 @@ def deactivate_provider() -> None:
|
||||
# Provider Resolution — picks which provider to use
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def resolve_provider(
|
||||
requested: Optional[str] = None,
|
||||
requested: str | None = None,
|
||||
*,
|
||||
explicit_api_key: Optional[str] = None,
|
||||
explicit_base_url: Optional[str] = None,
|
||||
explicit_api_key: str | None = None,
|
||||
explicit_base_url: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Determine which inference provider to use.
|
||||
@@ -475,9 +471,14 @@ def resolve_provider(
|
||||
|
||||
# Normalize provider aliases
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
"glm": "zai",
|
||||
"z-ai": "zai",
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
"kimi": "kimi-coding",
|
||||
"moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn",
|
||||
"minimax_cn": "minimax-cn",
|
||||
}
|
||||
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
@@ -524,7 +525,8 @@ def resolve_provider(
|
||||
# Timestamp / TTL helpers
|
||||
# =============================================================================
|
||||
|
||||
def _parse_iso_timestamp(value: Any) -> Optional[float]:
|
||||
|
||||
def _parse_iso_timestamp(value: Any) -> float | None:
|
||||
if not isinstance(value, str) or not value:
|
||||
return None
|
||||
text = value.strip()
|
||||
@@ -537,7 +539,7 @@ def _parse_iso_timestamp(value: Any) -> Optional[float]:
|
||||
except Exception:
|
||||
return None
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
parsed = parsed.replace(tzinfo=UTC)
|
||||
return parsed.timestamp()
|
||||
|
||||
|
||||
@@ -556,14 +558,14 @@ def _coerce_ttl_seconds(expires_in: Any) -> int:
|
||||
return max(0, ttl)
|
||||
|
||||
|
||||
def _optional_base_url(value: Any) -> Optional[str]:
|
||||
def _optional_base_url(value: Any) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
cleaned = value.strip().rstrip("/")
|
||||
return cleaned if cleaned else None
|
||||
|
||||
|
||||
def _decode_jwt_claims(token: Any) -> Dict[str, Any]:
|
||||
def _decode_jwt_claims(token: Any) -> dict[str, Any]:
|
||||
if not isinstance(token, str) or token.count(".") != 2:
|
||||
return {}
|
||||
payload = token.split(".")[1]
|
||||
@@ -588,6 +590,7 @@ def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> boo
|
||||
# SSH / remote session detection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _is_remote_session() -> bool:
|
||||
"""Detect if running in an SSH session where webbrowser.open() won't work."""
|
||||
return bool(os.getenv("SSH_CLIENT") or os.getenv("SSH_TTY"))
|
||||
@@ -601,9 +604,10 @@ def _is_remote_session() -> bool:
|
||||
# where one app's refresh invalidates the other's session.
|
||||
# =============================================================================
|
||||
|
||||
def _read_codex_tokens(*, _lock: bool = True) -> Dict[str, Any]:
|
||||
|
||||
def _read_codex_tokens(*, _lock: bool = True) -> dict[str, Any]:
|
||||
"""Read Codex OAuth tokens from Hermes auth store (~/.hermes/auth.json).
|
||||
|
||||
|
||||
Returns dict with 'tokens' (access_token, refresh_token) and 'last_refresh'.
|
||||
Raises AuthError if no Codex tokens are stored.
|
||||
"""
|
||||
@@ -650,10 +654,10 @@ def _read_codex_tokens(*, _lock: bool = True) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _save_codex_tokens(tokens: Dict[str, str], last_refresh: str = None) -> None:
|
||||
def _save_codex_tokens(tokens: dict[str, str], last_refresh: str = None) -> None:
|
||||
"""Save Codex OAuth tokens to Hermes auth store (~/.hermes/auth.json)."""
|
||||
if last_refresh is None:
|
||||
last_refresh = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
last_refresh = datetime.now(UTC).isoformat().replace("+00:00", "Z")
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
state = _load_provider_state(auth_store, "openai-codex") or {}
|
||||
@@ -665,11 +669,11 @@ def _save_codex_tokens(tokens: Dict[str, str], last_refresh: str = None) -> None
|
||||
|
||||
|
||||
def _refresh_codex_auth_tokens(
|
||||
tokens: Dict[str, str],
|
||||
tokens: dict[str, str],
|
||||
timeout_seconds: float,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Refresh Codex access token using the refresh token.
|
||||
|
||||
|
||||
Saves the new tokens to Hermes auth store automatically.
|
||||
"""
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
@@ -746,9 +750,9 @@ def _refresh_codex_auth_tokens(
|
||||
return updated_tokens
|
||||
|
||||
|
||||
def _import_codex_cli_tokens() -> Optional[Dict[str, str]]:
|
||||
def _import_codex_cli_tokens() -> dict[str, str] | None:
|
||||
"""Try to read tokens from ~/.codex/auth.json (Codex CLI shared file).
|
||||
|
||||
|
||||
Returns tokens dict if valid, None otherwise. Does NOT write to the shared file.
|
||||
"""
|
||||
codex_home = os.getenv("CODEX_HOME", "").strip()
|
||||
@@ -774,7 +778,7 @@ def resolve_codex_runtime_credentials(
|
||||
force_refresh: bool = False,
|
||||
refresh_if_expiring: bool = True,
|
||||
refresh_skew_seconds: int = CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve runtime credentials from Hermes's own Codex token store."""
|
||||
try:
|
||||
data = _read_codex_tokens()
|
||||
@@ -817,10 +821,7 @@ def resolve_codex_runtime_credentials(
|
||||
tokens = _refresh_codex_auth_tokens(tokens, refresh_timeout_seconds)
|
||||
access_token = str(tokens.get("access_token", "") or "").strip()
|
||||
|
||||
base_url = (
|
||||
os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/")
|
||||
or DEFAULT_CODEX_BASE_URL
|
||||
)
|
||||
base_url = os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") or DEFAULT_CODEX_BASE_URL
|
||||
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
@@ -836,24 +837,19 @@ def resolve_codex_runtime_credentials(
|
||||
# TLS verification helper
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _resolve_verify(
|
||||
*,
|
||||
insecure: Optional[bool] = None,
|
||||
ca_bundle: Optional[str] = None,
|
||||
auth_state: Optional[Dict[str, Any]] = None,
|
||||
insecure: bool | None = None,
|
||||
ca_bundle: str | None = None,
|
||||
auth_state: dict[str, Any] | None = None,
|
||||
) -> bool | str:
|
||||
tls_state = auth_state.get("tls") if isinstance(auth_state, dict) else {}
|
||||
tls_state = tls_state if isinstance(tls_state, dict) else {}
|
||||
|
||||
effective_insecure = (
|
||||
bool(insecure) if insecure is not None
|
||||
else bool(tls_state.get("insecure", False))
|
||||
)
|
||||
effective_insecure = bool(insecure) if insecure is not None else bool(tls_state.get("insecure", False))
|
||||
effective_ca = (
|
||||
ca_bundle
|
||||
or tls_state.get("ca_bundle")
|
||||
or os.getenv("HERMES_CA_BUNDLE")
|
||||
or os.getenv("SSL_CERT_FILE")
|
||||
ca_bundle or tls_state.get("ca_bundle") or os.getenv("HERMES_CA_BUNDLE") or os.getenv("SSL_CERT_FILE")
|
||||
)
|
||||
|
||||
if effective_insecure:
|
||||
@@ -867,12 +863,13 @@ def _resolve_verify(
|
||||
# OAuth Device Code Flow — generic, parameterized by provider
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _request_device_code(
|
||||
client: httpx.Client,
|
||||
portal_base_url: str,
|
||||
client_id: str,
|
||||
scope: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
scope: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""POST to the device code endpoint. Returns device_code, user_code, etc."""
|
||||
response = client.post(
|
||||
f"{portal_base_url}/api/oauth/device/code",
|
||||
@@ -885,8 +882,12 @@ def _request_device_code(
|
||||
data = response.json()
|
||||
|
||||
required_fields = [
|
||||
"device_code", "user_code", "verification_uri",
|
||||
"verification_uri_complete", "expires_in", "interval",
|
||||
"device_code",
|
||||
"user_code",
|
||||
"verification_uri",
|
||||
"verification_uri_complete",
|
||||
"expires_in",
|
||||
"interval",
|
||||
]
|
||||
missing = [f for f in required_fields if f not in data]
|
||||
if missing:
|
||||
@@ -901,7 +902,7 @@ def _poll_for_token(
|
||||
device_code: str,
|
||||
expires_in: int,
|
||||
poll_interval: int,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Poll the token endpoint until the user approves or the code expires."""
|
||||
deadline = time.time() + max(1, expires_in)
|
||||
current_interval = max(1, min(poll_interval, DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS))
|
||||
@@ -947,13 +948,14 @@ def _poll_for_token(
|
||||
# Nous Portal — token refresh, agent key minting, model discovery
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _refresh_access_token(
|
||||
*,
|
||||
client: httpx.Client,
|
||||
portal_base_url: str,
|
||||
client_id: str,
|
||||
refresh_token: str,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
response = client.post(
|
||||
f"{portal_base_url}/api/oauth/token",
|
||||
data={
|
||||
@@ -966,15 +968,15 @@ def _refresh_access_token(
|
||||
if response.status_code == 200:
|
||||
payload = response.json()
|
||||
if "access_token" not in payload:
|
||||
raise AuthError("Refresh response missing access_token",
|
||||
provider="nous", code="invalid_token", relogin_required=True)
|
||||
raise AuthError(
|
||||
"Refresh response missing access_token", provider="nous", code="invalid_token", relogin_required=True
|
||||
)
|
||||
return payload
|
||||
|
||||
try:
|
||||
error_payload = response.json()
|
||||
except Exception as exc:
|
||||
raise AuthError("Refresh token exchange failed",
|
||||
provider="nous", relogin_required=True) from exc
|
||||
raise AuthError("Refresh token exchange failed", provider="nous", relogin_required=True) from exc
|
||||
|
||||
code = str(error_payload.get("error", "invalid_grant"))
|
||||
description = str(error_payload.get("error_description") or "Refresh token exchange failed")
|
||||
@@ -988,7 +990,7 @@ def _mint_agent_key(
|
||||
portal_base_url: str,
|
||||
access_token: str,
|
||||
min_ttl_seconds: int,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Mint (or reuse) a short-lived inference API key."""
|
||||
response = client.post(
|
||||
f"{portal_base_url}/api/oauth/agent-key",
|
||||
@@ -999,15 +1001,13 @@ def _mint_agent_key(
|
||||
if response.status_code == 200:
|
||||
payload = response.json()
|
||||
if "api_key" not in payload:
|
||||
raise AuthError("Mint response missing api_key",
|
||||
provider="nous", code="server_error")
|
||||
raise AuthError("Mint response missing api_key", provider="nous", code="server_error")
|
||||
return payload
|
||||
|
||||
try:
|
||||
error_payload = response.json()
|
||||
except Exception as exc:
|
||||
raise AuthError("Agent key mint request failed",
|
||||
provider="nous", code="server_error") from exc
|
||||
raise AuthError("Agent key mint request failed", provider="nous", code="server_error") from exc
|
||||
|
||||
code = str(error_payload.get("error", "server_error"))
|
||||
description = str(error_payload.get("error_description") or "Agent key mint request failed")
|
||||
@@ -1021,7 +1021,7 @@ def fetch_nous_models(
|
||||
api_key: str,
|
||||
timeout_seconds: float = 15.0,
|
||||
verify: bool | str = True,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Fetch available model IDs from the Nous inference API."""
|
||||
timeout = httpx.Timeout(timeout_seconds)
|
||||
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
|
||||
@@ -1044,7 +1044,7 @@ def fetch_nous_models(
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
|
||||
model_ids: List[str] = []
|
||||
model_ids: list[str] = []
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
@@ -1059,7 +1059,7 @@ def fetch_nous_models(
|
||||
return list(dict.fromkeys(model_ids))
|
||||
|
||||
|
||||
def _agent_key_is_usable(state: Dict[str, Any], min_ttl_seconds: int) -> bool:
|
||||
def _agent_key_is_usable(state: dict[str, Any], min_ttl_seconds: int) -> bool:
|
||||
key = state.get("agent_key")
|
||||
if not isinstance(key, str) or not key.strip():
|
||||
return False
|
||||
@@ -1070,10 +1070,10 @@ def resolve_nous_runtime_credentials(
|
||||
*,
|
||||
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
timeout_seconds: float = 15.0,
|
||||
insecure: Optional[bool] = None,
|
||||
ca_bundle: Optional[str] = None,
|
||||
insecure: bool | None = None,
|
||||
ca_bundle: str | None = None,
|
||||
force_mint: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Resolve Nous inference credentials for runtime use.
|
||||
|
||||
@@ -1092,8 +1092,7 @@ def resolve_nous_runtime_credentials(
|
||||
state = _load_provider_state(auth_store, "nous")
|
||||
|
||||
if not state:
|
||||
raise AuthError("Hermes is not logged into Nous Portal.",
|
||||
provider="nous", relogin_required=True)
|
||||
raise AuthError("Hermes is not logged into Nous Portal.", provider="nous", relogin_required=True)
|
||||
|
||||
portal_base_url = (
|
||||
_optional_base_url(state.get("portal_base_url"))
|
||||
@@ -1143,14 +1142,14 @@ def resolve_nous_runtime_credentials(
|
||||
refresh_token = state.get("refresh_token")
|
||||
|
||||
if not isinstance(access_token, str) or not access_token:
|
||||
raise AuthError("No access token found for Nous Portal login.",
|
||||
provider="nous", relogin_required=True)
|
||||
raise AuthError("No access token found for Nous Portal login.", provider="nous", relogin_required=True)
|
||||
|
||||
# Step 1: refresh access token if expiring
|
||||
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
|
||||
if not isinstance(refresh_token, str) or not refresh_token:
|
||||
raise AuthError("Session expired and no refresh token is available.",
|
||||
provider="nous", relogin_required=True)
|
||||
raise AuthError(
|
||||
"Session expired and no refresh token is available.", provider="nous", relogin_required=True
|
||||
)
|
||||
|
||||
_oauth_trace(
|
||||
"refresh_start",
|
||||
@@ -1159,10 +1158,12 @@ def resolve_nous_runtime_credentials(
|
||||
refresh_token_fp=_token_fingerprint(refresh_token),
|
||||
)
|
||||
refreshed = _refresh_access_token(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, refresh_token=refresh_token,
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
|
||||
previous_refresh_token = refresh_token
|
||||
state["access_token"] = refreshed["access_token"]
|
||||
@@ -1174,9 +1175,7 @@ def resolve_nous_runtime_credentials(
|
||||
inference_base_url = refreshed_url
|
||||
state["obtained_at"] = now.isoformat()
|
||||
state["expires_in"] = access_ttl
|
||||
state["expires_at"] = datetime.fromtimestamp(
|
||||
now.timestamp() + access_ttl, tz=timezone.utc
|
||||
).isoformat()
|
||||
state["expires_at"] = datetime.fromtimestamp(now.timestamp() + access_ttl, tz=UTC).isoformat()
|
||||
access_token = state["access_token"]
|
||||
refresh_token = state["refresh_token"]
|
||||
_oauth_trace(
|
||||
@@ -1191,7 +1190,7 @@ def resolve_nous_runtime_credentials(
|
||||
|
||||
# Step 2: mint agent key if missing/expiring
|
||||
used_cached_key = False
|
||||
mint_payload: Optional[Dict[str, Any]] = None
|
||||
mint_payload: dict[str, Any] | None = None
|
||||
|
||||
if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
|
||||
used_cached_key = True
|
||||
@@ -1204,8 +1203,10 @@ def resolve_nous_runtime_credentials(
|
||||
access_token_fp=_token_fingerprint(access_token),
|
||||
)
|
||||
mint_payload = _mint_agent_key(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
access_token=access_token,
|
||||
min_ttl_seconds=min_key_ttl_seconds,
|
||||
)
|
||||
except AuthError as exc:
|
||||
_oauth_trace(
|
||||
@@ -1227,10 +1228,12 @@ def resolve_nous_runtime_credentials(
|
||||
refresh_token_fp=_token_fingerprint(latest_refresh_token),
|
||||
)
|
||||
refreshed = _refresh_access_token(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, refresh_token=latest_refresh_token,
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
refresh_token=latest_refresh_token,
|
||||
)
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
|
||||
state["access_token"] = refreshed["access_token"]
|
||||
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
|
||||
@@ -1241,9 +1244,7 @@ def resolve_nous_runtime_credentials(
|
||||
inference_base_url = refreshed_url
|
||||
state["obtained_at"] = now.isoformat()
|
||||
state["expires_in"] = access_ttl
|
||||
state["expires_at"] = datetime.fromtimestamp(
|
||||
now.timestamp() + access_ttl, tz=timezone.utc
|
||||
).isoformat()
|
||||
state["expires_at"] = datetime.fromtimestamp(now.timestamp() + access_ttl, tz=UTC).isoformat()
|
||||
access_token = state["access_token"]
|
||||
refresh_token = state["refresh_token"]
|
||||
_oauth_trace(
|
||||
@@ -1257,14 +1258,16 @@ def resolve_nous_runtime_credentials(
|
||||
_persist_state("post_refresh_mint_retry")
|
||||
|
||||
mint_payload = _mint_agent_key(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
access_token=access_token,
|
||||
min_ttl_seconds=min_key_ttl_seconds,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if mint_payload is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
state["agent_key"] = mint_payload.get("api_key")
|
||||
state["agent_key_id"] = mint_payload.get("key_id")
|
||||
state["agent_key_expires_at"] = mint_payload.get("expires_at")
|
||||
@@ -1293,8 +1296,7 @@ def resolve_nous_runtime_credentials(
|
||||
|
||||
api_key = state.get("agent_key")
|
||||
if not isinstance(api_key, str) or not api_key:
|
||||
raise AuthError("Failed to resolve a Nous inference API key",
|
||||
provider="nous", code="server_error")
|
||||
raise AuthError("Failed to resolve a Nous inference API key", provider="nous", code="server_error")
|
||||
|
||||
expires_at = state.get("agent_key_expires_at")
|
||||
expires_epoch = _parse_iso_timestamp(expires_at)
|
||||
@@ -1319,7 +1321,8 @@ def resolve_nous_runtime_credentials(
|
||||
# Status helpers
|
||||
# =============================================================================
|
||||
|
||||
def get_nous_auth_status() -> Dict[str, Any]:
|
||||
|
||||
def get_nous_auth_status() -> dict[str, Any]:
|
||||
"""Status snapshot for `hermes status` output."""
|
||||
state = get_provider_auth_state("nous")
|
||||
if not state:
|
||||
@@ -1341,7 +1344,7 @@ def get_nous_auth_status() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_codex_auth_status() -> Dict[str, Any]:
|
||||
def get_codex_auth_status() -> dict[str, Any]:
|
||||
"""Status snapshot for Codex auth."""
|
||||
try:
|
||||
creds = resolve_codex_runtime_credentials()
|
||||
@@ -1360,7 +1363,7 @@ def get_codex_auth_status() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
def get_api_key_provider_status(provider_id: str) -> dict[str, Any]:
|
||||
"""Status snapshot for API-key providers (z.ai, Kimi, MiniMax)."""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
@@ -1396,7 +1399,7 @@ def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
def get_auth_status(provider_id: str | None = None) -> dict[str, Any]:
|
||||
"""Generic auth status dispatcher."""
|
||||
target = provider_id or get_active_provider()
|
||||
if target == "nous":
|
||||
@@ -1410,7 +1413,7 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {"logged_in": False}
|
||||
|
||||
|
||||
def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
def resolve_api_key_provider_credentials(provider_id: str) -> dict[str, Any]:
|
||||
"""Resolve API key and base URL for an API-key provider.
|
||||
|
||||
Returns dict with: provider, api_key, base_url, source.
|
||||
@@ -1455,7 +1458,8 @@ def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
# External credential detection
|
||||
# =============================================================================
|
||||
|
||||
def detect_external_credentials() -> List[Dict[str, Any]]:
|
||||
|
||||
def detect_external_credentials() -> list[dict[str, Any]]:
|
||||
"""Scan for credentials from other CLI tools that Hermes can reuse.
|
||||
|
||||
Returns a list of dicts, each with:
|
||||
@@ -1463,17 +1467,19 @@ def detect_external_credentials() -> List[Dict[str, Any]]:
|
||||
- path: str -- filesystem path where creds were found
|
||||
- label: str -- human-friendly description for the setup UI
|
||||
"""
|
||||
found: List[Dict[str, Any]] = []
|
||||
found: list[dict[str, Any]] = []
|
||||
|
||||
# Codex CLI: ~/.codex/auth.json (importable, not shared)
|
||||
cli_tokens = _import_codex_cli_tokens()
|
||||
if cli_tokens:
|
||||
codex_path = Path.home() / ".codex" / "auth.json"
|
||||
found.append({
|
||||
"provider": "openai-codex",
|
||||
"path": str(codex_path),
|
||||
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes login` to create a separate session",
|
||||
})
|
||||
found.append(
|
||||
{
|
||||
"provider": "openai-codex",
|
||||
"path": str(codex_path),
|
||||
"label": f"Codex CLI credentials found ({codex_path}) — run `hermes login` to create a separate session",
|
||||
}
|
||||
)
|
||||
|
||||
return found
|
||||
|
||||
@@ -1482,6 +1488,7 @@ def detect_external_credentials() -> List[Dict[str, Any]]:
|
||||
# CLI Commands — login / logout
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Path:
|
||||
"""Update config.yaml and auth.json to reflect the active provider."""
|
||||
# Set active_provider in auth.json so auto-resolution picks this provider
|
||||
@@ -1494,7 +1501,7 @@ def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Pa
|
||||
config_path = get_config_path()
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config: Dict[str, Any] = {}
|
||||
config: dict[str, Any] = {}
|
||||
if config_path.exists():
|
||||
try:
|
||||
loaded = yaml.safe_load(config_path.read_text()) or {}
|
||||
@@ -1542,7 +1549,7 @@ def _reset_config_provider() -> Path:
|
||||
return config_path
|
||||
|
||||
|
||||
def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Optional[str]:
|
||||
def _prompt_model_selection(model_ids: list[str], current_model: str = "") -> str | None:
|
||||
"""Interactive model selection. Puts current_model first with a marker. Returns chosen model ID or None."""
|
||||
# Reorder: current model first, then the rest (deduplicated)
|
||||
ordered = []
|
||||
@@ -1564,6 +1571,7 @@ def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Op
|
||||
# Try arrow-key menu first, fall back to number input
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
|
||||
choices = [f" {_label(mid)}" for mid in ordered]
|
||||
choices.append(" Enter custom model name")
|
||||
choices.append(" Skip (keep current)")
|
||||
@@ -1621,7 +1629,7 @@ def _prompt_model_selection(model_ids: List[str], current_model: str = "") -> Op
|
||||
|
||||
def _save_model_choice(model_id: str) -> None:
|
||||
"""Save the selected model to config.yaml and .env."""
|
||||
from hermes_cli.config import save_config, load_config, save_env_value
|
||||
from hermes_cli.config import load_config, save_config, save_env_value
|
||||
|
||||
config = load_config()
|
||||
# Handle both string and dict model formats
|
||||
@@ -1693,11 +1701,11 @@ def _login_openai_codex(args, pconfig: ProviderConfig) -> None:
|
||||
config_path = _update_config_for_provider("openai-codex", creds.get("base_url", DEFAULT_CODEX_BASE_URL))
|
||||
print()
|
||||
print("Login successful!")
|
||||
print(f" Auth state: ~/.hermes/auth.json")
|
||||
print(" Auth state: ~/.hermes/auth.json")
|
||||
print(f" Config updated: {config_path} (model.provider=openai-codex)")
|
||||
|
||||
|
||||
def _codex_device_code_login() -> Dict[str, Any]:
|
||||
def _codex_device_code_login() -> dict[str, Any]:
|
||||
"""Run the OpenAI device code login flow and return credentials dict."""
|
||||
import time as _time
|
||||
|
||||
@@ -1715,13 +1723,15 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
except Exception as exc:
|
||||
raise AuthError(
|
||||
f"Failed to request device code: {exc}",
|
||||
provider="openai-codex", code="device_code_request_failed",
|
||||
provider="openai-codex",
|
||||
code="device_code_request_failed",
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise AuthError(
|
||||
f"Device code request returned status {resp.status_code}.",
|
||||
provider="openai-codex", code="device_code_request_error",
|
||||
provider="openai-codex",
|
||||
code="device_code_request_error",
|
||||
)
|
||||
|
||||
device_data = resp.json()
|
||||
@@ -1732,14 +1742,15 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
if not user_code or not device_auth_id:
|
||||
raise AuthError(
|
||||
"Device code response missing required fields.",
|
||||
provider="openai-codex", code="device_code_incomplete",
|
||||
provider="openai-codex",
|
||||
code="device_code_incomplete",
|
||||
)
|
||||
|
||||
# Step 2: Show user the code
|
||||
print("To continue, follow these steps:\n")
|
||||
print(f" 1. Open this URL in your browser:")
|
||||
print(" 1. Open this URL in your browser:")
|
||||
print(f" \033[94m{issuer}/codex/device\033[0m\n")
|
||||
print(f" 2. Enter this code:")
|
||||
print(" 2. Enter this code:")
|
||||
print(f" \033[94m{user_code}\033[0m\n")
|
||||
print("Waiting for sign-in... (press Ctrl+C to cancel)")
|
||||
|
||||
@@ -1766,7 +1777,8 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
else:
|
||||
raise AuthError(
|
||||
f"Device auth polling returned status {poll_resp.status_code}.",
|
||||
provider="openai-codex", code="device_code_poll_error",
|
||||
provider="openai-codex",
|
||||
code="device_code_poll_error",
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("\nLogin cancelled.")
|
||||
@@ -1775,7 +1787,8 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
if code_resp is None:
|
||||
raise AuthError(
|
||||
"Login timed out after 15 minutes.",
|
||||
provider="openai-codex", code="device_code_timeout",
|
||||
provider="openai-codex",
|
||||
code="device_code_timeout",
|
||||
)
|
||||
|
||||
# Step 4: Exchange authorization code for tokens
|
||||
@@ -1786,7 +1799,8 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
if not authorization_code or not code_verifier:
|
||||
raise AuthError(
|
||||
"Device auth response missing authorization_code or code_verifier.",
|
||||
provider="openai-codex", code="device_code_incomplete_exchange",
|
||||
provider="openai-codex",
|
||||
code="device_code_incomplete_exchange",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1805,13 +1819,15 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
except Exception as exc:
|
||||
raise AuthError(
|
||||
f"Token exchange failed: {exc}",
|
||||
provider="openai-codex", code="token_exchange_failed",
|
||||
provider="openai-codex",
|
||||
code="token_exchange_failed",
|
||||
)
|
||||
|
||||
if token_resp.status_code != 200:
|
||||
raise AuthError(
|
||||
f"Token exchange returned status {token_resp.status_code}.",
|
||||
provider="openai-codex", code="token_exchange_error",
|
||||
provider="openai-codex",
|
||||
code="token_exchange_error",
|
||||
)
|
||||
|
||||
tokens = token_resp.json()
|
||||
@@ -1821,14 +1837,12 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
if not access_token:
|
||||
raise AuthError(
|
||||
"Token exchange did not return an access_token.",
|
||||
provider="openai-codex", code="token_exchange_no_access_token",
|
||||
provider="openai-codex",
|
||||
code="token_exchange_no_access_token",
|
||||
)
|
||||
|
||||
# Return tokens for the caller to persist (no longer writes to ~/.codex/)
|
||||
base_url = (
|
||||
os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/")
|
||||
or DEFAULT_CODEX_BASE_URL
|
||||
)
|
||||
base_url = os.getenv("HERMES_CODEX_BASE_URL", "").strip().rstrip("/") or DEFAULT_CODEX_BASE_URL
|
||||
|
||||
return {
|
||||
"tokens": {
|
||||
@@ -1836,7 +1850,7 @@ def _codex_device_code_login() -> Dict[str, Any]:
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
"base_url": base_url,
|
||||
"last_refresh": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
|
||||
"last_refresh": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
||||
"auth_mode": "chatgpt",
|
||||
"source": "device-code",
|
||||
}
|
||||
@@ -1851,9 +1865,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
or pconfig.portal_base_url
|
||||
).rstrip("/")
|
||||
requested_inference_url = (
|
||||
getattr(args, "inference_url", None)
|
||||
or os.getenv("NOUS_INFERENCE_BASE_URL")
|
||||
or pconfig.inference_base_url
|
||||
getattr(args, "inference_url", None) or os.getenv("NOUS_INFERENCE_BASE_URL") or pconfig.inference_base_url
|
||||
).rstrip("/")
|
||||
client_id = getattr(args, "client_id", None) or pconfig.client_id
|
||||
scope = getattr(args, "scope", None) or pconfig.scope
|
||||
@@ -1862,11 +1874,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
timeout = httpx.Timeout(timeout_seconds)
|
||||
|
||||
insecure = bool(getattr(args, "insecure", False))
|
||||
ca_bundle = (
|
||||
getattr(args, "ca_bundle", None)
|
||||
or os.getenv("HERMES_CA_BUNDLE")
|
||||
or os.getenv("SSL_CERT_FILE")
|
||||
)
|
||||
ca_bundle = getattr(args, "ca_bundle", None) or os.getenv("HERMES_CA_BUNDLE") or os.getenv("SSL_CERT_FILE")
|
||||
verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True)
|
||||
|
||||
# Skip browser open in SSH sessions
|
||||
@@ -1883,8 +1891,10 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
try:
|
||||
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
|
||||
device_data = _request_device_code(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, scope=scope,
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
verification_url = str(device_data["verification_uri_complete"])
|
||||
@@ -1908,19 +1918,19 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
print(f"Waiting for approval (polling every {effective_interval}s)...")
|
||||
|
||||
token_data = _poll_for_token(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, device_code=str(device_data["device_code"]),
|
||||
expires_in=expires_in, poll_interval=interval,
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
device_code=str(device_data["device_code"]),
|
||||
expires_in=expires_in,
|
||||
poll_interval=interval,
|
||||
)
|
||||
|
||||
# Process token response
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
token_expires_in = _coerce_ttl_seconds(token_data.get("expires_in", 0))
|
||||
expires_at = now.timestamp() + token_expires_in
|
||||
inference_base_url = (
|
||||
_optional_base_url(token_data.get("inference_base_url"))
|
||||
or requested_inference_url
|
||||
)
|
||||
inference_base_url = _optional_base_url(token_data.get("inference_base_url")) or requested_inference_url
|
||||
if inference_base_url != requested_inference_url:
|
||||
print(f"Using portal-provided inference URL: {inference_base_url}")
|
||||
|
||||
@@ -1933,7 +1943,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
"access_token": token_data["access_token"],
|
||||
"refresh_token": token_data.get("refresh_token"),
|
||||
"obtained_at": now.isoformat(),
|
||||
"expires_at": datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(),
|
||||
"expires_at": datetime.fromtimestamp(expires_at, tz=UTC).isoformat(),
|
||||
"expires_in": token_expires_in,
|
||||
"tls": {
|
||||
"insecure": verify is False,
|
||||
@@ -1964,13 +1974,13 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
||||
runtime_creds = resolve_nous_runtime_credentials(
|
||||
min_key_ttl_seconds=5 * 60,
|
||||
timeout_seconds=timeout_seconds,
|
||||
insecure=insecure, ca_bundle=ca_bundle,
|
||||
insecure=insecure,
|
||||
ca_bundle=ca_bundle,
|
||||
)
|
||||
runtime_key = runtime_creds.get("api_key")
|
||||
runtime_base_url = runtime_creds.get("base_url") or inference_base_url
|
||||
if not isinstance(runtime_key, str) or not runtime_key:
|
||||
raise AuthError("No runtime API key available to fetch models",
|
||||
provider="nous", code="invalid_token")
|
||||
raise AuthError("No runtime API key available to fetch models", provider="nous", code="invalid_token")
|
||||
|
||||
model_ids = fetch_nous_models(
|
||||
inference_base_url=runtime_base_url,
|
||||
|
||||
@@ -9,14 +9,12 @@ import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from prompt_toolkit import print_formatted_text as _pt_print
|
||||
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -77,7 +75,8 @@ COMPACT_BANNER = """
|
||||
# Skills scanning
|
||||
# =========================================================================
|
||||
|
||||
def get_available_skills() -> Dict[str, List[str]]:
|
||||
|
||||
def get_available_skills() -> dict[str, list[str]]:
|
||||
"""Scan ~/.hermes/skills/ and return skills grouped by category."""
|
||||
import os
|
||||
|
||||
@@ -110,7 +109,7 @@ def get_available_skills() -> Dict[str, List[str]]:
|
||||
_UPDATE_CHECK_CACHE_SECONDS = 6 * 3600
|
||||
|
||||
|
||||
def check_for_updates() -> Optional[int]:
|
||||
def check_for_updates() -> int | None:
|
||||
"""Check how many commits behind origin/main the local repo is.
|
||||
|
||||
Does a ``git fetch`` at most once every 6 hours (cached to
|
||||
@@ -139,7 +138,8 @@ def check_for_updates() -> Optional[int]:
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "fetch", "origin", "--quiet"],
|
||||
capture_output=True, timeout=10,
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
except Exception:
|
||||
@@ -149,7 +149,9 @@ def check_for_updates() -> Optional[int]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-list", "--count", "HEAD..origin/main"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
@@ -172,6 +174,7 @@ def check_for_updates() -> Optional[int]:
|
||||
# Welcome banner
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _format_context_length(tokens: int) -> str:
|
||||
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
|
||||
if tokens >= 1_000_000:
|
||||
@@ -183,12 +186,16 @@ def _format_context_length(tokens: int) -> str:
|
||||
return str(tokens)
|
||||
|
||||
|
||||
def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
tools: List[dict] = None,
|
||||
enabled_toolsets: List[str] = None,
|
||||
session_id: str = None,
|
||||
get_toolset_for_tool=None,
|
||||
context_length: int = None):
|
||||
def build_welcome_banner(
|
||||
console: Console,
|
||||
model: str,
|
||||
cwd: str,
|
||||
tools: list[dict] = None,
|
||||
enabled_toolsets: list[str] = None,
|
||||
session_id: str = None,
|
||||
get_toolset_for_tool=None,
|
||||
context_length: int = None,
|
||||
):
|
||||
"""Build and print a welcome banner with caduceus on left and info on right.
|
||||
|
||||
Args:
|
||||
@@ -201,7 +208,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
get_toolset_for_tool: Callable to map tool name -> toolset name.
|
||||
context_length: Model's context window size in tokens.
|
||||
"""
|
||||
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
|
||||
from model_tools import check_tool_availability
|
||||
|
||||
if get_toolset_for_tool is None:
|
||||
from model_tools import get_toolset_for_tool
|
||||
|
||||
@@ -221,7 +229,9 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
model_short = model.split("/")[-1] if "/" in model else model
|
||||
if len(model_short) > 28:
|
||||
model_short = model_short[:25] + "..."
|
||||
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
|
||||
ctx_str = (
|
||||
f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
|
||||
)
|
||||
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
|
||||
left_lines.append(f"[dim #B8860B]{cwd}[/]")
|
||||
if session_id:
|
||||
@@ -229,7 +239,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
left_content = "\n".join(left_lines)
|
||||
|
||||
right_lines = ["[bold #FFBF00]Available Tools[/]"]
|
||||
toolsets_dict: Dict[str, list] = {}
|
||||
toolsets_dict: dict[str, list] = {}
|
||||
|
||||
for tool in tools:
|
||||
tool_name = tool["function"]["name"]
|
||||
@@ -286,6 +296,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
# MCP Servers section (only if configured)
|
||||
try:
|
||||
from tools.mcp_tool import get_mcp_status
|
||||
|
||||
mcp_status = get_mcp_status()
|
||||
except Exception:
|
||||
mcp_status = []
|
||||
@@ -300,10 +311,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
f"[dim #B8860B]—[/] [#FFF8DC]{srv['tools']} tool(s)[/]"
|
||||
)
|
||||
else:
|
||||
right_lines.append(
|
||||
f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] "
|
||||
f"[red]— failed[/]"
|
||||
)
|
||||
right_lines.append(f"[red]{srv['name']}[/] [dim]({srv['transport']})[/] [red]— failed[/]")
|
||||
|
||||
right_lines.append("")
|
||||
right_lines.append("[bold #FFBF00]Available Skills[/]")
|
||||
|
||||
@@ -9,7 +9,7 @@ with the TUI.
|
||||
import queue
|
||||
import time as _time
|
||||
|
||||
from hermes_cli.banner import cprint, _DIM, _RST
|
||||
from hermes_cli.banner import _DIM, _RST, cprint
|
||||
|
||||
|
||||
def clarify_callback(cli, question, choices):
|
||||
@@ -33,7 +33,7 @@ def clarify_callback(cli, question, choices):
|
||||
cli._clarify_deadline = _time.monotonic() + timeout
|
||||
cli._clarify_freetext = is_open_ended
|
||||
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
while True:
|
||||
@@ -45,13 +45,13 @@ def clarify_callback(cli, question, choices):
|
||||
remaining = cli._clarify_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
cli._clarify_state = None
|
||||
cli._clarify_freetext = False
|
||||
cli._clarify_deadline = 0
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
cprint(f"\n{_DIM}(clarify timed out after {timeout}s — agent will decide){_RST}")
|
||||
return (
|
||||
@@ -71,7 +71,7 @@ def sudo_password_callback(cli) -> str:
|
||||
cli._sudo_state = {"response_queue": response_queue}
|
||||
cli._sudo_deadline = _time.monotonic() + timeout
|
||||
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
while True:
|
||||
@@ -79,7 +79,7 @@ def sudo_password_callback(cli) -> str:
|
||||
result = response_queue.get(timeout=1)
|
||||
cli._sudo_state = None
|
||||
cli._sudo_deadline = 0
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
if result:
|
||||
cprint(f"\n{_DIM} ✓ Password received (cached for session){_RST}")
|
||||
@@ -90,12 +90,12 @@ def sudo_password_callback(cli) -> str:
|
||||
remaining = cli._sudo_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
cli._sudo_state = None
|
||||
cli._sudo_deadline = 0
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
cprint(f"\n{_DIM} ⏱ Timeout — continuing without sudo{_RST}")
|
||||
return ""
|
||||
@@ -119,7 +119,7 @@ def approval_callback(cli, command: str, description: str) -> str:
|
||||
}
|
||||
cli._approval_deadline = _time.monotonic() + timeout
|
||||
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
while True:
|
||||
@@ -127,19 +127,19 @@ def approval_callback(cli, command: str, description: str) -> str:
|
||||
result = response_queue.get(timeout=1)
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
return result
|
||||
except queue.Empty:
|
||||
remaining = cli._approval_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
if hasattr(cli, '_app') and cli._app:
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}")
|
||||
return "deny"
|
||||
|
||||
@@ -51,6 +51,7 @@ def has_clipboard_image() -> bool:
|
||||
|
||||
# ── macOS ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _macos_save(dest: Path) -> bool:
|
||||
"""Try pngpaste first (fast, handles more formats), fall back to osascript."""
|
||||
return _macos_pngpaste(dest) or _macos_osascript(dest)
|
||||
@@ -61,7 +62,9 @@ def _macos_has_image() -> bool:
|
||||
try:
|
||||
info = subprocess.run(
|
||||
["osascript", "-e", "clipboard info"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
return "«class PNGf»" in info.stdout or "«class TIFF»" in info.stdout
|
||||
except Exception:
|
||||
@@ -73,7 +76,8 @@ def _macos_pngpaste(dest: Path) -> bool:
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["pngpaste", str(dest)],
|
||||
capture_output=True, timeout=3,
|
||||
capture_output=True,
|
||||
timeout=3,
|
||||
)
|
||||
if r.returncode == 0 and dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
@@ -91,19 +95,21 @@ def _macos_osascript(dest: Path) -> bool:
|
||||
|
||||
# Extract as PNG
|
||||
script = (
|
||||
'try\n'
|
||||
' set imgData to the clipboard as «class PNGf»\n'
|
||||
"try\n"
|
||||
" set imgData to the clipboard as «class PNGf»\n"
|
||||
f' set f to open for access POSIX file "{dest}" with write permission\n'
|
||||
' write imgData to f\n'
|
||||
' close access f\n'
|
||||
'on error\n'
|
||||
" write imgData to f\n"
|
||||
" close access f\n"
|
||||
"on error\n"
|
||||
' return "fail"\n'
|
||||
'end try\n'
|
||||
"end try\n"
|
||||
)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["osascript", "-e", script],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if r.returncode == 0 and "fail" not in r.stdout and dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
@@ -114,13 +120,14 @@ def _macos_osascript(dest: Path) -> bool:
|
||||
|
||||
# ── Linux ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _is_wsl() -> bool:
|
||||
"""Detect if running inside WSL (1 or 2)."""
|
||||
global _wsl_detected
|
||||
if _wsl_detected is not None:
|
||||
return _wsl_detected
|
||||
try:
|
||||
with open("/proc/version", "r") as f:
|
||||
with open("/proc/version") as f:
|
||||
_wsl_detected = "microsoft" in f.read().lower()
|
||||
except Exception:
|
||||
_wsl_detected = False
|
||||
@@ -145,10 +152,7 @@ def _linux_save(dest: Path) -> bool:
|
||||
|
||||
# PowerShell script: get clipboard image as base64-encoded PNG on stdout.
|
||||
# Using .NET System.Windows.Forms.Clipboard — always available on Windows.
|
||||
_PS_CHECK_IMAGE = (
|
||||
"Add-Type -AssemblyName System.Windows.Forms;"
|
||||
"[System.Windows.Forms.Clipboard]::ContainsImage()"
|
||||
)
|
||||
_PS_CHECK_IMAGE = "Add-Type -AssemblyName System.Windows.Forms;[System.Windows.Forms.Clipboard]::ContainsImage()"
|
||||
|
||||
_PS_EXTRACT_IMAGE = (
|
||||
"Add-Type -AssemblyName System.Windows.Forms;"
|
||||
@@ -165,9 +169,10 @@ def _wsl_has_image() -> bool:
|
||||
"""Check if Windows clipboard has an image (via powershell.exe)."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=8,
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", _PS_CHECK_IMAGE],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=8,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except FileNotFoundError:
|
||||
@@ -181,9 +186,10 @@ def _wsl_save(dest: Path) -> bool:
|
||||
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command", _PS_EXTRACT_IMAGE],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
@@ -206,16 +212,17 @@ def _wsl_save(dest: Path) -> bool:
|
||||
|
||||
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _wayland_has_image() -> bool:
|
||||
"""Check if Wayland clipboard has image content."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["wl-paste", "--list-types"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
return r.returncode == 0 and any(
|
||||
t.startswith("image/") for t in r.stdout.splitlines()
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
return r.returncode == 0 and any(t.startswith("image/") for t in r.stdout.splitlines())
|
||||
except FileNotFoundError:
|
||||
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
|
||||
except Exception:
|
||||
@@ -229,7 +236,9 @@ def _wayland_save(dest: Path) -> bool:
|
||||
# Check available MIME types
|
||||
types_r = subprocess.run(
|
||||
["wl-paste", "--list-types"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if types_r.returncode != 0:
|
||||
return False
|
||||
@@ -237,8 +246,7 @@ def _wayland_save(dest: Path) -> bool:
|
||||
|
||||
# Prefer PNG, fall back to other image formats
|
||||
mime = None
|
||||
for preferred in ("image/png", "image/jpeg", "image/bmp",
|
||||
"image/gif", "image/webp"):
|
||||
for preferred in ("image/png", "image/jpeg", "image/bmp", "image/gif", "image/webp"):
|
||||
if preferred in types:
|
||||
mime = preferred
|
||||
break
|
||||
@@ -250,7 +258,10 @@ def _wayland_save(dest: Path) -> bool:
|
||||
with open(dest, "wb") as f:
|
||||
subprocess.run(
|
||||
["wl-paste", "--type", mime],
|
||||
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
|
||||
stdout=f,
|
||||
stderr=subprocess.DEVNULL,
|
||||
timeout=5,
|
||||
check=True,
|
||||
)
|
||||
|
||||
if not dest.exists() or dest.stat().st_size == 0:
|
||||
@@ -276,6 +287,7 @@ def _convert_to_png(path: Path) -> bool:
|
||||
# Try Pillow first (likely installed in the venv)
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(path)
|
||||
img.save(path, "PNG")
|
||||
return True
|
||||
@@ -290,7 +302,8 @@ def _convert_to_png(path: Path) -> bool:
|
||||
path.rename(tmp)
|
||||
r = subprocess.run(
|
||||
["convert", str(tmp), "png:" + str(path)],
|
||||
capture_output=True, timeout=5,
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
tmp.unlink(missing_ok=True)
|
||||
if r.returncode == 0 and path.exists() and path.stat().st_size > 0:
|
||||
@@ -310,12 +323,15 @@ def _convert_to_png(path: Path) -> bool:
|
||||
|
||||
# ── X11 (xclip) ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _xclip_has_image() -> bool:
|
||||
"""Check if X11 clipboard has image content."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
return r.returncode == 0 and "image/png" in r.stdout
|
||||
except FileNotFoundError:
|
||||
@@ -331,7 +347,9 @@ def _xclip_save(dest: Path) -> bool:
|
||||
try:
|
||||
targets = subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3,
|
||||
)
|
||||
if "image/png" not in targets.stdout:
|
||||
return False
|
||||
@@ -346,7 +364,10 @@ def _xclip_save(dest: Path) -> bool:
|
||||
with open(dest, "wb") as f:
|
||||
subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "image/png", "-o"],
|
||||
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
|
||||
stdout=f,
|
||||
stderr=subprocess.DEVNULL,
|
||||
timeout=5,
|
||||
check=True,
|
||||
)
|
||||
if dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
|
||||
@@ -4,14 +4,12 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CODEX_MODELS: List[str] = [
|
||||
DEFAULT_CODEX_MODELS: list[str] = [
|
||||
"gpt-5.3-codex",
|
||||
"gpt-5.2-codex",
|
||||
"gpt-5.1-codex-max",
|
||||
@@ -19,10 +17,11 @@ DEFAULT_CODEX_MODELS: List[str] = [
|
||||
]
|
||||
|
||||
|
||||
def _fetch_models_from_api(access_token: str) -> List[str]:
|
||||
def _fetch_models_from_api(access_token: str) -> list[str]:
|
||||
"""Fetch available models from the Codex API. Returns visible models sorted by priority."""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
resp = httpx.get(
|
||||
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
@@ -57,7 +56,7 @@ def _fetch_models_from_api(access_token: str) -> List[str]:
|
||||
return [slug for _, slug in sortable]
|
||||
|
||||
|
||||
def _read_default_model(codex_home: Path) -> Optional[str]:
|
||||
def _read_default_model(codex_home: Path) -> str | None:
|
||||
config_path = codex_home / "config.toml"
|
||||
if not config_path.exists():
|
||||
return None
|
||||
@@ -75,7 +74,7 @@ def _read_default_model(codex_home: Path) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _read_cache_models(codex_home: Path) -> List[str]:
|
||||
def _read_cache_models(codex_home: Path) -> list[str]:
|
||||
cache_path = codex_home / "models_cache.json"
|
||||
if not cache_path.exists():
|
||||
return []
|
||||
@@ -104,22 +103,22 @@ def _read_cache_models(codex_home: Path) -> List[str]:
|
||||
sortable.append((rank, slug))
|
||||
|
||||
sortable.sort(key=lambda item: (item[0], item[1]))
|
||||
deduped: List[str] = []
|
||||
deduped: list[str] = []
|
||||
for _, slug in sortable:
|
||||
if slug not in deduped:
|
||||
deduped.append(slug)
|
||||
return deduped
|
||||
|
||||
|
||||
def get_codex_model_ids(access_token: Optional[str] = None) -> List[str]:
|
||||
def get_codex_model_ids(access_token: str | None = None) -> list[str]:
|
||||
"""Return available Codex model IDs, trying API first, then local sources.
|
||||
|
||||
|
||||
Resolution order: API (live, if token provided) > config.toml default >
|
||||
local cache > hardcoded defaults.
|
||||
"""
|
||||
codex_home_str = os.getenv("CODEX_HOME", "").strip() or str(Path.home() / ".codex")
|
||||
codex_home = Path(codex_home_str).expanduser()
|
||||
ordered: List[str] = []
|
||||
ordered: list[str] = []
|
||||
|
||||
# Try live API if we have a token
|
||||
if access_token:
|
||||
|
||||
@@ -12,7 +12,6 @@ from typing import Any
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
|
||||
COMMANDS = {
|
||||
"/help": "Show this help message",
|
||||
"/tools": "List available tools",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,46 +20,46 @@ from hermes_cli.colors import Colors, color
|
||||
def cron_list(show_all: bool = False):
|
||||
"""List all scheduled jobs."""
|
||||
from cron.jobs import list_jobs
|
||||
|
||||
|
||||
jobs = list_jobs(include_disabled=show_all)
|
||||
|
||||
|
||||
if not jobs:
|
||||
print(color("No scheduled jobs.", Colors.DIM))
|
||||
print(color("Create one with the /cron add command in chat, or via Telegram.", Colors.DIM))
|
||||
return
|
||||
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ Scheduled Jobs │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
print()
|
||||
|
||||
|
||||
for job in jobs:
|
||||
job_id = job.get("id", "?")[:8]
|
||||
name = job.get("name", "(unnamed)")
|
||||
schedule = job.get("schedule_display", job.get("schedule", {}).get("value", "?"))
|
||||
enabled = job.get("enabled", True)
|
||||
next_run = job.get("next_run_at", "?")
|
||||
|
||||
|
||||
repeat_info = job.get("repeat", {})
|
||||
repeat_times = repeat_info.get("times")
|
||||
repeat_completed = repeat_info.get("completed", 0)
|
||||
|
||||
|
||||
if repeat_times:
|
||||
repeat_str = f"{repeat_completed}/{repeat_times}"
|
||||
else:
|
||||
repeat_str = "∞"
|
||||
|
||||
|
||||
deliver = job.get("deliver", ["local"])
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
deliver_str = ", ".join(deliver)
|
||||
|
||||
|
||||
if not enabled:
|
||||
status = color("[disabled]", Colors.RED)
|
||||
else:
|
||||
status = color("[active]", Colors.GREEN)
|
||||
|
||||
|
||||
print(f" {color(job_id, Colors.YELLOW)} {status}")
|
||||
print(f" Name: {name}")
|
||||
print(f" Schedule: {schedule}")
|
||||
@@ -67,9 +67,10 @@ def cron_list(show_all: bool = False):
|
||||
print(f" Next run: {next_run}")
|
||||
print(f" Deliver: {deliver_str}")
|
||||
print()
|
||||
|
||||
|
||||
# Warn if gateway isn't running
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
|
||||
if not find_gateway_pids():
|
||||
print(color(" ⚠ Gateway is not running — jobs won't fire automatically.", Colors.YELLOW))
|
||||
print(color(" Start it with: hermes gateway install", Colors.DIM))
|
||||
@@ -79,6 +80,7 @@ def cron_list(show_all: bool = False):
|
||||
def cron_tick():
|
||||
"""Run due jobs once and exit."""
|
||||
from cron.scheduler import tick
|
||||
|
||||
tick(verbose=True)
|
||||
|
||||
|
||||
@@ -86,9 +88,9 @@ def cron_status():
|
||||
"""Show cron execution status."""
|
||||
from cron.jobs import list_jobs
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
pids = find_gateway_pids()
|
||||
if pids:
|
||||
print(color("✓ Gateway is running — cron jobs will fire automatically", Colors.GREEN))
|
||||
@@ -99,9 +101,9 @@ def cron_status():
|
||||
print(" To enable automatic execution:")
|
||||
print(" hermes gateway install # Install as system service (recommended)")
|
||||
print(" hermes gateway # Or run in foreground")
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
jobs = list_jobs(include_disabled=False)
|
||||
if jobs:
|
||||
next_runs = [j.get("next_run_at") for j in jobs if j.get("next_run_at")]
|
||||
@@ -110,24 +112,24 @@ def cron_status():
|
||||
print(f" Next run: {min(next_runs)}")
|
||||
else:
|
||||
print(" No active jobs")
|
||||
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def cron_command(args):
|
||||
"""Handle cron subcommands."""
|
||||
subcmd = getattr(args, 'cron_command', None)
|
||||
|
||||
subcmd = getattr(args, "cron_command", None)
|
||||
|
||||
if subcmd is None or subcmd == "list":
|
||||
show_all = getattr(args, 'all', False)
|
||||
show_all = getattr(args, "all", False)
|
||||
cron_list(show_all)
|
||||
|
||||
|
||||
elif subcmd == "tick":
|
||||
cron_tick()
|
||||
|
||||
|
||||
elif subcmd == "status":
|
||||
cron_status()
|
||||
|
||||
|
||||
else:
|
||||
print(f"Unknown cron command: {subcmd}")
|
||||
print("Usage: hermes cron [list|status|tick]")
|
||||
|
||||
@@ -5,18 +5,18 @@ Diagnoses issues with Hermes Agent setup.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from hermes_cli.config import get_project_root, get_hermes_home, get_env_path
|
||||
from hermes_cli.config import get_env_path, get_hermes_home, get_project_root
|
||||
|
||||
PROJECT_ROOT = get_project_root()
|
||||
HERMES_HOME = get_hermes_home()
|
||||
|
||||
# Load environment variables from ~/.hermes/.env so API key checks work
|
||||
from dotenv import load_dotenv
|
||||
|
||||
_env_path = get_env_path()
|
||||
if _env_path.exists():
|
||||
try:
|
||||
@@ -33,7 +33,6 @@ os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
|
||||
_PROVIDER_ENV_HINTS = (
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
@@ -56,35 +55,38 @@ def _has_provider_env_config(content: str) -> bool:
|
||||
def check_ok(text: str, detail: str = ""):
|
||||
print(f" {color('✓', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
|
||||
def check_warn(text: str, detail: str = ""):
|
||||
print(f" {color('⚠', Colors.YELLOW)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
|
||||
def check_fail(text: str, detail: str = ""):
|
||||
print(f" {color('✗', Colors.RED)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
|
||||
def check_info(text: str):
|
||||
print(f" {color('→', Colors.CYAN)} {text}")
|
||||
|
||||
|
||||
def run_doctor(args):
|
||||
"""Run diagnostic checks."""
|
||||
should_fix = getattr(args, 'fix', False)
|
||||
|
||||
should_fix = getattr(args, "fix", False)
|
||||
|
||||
issues = []
|
||||
manual_issues = [] # issues that can't be auto-fixed
|
||||
fixed_count = 0
|
||||
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ 🩺 Hermes Doctor │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Python version
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Python Environment", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
py_version = sys.version_info
|
||||
if py_version >= (3, 11):
|
||||
check_ok(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}")
|
||||
@@ -96,20 +98,20 @@ def run_doctor(args):
|
||||
else:
|
||||
check_fail(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}", "(3.10+ required)")
|
||||
issues.append("Upgrade Python to 3.10+")
|
||||
|
||||
|
||||
# Check if in virtual environment
|
||||
in_venv = sys.prefix != sys.base_prefix
|
||||
if in_venv:
|
||||
check_ok("Virtual environment active")
|
||||
else:
|
||||
check_warn("Not in virtual environment", "(recommended)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Required packages
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Required Packages", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
required_packages = [
|
||||
("openai", "OpenAI SDK"),
|
||||
("rich", "Rich (terminal UI)"),
|
||||
@@ -117,13 +119,13 @@ def run_doctor(args):
|
||||
("yaml", "PyYAML"),
|
||||
("httpx", "HTTPX"),
|
||||
]
|
||||
|
||||
|
||||
optional_packages = [
|
||||
("croniter", "Croniter (cron expressions)"),
|
||||
("telegram", "python-telegram-bot"),
|
||||
("discord", "discord.py"),
|
||||
]
|
||||
|
||||
|
||||
for module, name in required_packages:
|
||||
try:
|
||||
__import__(module)
|
||||
@@ -131,25 +133,25 @@ def run_doctor(args):
|
||||
except ImportError:
|
||||
check_fail(name, "(missing)")
|
||||
issues.append(f"Install {name}: uv pip install {module}")
|
||||
|
||||
|
||||
for module, name in optional_packages:
|
||||
try:
|
||||
__import__(module)
|
||||
check_ok(name, "(optional)")
|
||||
except ImportError:
|
||||
check_warn(name, "(optional, not installed)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Configuration files
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Configuration Files", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
# Check ~/.hermes/.env (primary location for user config)
|
||||
env_path = HERMES_HOME / '.env'
|
||||
env_path = HERMES_HOME / ".env"
|
||||
if env_path.exists():
|
||||
check_ok("~/.hermes/.env file exists")
|
||||
|
||||
|
||||
# Check for common issues
|
||||
content = env_path.read_text()
|
||||
if _has_provider_env_config(content):
|
||||
@@ -159,7 +161,7 @@ def run_doctor(args):
|
||||
issues.append("Run 'hermes setup' to configure API keys")
|
||||
else:
|
||||
# Also check project root as fallback
|
||||
fallback_env = PROJECT_ROOT / '.env'
|
||||
fallback_env = PROJECT_ROOT / ".env"
|
||||
if fallback_env.exists():
|
||||
check_ok(".env file exists (in project directory)")
|
||||
else:
|
||||
@@ -173,17 +175,17 @@ def run_doctor(args):
|
||||
else:
|
||||
check_info("Run 'hermes setup' to create one")
|
||||
issues.append("Run 'hermes setup' to create .env")
|
||||
|
||||
|
||||
# Check ~/.hermes/config.yaml (primary) or project cli-config.yaml (fallback)
|
||||
config_path = HERMES_HOME / 'config.yaml'
|
||||
config_path = HERMES_HOME / "config.yaml"
|
||||
if config_path.exists():
|
||||
check_ok("~/.hermes/config.yaml exists")
|
||||
else:
|
||||
fallback_config = PROJECT_ROOT / 'cli-config.yaml'
|
||||
fallback_config = PROJECT_ROOT / "cli-config.yaml"
|
||||
if fallback_config.exists():
|
||||
check_ok("cli-config.yaml exists (in project directory)")
|
||||
else:
|
||||
example_config = PROJECT_ROOT / 'cli-config.yaml.example'
|
||||
example_config = PROJECT_ROOT / "cli-config.yaml.example"
|
||||
if should_fix and example_config.exists():
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(str(example_config), str(config_path))
|
||||
@@ -194,7 +196,7 @@ def run_doctor(args):
|
||||
manual_issues.append("Create ~/.hermes/config.yaml manually")
|
||||
else:
|
||||
check_warn("config.yaml not found", "(using defaults)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Auth providers
|
||||
# =========================================================================
|
||||
@@ -202,7 +204,7 @@ def run_doctor(args):
|
||||
print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status
|
||||
from hermes_cli.auth import get_codex_auth_status, get_nous_auth_status
|
||||
|
||||
nous_status = get_nous_auth_status()
|
||||
if nous_status.get("logged_in"):
|
||||
@@ -230,7 +232,7 @@ def run_doctor(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Directory Structure", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
hermes_home = HERMES_HOME
|
||||
if hermes_home.exists():
|
||||
check_ok("~/.hermes directory exists")
|
||||
@@ -241,7 +243,7 @@ def run_doctor(args):
|
||||
fixed_count += 1
|
||||
else:
|
||||
check_warn("~/.hermes not found", "(will be created on first use)")
|
||||
|
||||
|
||||
# Check expected subdirectories
|
||||
expected_subdirs = ["cron", "sessions", "logs", "skills", "memories"]
|
||||
for subdir_name in expected_subdirs:
|
||||
@@ -255,7 +257,7 @@ def run_doctor(args):
|
||||
fixed_count += 1
|
||||
else:
|
||||
check_warn(f"~/.hermes/{subdir_name}/ not found", "(will be created on first use)")
|
||||
|
||||
|
||||
# Check for SOUL.md persona file
|
||||
soul_path = hermes_home / "SOUL.md"
|
||||
if soul_path.exists():
|
||||
@@ -278,7 +280,7 @@ def run_doctor(args):
|
||||
)
|
||||
check_ok("Created ~/.hermes/SOUL.md with basic template")
|
||||
fixed_count += 1
|
||||
|
||||
|
||||
# Check memory directory
|
||||
memories_dir = hermes_home / "memories"
|
||||
if memories_dir.exists():
|
||||
@@ -301,12 +303,13 @@ def run_doctor(args):
|
||||
memories_dir.mkdir(parents=True, exist_ok=True)
|
||||
check_ok("Created ~/.hermes/memories/")
|
||||
fixed_count += 1
|
||||
|
||||
|
||||
# Check SQLite session store
|
||||
state_db_path = hermes_home / "state.db"
|
||||
if state_db_path.exists():
|
||||
try:
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(str(state_db_path))
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
count = cursor.fetchone()[0]
|
||||
@@ -316,26 +319,26 @@ def run_doctor(args):
|
||||
check_warn(f"~/.hermes/state.db exists but has issues: {e}")
|
||||
else:
|
||||
check_info("~/.hermes/state.db not created yet (will be created on first session)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: External tools
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ External Tools", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
# Git
|
||||
if shutil.which("git"):
|
||||
check_ok("git")
|
||||
else:
|
||||
check_warn("git not found", "(optional)")
|
||||
|
||||
|
||||
# ripgrep (optional, for faster file search)
|
||||
if shutil.which("rg"):
|
||||
check_ok("ripgrep (rg)", "(faster file search)")
|
||||
else:
|
||||
check_warn("ripgrep (rg) not found", "(file search uses grep fallback)")
|
||||
check_info("Install for faster search: sudo apt install ripgrep")
|
||||
|
||||
|
||||
# Docker (optional)
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "local")
|
||||
if terminal_env == "docker":
|
||||
@@ -355,7 +358,7 @@ def run_doctor(args):
|
||||
check_ok("docker", "(optional)")
|
||||
else:
|
||||
check_warn("docker not found", "(optional)")
|
||||
|
||||
|
||||
# SSH (if using ssh backend)
|
||||
if terminal_env == "ssh":
|
||||
ssh_host = os.getenv("TERMINAL_SSH_HOST")
|
||||
@@ -364,7 +367,7 @@ def run_doctor(args):
|
||||
result = subprocess.run(
|
||||
["ssh", "-o", "ConnectTimeout=5", "-o", "BatchMode=yes", ssh_host, "echo ok"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
check_ok(f"SSH connection to {ssh_host}")
|
||||
@@ -374,7 +377,7 @@ def run_doctor(args):
|
||||
else:
|
||||
check_fail("TERMINAL_SSH_HOST not set", "(required for TERMINAL_ENV=ssh)")
|
||||
issues.append("Set TERMINAL_SSH_HOST in .env")
|
||||
|
||||
|
||||
# Daytona (if using daytona backend)
|
||||
if terminal_env == "daytona":
|
||||
daytona_key = os.getenv("DAYTONA_API_KEY")
|
||||
@@ -385,6 +388,7 @@ def run_doctor(args):
|
||||
issues.append("Set DAYTONA_API_KEY environment variable")
|
||||
try:
|
||||
from daytona import Daytona
|
||||
|
||||
check_ok("daytona SDK", "(installed)")
|
||||
except ImportError:
|
||||
check_fail("daytona SDK not installed", "(pip install daytona)")
|
||||
@@ -401,7 +405,7 @@ def run_doctor(args):
|
||||
check_warn("agent-browser not installed", "(run: npm install)")
|
||||
else:
|
||||
check_warn("Node.js not found", "(optional, needed for browser tools)")
|
||||
|
||||
|
||||
# npm audit for all Node.js packages
|
||||
if shutil.which("npm"):
|
||||
npm_dirs = [
|
||||
@@ -415,9 +419,12 @@ def run_doctor(args):
|
||||
audit_result = subprocess.run(
|
||||
["npm", "audit", "--json"],
|
||||
cwd=str(npm_dir),
|
||||
capture_output=True, text=True, timeout=30,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
import json as _json
|
||||
|
||||
audit_data = _json.loads(audit_result.stdout) if audit_result.stdout.strip() else {}
|
||||
vuln_count = audit_data.get("metadata", {}).get("vulnerabilities", {})
|
||||
critical = vuln_count.get("critical", 0)
|
||||
@@ -429,7 +436,7 @@ def run_doctor(args):
|
||||
elif critical > 0 or high > 0:
|
||||
check_warn(
|
||||
f"{label} deps",
|
||||
f"({critical} critical, {high} high, {moderate} moderate — run: cd {npm_dir} && npm audit fix)"
|
||||
f"({critical} critical, {high} high, {moderate} moderate — run: cd {npm_dir} && npm audit fix)",
|
||||
)
|
||||
issues.append(f"{label} has {total} npm vulnerability(ies)")
|
||||
else:
|
||||
@@ -442,47 +449,50 @@ def run_doctor(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API Connectivity", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if openrouter_key:
|
||||
print(" Checking OpenRouter API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(
|
||||
OPENROUTER_MODELS_URL,
|
||||
headers={"Authorization": f"Bearer {openrouter_key}"},
|
||||
timeout=10
|
||||
OPENROUTER_MODELS_URL, headers={"Authorization": f"Bearer {openrouter_key}"}, timeout=10
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"\r {color('✓', Colors.GREEN)} OpenRouter API ")
|
||||
elif response.status_code == 401:
|
||||
print(f"\r {color('✗', Colors.RED)} OpenRouter API {color('(invalid API key)', Colors.DIM)} ")
|
||||
print(
|
||||
f"\r {color('✗', Colors.RED)} OpenRouter API {color('(invalid API key)', Colors.DIM)} "
|
||||
)
|
||||
issues.append("Check OPENROUTER_API_KEY in .env")
|
||||
else:
|
||||
print(f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'(HTTP {response.status_code})', Colors.DIM)} ")
|
||||
print(
|
||||
f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'(HTTP {response.status_code})', Colors.DIM)} "
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\r {color('✗', Colors.RED)} OpenRouter API {color(f'({e})', Colors.DIM)} ")
|
||||
issues.append("Check network connectivity")
|
||||
else:
|
||||
check_warn("OpenRouter API", "(not configured)")
|
||||
|
||||
|
||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if anthropic_key:
|
||||
print(" Checking Anthropic API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": anthropic_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
},
|
||||
timeout=10
|
||||
headers={"x-api-key": anthropic_key, "anthropic-version": "2023-06-01"},
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
print(f"\r {color('✓', Colors.GREEN)} Anthropic API ")
|
||||
elif response.status_code == 401:
|
||||
print(f"\r {color('✗', Colors.RED)} Anthropic API {color('(invalid API key)', Colors.DIM)} ")
|
||||
print(
|
||||
f"\r {color('✗', Colors.RED)} Anthropic API {color('(invalid API key)', Colors.DIM)} "
|
||||
)
|
||||
else:
|
||||
msg = "(couldn't verify)"
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ")
|
||||
@@ -491,10 +501,15 @@ def run_doctor(args):
|
||||
|
||||
# -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) --
|
||||
_apikey_providers = [
|
||||
("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"),
|
||||
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
|
||||
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
|
||||
(
|
||||
"Z.AI / GLM",
|
||||
("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"https://api.z.ai/api/paas/v4/models",
|
||||
"GLM_BASE_URL",
|
||||
),
|
||||
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
|
||||
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
|
||||
]
|
||||
for _pname, _env_vars, _default_url, _base_env in _apikey_providers:
|
||||
_key = ""
|
||||
@@ -507,6 +522,7 @@ def run_doctor(args):
|
||||
print(f" Checking {_pname} API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
_base = os.getenv(_base_env, "")
|
||||
# Auto-detect Kimi Code keys (sk-kimi-) → api.kimi.com
|
||||
if not _base and _key.startswith("sk-kimi-"):
|
||||
@@ -526,7 +542,9 @@ def run_doctor(args):
|
||||
print(f"\r {color('✗', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ")
|
||||
issues.append(f"Check {_env_vars[0]} in .env")
|
||||
else:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ")
|
||||
print(
|
||||
f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} "
|
||||
)
|
||||
except Exception as _e:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ")
|
||||
|
||||
@@ -535,7 +553,7 @@ def run_doctor(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Submodules", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
# mini-swe-agent (terminal tool backend)
|
||||
mini_swe_dir = PROJECT_ROOT / "mini-swe-agent"
|
||||
if mini_swe_dir.exists() and (mini_swe_dir / "pyproject.toml").exists():
|
||||
@@ -547,7 +565,7 @@ def run_doctor(args):
|
||||
issues.append("Install mini-swe-agent: uv pip install -e ./mini-swe-agent")
|
||||
else:
|
||||
check_warn("mini-swe-agent not found", "(run: git submodule update --init --recursive)")
|
||||
|
||||
|
||||
# tinker-atropos (RL training backend)
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
@@ -562,24 +580,24 @@ def run_doctor(args):
|
||||
check_warn("tinker-atropos requires Python 3.11+", f"(current: {py_version.major}.{py_version.minor})")
|
||||
else:
|
||||
check_warn("tinker-atropos not found", "(run: git submodule update --init --recursive)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Tool Availability
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Tool Availability", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
try:
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
|
||||
|
||||
from model_tools import TOOLSET_REQUIREMENTS, check_tool_availability
|
||||
|
||||
available, unavailable = check_tool_availability()
|
||||
|
||||
|
||||
for tid in available:
|
||||
info = TOOLSET_REQUIREMENTS.get(tid, {})
|
||||
check_ok(info.get("name", tid))
|
||||
|
||||
|
||||
for item in unavailable:
|
||||
env_vars = item.get("missing_vars") or item.get("env_vars") or []
|
||||
if env_vars:
|
||||
@@ -594,7 +612,7 @@ def run_doctor(args):
|
||||
issues.append("Run 'hermes setup' to configure missing API keys for full tool access")
|
||||
except Exception as e:
|
||||
check_warn("Could not check tool availability", f"({e})")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Check: Skills Hub
|
||||
# =========================================================================
|
||||
@@ -608,6 +626,7 @@ def run_doctor(args):
|
||||
if lock_file.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
lock_data = json.loads(lock_file.read_text())
|
||||
count = len(lock_data.get("installed", {}))
|
||||
check_ok(f"Lock file OK ({count} hub-installed skill(s))")
|
||||
@@ -621,6 +640,7 @@ def run_doctor(args):
|
||||
check_warn("Skills Hub directory not initialized", "(run: hermes skills list)")
|
||||
|
||||
from hermes_cli.config import get_env_value
|
||||
|
||||
github_token = get_env_value("GITHUB_TOKEN") or get_env_value("GH_TOKEN")
|
||||
if github_token:
|
||||
check_ok("GitHub token configured (authenticated API access)")
|
||||
@@ -656,5 +676,5 @@ def run_doctor(args):
|
||||
else:
|
||||
print(color("─" * 60, Colors.GREEN))
|
||||
print(color(" All checks passed! 🎉", Colors.GREEN, Colors.BOLD))
|
||||
|
||||
|
||||
print()
|
||||
|
||||
@@ -13,18 +13,24 @@ from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.config import get_env_value, save_env_value
|
||||
from hermes_cli.setup import (
|
||||
print_header, print_info, print_success, print_warning, print_error,
|
||||
prompt, prompt_choice, prompt_yes_no,
|
||||
print_error,
|
||||
print_header,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
prompt,
|
||||
prompt_choice,
|
||||
prompt_yes_no,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Process Management (for manual gateway runs)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def find_gateway_pids() -> list:
|
||||
"""Find PIDs of running gateway processes."""
|
||||
pids = []
|
||||
@@ -38,17 +44,16 @@ def find_gateway_pids() -> list:
|
||||
if is_windows():
|
||||
# Windows: use wmic to search command lines
|
||||
result = subprocess.run(
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
|
||||
capture_output=True, text=True
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True
|
||||
)
|
||||
# Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n"
|
||||
current_cmd = ""
|
||||
for line in result.stdout.split('\n'):
|
||||
for line in result.stdout.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("CommandLine="):
|
||||
current_cmd = line[len("CommandLine="):]
|
||||
current_cmd = line[len("CommandLine=") :]
|
||||
elif line.startswith("ProcessId="):
|
||||
pid_str = line[len("ProcessId="):]
|
||||
pid_str = line[len("ProcessId=") :]
|
||||
if any(p in current_cmd for p in patterns):
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
@@ -58,14 +63,10 @@ def find_gateway_pids() -> list:
|
||||
pass
|
||||
current_cmd = ""
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["ps", "aux"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
result = subprocess.run(["ps", "aux"], capture_output=True, text=True)
|
||||
for line in result.stdout.split("\n"):
|
||||
# Skip grep and current process
|
||||
if 'grep' in line or str(os.getpid()) in line:
|
||||
if "grep" in line or str(os.getpid()) in line:
|
||||
continue
|
||||
for pattern in patterns:
|
||||
if pattern in line:
|
||||
@@ -88,7 +89,7 @@ def kill_gateway_processes(force: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed."""
|
||||
pids = find_gateway_pids()
|
||||
killed = 0
|
||||
|
||||
|
||||
for pid in pids:
|
||||
try:
|
||||
if force and not is_windows():
|
||||
@@ -101,18 +102,20 @@ def kill_gateway_processes(force: bool = False) -> int:
|
||||
pass
|
||||
except PermissionError:
|
||||
print(f"⚠ Permission denied to kill PID {pid}")
|
||||
|
||||
|
||||
return killed
|
||||
|
||||
|
||||
def is_linux() -> bool:
|
||||
return sys.platform.startswith('linux')
|
||||
return sys.platform.startswith("linux")
|
||||
|
||||
|
||||
def is_macos() -> bool:
|
||||
return sys.platform == 'darwin'
|
||||
return sys.platform == "darwin"
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
return sys.platform == 'win32'
|
||||
return sys.platform == "win32"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -122,12 +125,15 @@ def is_windows() -> bool:
|
||||
SERVICE_NAME = "hermes-gateway"
|
||||
SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
|
||||
|
||||
def get_systemd_unit_path() -> Path:
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service"
|
||||
|
||||
|
||||
def get_launchd_plist_path() -> Path:
|
||||
return Path.home() / "Library" / "LaunchAgents" / "ai.hermes.gateway.plist"
|
||||
|
||||
|
||||
def get_python_path() -> str:
|
||||
if is_windows():
|
||||
venv_python = PROJECT_ROOT / "venv" / "Scripts" / "python.exe"
|
||||
@@ -137,14 +143,16 @@ def get_python_path() -> str:
|
||||
return str(venv_python)
|
||||
return sys.executable
|
||||
|
||||
|
||||
def get_hermes_cli_path() -> str:
|
||||
"""Get the path to the hermes CLI."""
|
||||
# Check if installed via pip
|
||||
import shutil
|
||||
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if hermes_bin:
|
||||
return hermes_bin
|
||||
|
||||
|
||||
# Fallback to direct module execution
|
||||
return f"{get_python_path()} -m hermes_cli.main"
|
||||
|
||||
@@ -153,8 +161,10 @@ def get_hermes_cli_path() -> str:
|
||||
# Systemd (Linux)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def generate_systemd_unit() -> str:
|
||||
import shutil
|
||||
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
@@ -163,7 +173,7 @@ def generate_systemd_unit() -> str:
|
||||
|
||||
# Build a PATH that includes the venv, node_modules, and standard system dirs
|
||||
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
||||
|
||||
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
@@ -188,56 +198,62 @@ StandardError=journal
|
||||
WantedBy=default.target
|
||||
"""
|
||||
|
||||
|
||||
def systemd_install(force: bool = False):
|
||||
unit_path = get_systemd_unit_path()
|
||||
|
||||
|
||||
if unit_path.exists() and not force:
|
||||
print(f"Service already installed at: {unit_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
|
||||
|
||||
unit_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Installing systemd service to: {unit_path}")
|
||||
unit_path.write_text(generate_systemd_unit())
|
||||
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
subprocess.run(["systemctl", "--user", "enable", SERVICE_NAME], check=True)
|
||||
|
||||
|
||||
print()
|
||||
print("✓ Service installed and enabled!")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(f" hermes gateway start # Start the service")
|
||||
print(f" hermes gateway status # Check status")
|
||||
print(" hermes gateway start # Start the service")
|
||||
print(" hermes gateway status # Check status")
|
||||
print(f" journalctl --user -u {SERVICE_NAME} -f # View logs")
|
||||
print()
|
||||
print("To enable lingering (keeps running after logout):")
|
||||
print(" sudo loginctl enable-linger $USER")
|
||||
|
||||
|
||||
def systemd_uninstall():
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=False)
|
||||
subprocess.run(["systemctl", "--user", "disable", SERVICE_NAME], check=False)
|
||||
|
||||
|
||||
unit_path = get_systemd_unit_path()
|
||||
if unit_path.exists():
|
||||
unit_path.unlink()
|
||||
print(f"✓ Removed {unit_path}")
|
||||
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
|
||||
def systemd_start():
|
||||
subprocess.run(["systemctl", "--user", "start", SERVICE_NAME], check=True)
|
||||
print("✓ Service started")
|
||||
|
||||
|
||||
def systemd_stop():
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=True)
|
||||
print("✓ Service stopped")
|
||||
|
||||
|
||||
def systemd_restart():
|
||||
subprocess.run(["systemctl", "--user", "restart", SERVICE_NAME], check=True)
|
||||
print("✓ Service restarted")
|
||||
|
||||
|
||||
def systemd_status(deep: bool = False):
|
||||
# Check if service unit file exists
|
||||
unit_path = get_systemd_unit_path()
|
||||
@@ -245,54 +261,45 @@ def systemd_status(deep: bool = False):
|
||||
print("✗ Gateway service is not installed")
|
||||
print(" Run: hermes gateway install")
|
||||
return
|
||||
|
||||
|
||||
# Show detailed status first
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"],
|
||||
capture_output=False
|
||||
)
|
||||
|
||||
subprocess.run(["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"], capture_output=False)
|
||||
|
||||
# Check if service is active
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", SERVICE_NAME],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
result = subprocess.run(["systemctl", "--user", "is-active", SERVICE_NAME], capture_output=True, text=True)
|
||||
|
||||
status = result.stdout.strip()
|
||||
|
||||
|
||||
if status == "active":
|
||||
print("✓ Gateway service is running")
|
||||
else:
|
||||
print("✗ Gateway service is stopped")
|
||||
print(" Run: hermes gateway start")
|
||||
|
||||
|
||||
if deep:
|
||||
print()
|
||||
print("Recent logs:")
|
||||
subprocess.run([
|
||||
"journalctl", "--user", "-u", SERVICE_NAME,
|
||||
"-n", "20", "--no-pager"
|
||||
])
|
||||
subprocess.run(["journalctl", "--user", "-u", SERVICE_NAME, "-n", "20", "--no-pager"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Launchd (macOS)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def generate_launchd_plist() -> str:
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
log_dir = Path.home() / ".hermes" / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
return f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>ai.hermes.gateway</string>
|
||||
|
||||
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{python_path}</string>
|
||||
@@ -301,42 +308,43 @@ def generate_launchd_plist() -> str:
|
||||
<string>gateway</string>
|
||||
<string>run</string>
|
||||
</array>
|
||||
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
<string>{working_dir}</string>
|
||||
|
||||
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
|
||||
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
</dict>
|
||||
|
||||
|
||||
<key>StandardOutPath</key>
|
||||
<string>{log_dir}/gateway.log</string>
|
||||
|
||||
|
||||
<key>StandardErrorPath</key>
|
||||
<string>{log_dir}/gateway.error.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
"""
|
||||
|
||||
|
||||
def launchd_install(force: bool = False):
|
||||
plist_path = get_launchd_plist_path()
|
||||
|
||||
|
||||
if plist_path.exists() and not force:
|
||||
print(f"Service already installed at: {plist_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
|
||||
|
||||
plist_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Installing launchd service to: {plist_path}")
|
||||
plist_path.write_text(generate_launchd_plist())
|
||||
|
||||
|
||||
subprocess.run(["launchctl", "load", str(plist_path)], check=True)
|
||||
|
||||
|
||||
print()
|
||||
print("✓ Service installed and loaded!")
|
||||
print()
|
||||
@@ -344,41 +352,42 @@ def launchd_install(force: bool = False):
|
||||
print(" hermes gateway status # Check status")
|
||||
print(" tail -f ~/.hermes/logs/gateway.log # View logs")
|
||||
|
||||
|
||||
def launchd_uninstall():
|
||||
plist_path = get_launchd_plist_path()
|
||||
subprocess.run(["launchctl", "unload", str(plist_path)], check=False)
|
||||
|
||||
|
||||
if plist_path.exists():
|
||||
plist_path.unlink()
|
||||
print(f"✓ Removed {plist_path}")
|
||||
|
||||
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
|
||||
def launchd_start():
|
||||
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service started")
|
||||
|
||||
|
||||
def launchd_stop():
|
||||
subprocess.run(["launchctl", "stop", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service stopped")
|
||||
|
||||
|
||||
def launchd_restart():
|
||||
launchd_stop()
|
||||
launchd_start()
|
||||
|
||||
|
||||
def launchd_status(deep: bool = False):
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("✓ Gateway service is loaded")
|
||||
print(result.stdout)
|
||||
else:
|
||||
print("✗ Gateway service is not loaded")
|
||||
|
||||
|
||||
if deep:
|
||||
log_file = Path.home() / ".hermes" / "logs" / "gateway.log"
|
||||
if log_file.exists():
|
||||
@@ -391,9 +400,10 @@ def launchd_status(deep: bool = False):
|
||||
# Gateway Runner
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
"""Run the gateway in foreground.
|
||||
|
||||
|
||||
Args:
|
||||
verbose: Enable verbose logging output.
|
||||
replace: If True, kill any existing gateway instance before starting.
|
||||
@@ -401,9 +411,9 @@ def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
hasn't fully exited yet.
|
||||
"""
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
from gateway.run import start_gateway
|
||||
|
||||
|
||||
print("┌─────────────────────────────────────────────────────────┐")
|
||||
print("│ ⚕ Hermes Gateway Starting... │")
|
||||
print("├─────────────────────────────────────────────────────────┤")
|
||||
@@ -411,7 +421,7 @@ def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
print("│ Press Ctrl+C to stop │")
|
||||
print("└─────────────────────────────────────────────────────────┘")
|
||||
print()
|
||||
|
||||
|
||||
# Exit with code 1 if gateway fails to connect any platform,
|
||||
# so systemd Restart=on-failure will retry on transient errors
|
||||
success = asyncio.run(start_gateway(replace=replace))
|
||||
@@ -438,13 +448,25 @@ _PLATFORMS = [
|
||||
"4. To find your user ID: message @userinfobot — it replies with your numeric ID",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "TELEGRAM_BOT_TOKEN", "prompt": "Bot token", "password": True,
|
||||
"help": "Paste the token from @BotFather (step 3 above)."},
|
||||
{"name": "TELEGRAM_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your user ID from step 4 above."},
|
||||
{"name": "TELEGRAM_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
|
||||
"help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat."},
|
||||
{
|
||||
"name": "TELEGRAM_BOT_TOKEN",
|
||||
"prompt": "Bot token",
|
||||
"password": True,
|
||||
"help": "Paste the token from @BotFather (step 3 above).",
|
||||
},
|
||||
{
|
||||
"name": "TELEGRAM_ALLOWED_USERS",
|
||||
"prompt": "Allowed user IDs (comma-separated)",
|
||||
"password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your user ID from step 4 above.",
|
||||
},
|
||||
{
|
||||
"name": "TELEGRAM_HOME_CHANNEL",
|
||||
"prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)",
|
||||
"password": False,
|
||||
"help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -466,13 +488,25 @@ _PLATFORMS = [
|
||||
" then right-click your name → Copy ID",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "DISCORD_BOT_TOKEN", "prompt": "Bot token", "password": True,
|
||||
"help": "Paste the token from step 2 above."},
|
||||
{"name": "DISCORD_ALLOWED_USERS", "prompt": "Allowed user IDs or usernames (comma-separated)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your user ID from step 5 above."},
|
||||
{"name": "DISCORD_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
|
||||
"help": "Right-click a channel → Copy Channel ID (requires Developer Mode)."},
|
||||
{
|
||||
"name": "DISCORD_BOT_TOKEN",
|
||||
"prompt": "Bot token",
|
||||
"password": True,
|
||||
"help": "Paste the token from step 2 above.",
|
||||
},
|
||||
{
|
||||
"name": "DISCORD_ALLOWED_USERS",
|
||||
"prompt": "Allowed user IDs or usernames (comma-separated)",
|
||||
"password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your user ID from step 5 above.",
|
||||
},
|
||||
{
|
||||
"name": "DISCORD_HOME_CHANNEL",
|
||||
"prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)",
|
||||
"password": False,
|
||||
"help": "Right-click a channel → Copy Channel ID (requires Developer Mode).",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -497,13 +531,25 @@ _PLATFORMS = [
|
||||
"8. Invite the bot to channels: /invite @YourBot",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "SLACK_BOT_TOKEN", "prompt": "Bot Token (xoxb-...)", "password": True,
|
||||
"help": "Paste the bot token from step 3 above."},
|
||||
{"name": "SLACK_APP_TOKEN", "prompt": "App Token (xapp-...)", "password": True,
|
||||
"help": "Paste the app-level token from step 4 above."},
|
||||
{"name": "SLACK_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your member ID from step 7 above."},
|
||||
{
|
||||
"name": "SLACK_BOT_TOKEN",
|
||||
"prompt": "Bot Token (xoxb-...)",
|
||||
"password": True,
|
||||
"help": "Paste the bot token from step 3 above.",
|
||||
},
|
||||
{
|
||||
"name": "SLACK_APP_TOKEN",
|
||||
"prompt": "App Token (xapp-...)",
|
||||
"password": True,
|
||||
"help": "Paste the app-level token from step 4 above.",
|
||||
},
|
||||
{
|
||||
"name": "SLACK_ALLOWED_USERS",
|
||||
"prompt": "Allowed user IDs (comma-separated)",
|
||||
"password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Paste your member ID from step 7 above.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
@@ -582,14 +628,14 @@ def _setup_standard_platform(platform: dict):
|
||||
|
||||
# Allowlist fields get special handling for the deny-by-default security model
|
||||
if var.get("is_allowlist"):
|
||||
print_info(f" The gateway DENIES all users by default for security.")
|
||||
print_info(f" Enter user IDs to create an allowlist, or leave empty")
|
||||
print_info(f" and you'll be asked about open access next.")
|
||||
print_info(" The gateway DENIES all users by default for security.")
|
||||
print_info(" Enter user IDs to create an allowlist, or leave empty")
|
||||
print_info(" and you'll be asked about open access next.")
|
||||
value = prompt(f" {var['prompt']}", password=False)
|
||||
if value:
|
||||
cleaned = value.replace(" ", "")
|
||||
save_env_value(var["name"], cleaned)
|
||||
print_success(f" Saved — only these users can interact with the bot.")
|
||||
print_success(" Saved — only these users can interact with the bot.")
|
||||
allowed_val_set = cleaned
|
||||
else:
|
||||
# No allowlist — ask about open access vs DM pairing
|
||||
@@ -618,7 +664,7 @@ def _setup_standard_platform(platform: dict):
|
||||
print_warning(f" Skipped — {label} won't work without this.")
|
||||
return
|
||||
else:
|
||||
print_info(f" Skipped (can configure later)")
|
||||
print_info(" Skipped (can configure later)")
|
||||
|
||||
# If an allowlist was set and home channel wasn't, offer to reuse
|
||||
# the first user ID (common for Telegram DMs).
|
||||
@@ -636,8 +682,10 @@ def _setup_standard_platform(platform: dict):
|
||||
|
||||
def _setup_whatsapp():
|
||||
"""Delegate to the existing WhatsApp setup flow."""
|
||||
from hermes_cli.main import cmd_whatsapp
|
||||
import argparse
|
||||
|
||||
from hermes_cli.main import cmd_whatsapp
|
||||
|
||||
cmd_whatsapp(argparse.Namespace())
|
||||
|
||||
|
||||
@@ -653,16 +701,10 @@ def _is_service_installed() -> bool:
|
||||
def _is_service_running() -> bool:
|
||||
"""Check if the gateway service is currently running."""
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", SERVICE_NAME],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
result = subprocess.run(["systemctl", "--user", "is-active", SERVICE_NAME], capture_output=True, text=True)
|
||||
return result.stdout.strip() == "active"
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
|
||||
return result.returncode == 0
|
||||
# Check for manual processes
|
||||
return len(find_gateway_pids()) > 0
|
||||
@@ -697,7 +739,7 @@ def _setup_signal():
|
||||
print_info(" Docker: bbernhard/signal-cli-rest-api")
|
||||
print()
|
||||
print_info(" After installing, link your account and start the daemon:")
|
||||
print_info(" signal-cli link -n \"HermesAgent\"")
|
||||
print_info(' signal-cli link -n "HermesAgent"')
|
||||
print_info(" signal-cli --account +YOURNUMBER daemon --http 127.0.0.1:8080")
|
||||
print()
|
||||
|
||||
@@ -715,6 +757,7 @@ def _setup_signal():
|
||||
print_info(" Testing connection...")
|
||||
try:
|
||||
import httpx
|
||||
|
||||
resp = httpx.get(f"{url.rstrip('/')}/api/v1/check", timeout=10.0)
|
||||
if resp.status_code == 200:
|
||||
print_success(" signal-cli daemon is reachable!")
|
||||
@@ -779,7 +822,7 @@ def _setup_signal():
|
||||
print_success("Signal configured!")
|
||||
print_info(f" URL: {url}")
|
||||
print_info(f" Account: {account}")
|
||||
print_info(f" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
|
||||
print_info(" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
|
||||
print_info(f" Groups: {'enabled' if get_env_value('SIGNAL_GROUP_ALLOWED_USERS') else 'disabled'}")
|
||||
|
||||
|
||||
@@ -841,11 +884,10 @@ def gateway_setup():
|
||||
_setup_standard_platform(platform)
|
||||
|
||||
# ── Post-setup: offer to install/restart gateway ──
|
||||
any_configured = any(
|
||||
bool(get_env_value(p["token_var"]))
|
||||
for p in _PLATFORMS
|
||||
if p["key"] != "whatsapp"
|
||||
) or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true"
|
||||
any_configured = (
|
||||
any(bool(get_env_value(p["token_var"])) for p in _PLATFORMS if p["key"] != "whatsapp")
|
||||
or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true"
|
||||
)
|
||||
|
||||
if any_configured:
|
||||
print()
|
||||
@@ -878,7 +920,9 @@ def gateway_setup():
|
||||
print()
|
||||
if is_linux() or is_macos():
|
||||
platform_name = "systemd" if is_linux() else "launchd"
|
||||
if prompt_yes_no(f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True):
|
||||
if prompt_yes_no(
|
||||
f" Install the gateway as a {platform_name} service? (runs in background, starts on boot)", True
|
||||
):
|
||||
try:
|
||||
force = False
|
||||
if is_linux():
|
||||
@@ -914,14 +958,15 @@ def gateway_setup():
|
||||
# Main Command Handler
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def gateway_command(args):
|
||||
"""Handle gateway subcommands."""
|
||||
subcmd = getattr(args, 'gateway_command', None)
|
||||
|
||||
subcmd = getattr(args, "gateway_command", None)
|
||||
|
||||
# Default to run if no subcommand
|
||||
if subcmd is None or subcmd == "run":
|
||||
verbose = getattr(args, 'verbose', False)
|
||||
replace = getattr(args, 'replace', False)
|
||||
verbose = getattr(args, "verbose", False)
|
||||
replace = getattr(args, "replace", False)
|
||||
run_gateway(verbose, replace=replace)
|
||||
return
|
||||
|
||||
@@ -931,7 +976,7 @@ def gateway_command(args):
|
||||
|
||||
# Service management commands
|
||||
if subcmd == "install":
|
||||
force = getattr(args, 'force', False)
|
||||
force = getattr(args, "force", False)
|
||||
if is_linux():
|
||||
systemd_install(force)
|
||||
elif is_macos():
|
||||
@@ -940,7 +985,7 @@ def gateway_command(args):
|
||||
print("Service installation not supported on this platform.")
|
||||
print("Run manually: hermes gateway run")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
elif subcmd == "uninstall":
|
||||
if is_linux():
|
||||
systemd_uninstall()
|
||||
@@ -949,7 +994,7 @@ def gateway_command(args):
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
elif subcmd == "start":
|
||||
if is_linux():
|
||||
systemd_start()
|
||||
@@ -958,11 +1003,11 @@ def gateway_command(args):
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
elif subcmd == "stop":
|
||||
# Try service first, fall back to killing processes directly
|
||||
service_available = False
|
||||
|
||||
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
try:
|
||||
systemd_stop()
|
||||
@@ -975,7 +1020,7 @@ def gateway_command(args):
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
if not service_available:
|
||||
# Kill gateway processes directly
|
||||
killed = kill_gateway_processes()
|
||||
@@ -983,11 +1028,11 @@ def gateway_command(args):
|
||||
print(f"✓ Stopped {killed} gateway process(es)")
|
||||
else:
|
||||
print("✗ No gateway processes found")
|
||||
|
||||
|
||||
elif subcmd == "restart":
|
||||
# Try service first, fall back to killing and restarting
|
||||
service_available = False
|
||||
|
||||
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
try:
|
||||
systemd_restart()
|
||||
@@ -1000,23 +1045,24 @@ def gateway_command(args):
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
|
||||
if not service_available:
|
||||
# Manual restart: kill existing processes
|
||||
killed = kill_gateway_processes()
|
||||
if killed:
|
||||
print(f"✓ Stopped {killed} gateway process(es)")
|
||||
|
||||
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
# Start fresh
|
||||
print("Starting gateway...")
|
||||
run_gateway(verbose=False)
|
||||
|
||||
|
||||
elif subcmd == "status":
|
||||
deep = getattr(args, 'deep', False)
|
||||
|
||||
deep = getattr(args, "deep", False)
|
||||
|
||||
# Check for service first
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
systemd_status(deep)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,26 +8,26 @@ Add, remove, or reorder entries here — both `hermes setup` and
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from difflib import get_close_matches
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
("anthropic/claude-sonnet-4.5", ""),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
("anthropic/claude-sonnet-4.5", ""),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
]
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
@@ -93,9 +93,7 @@ def menu_labels() -> list[str]:
|
||||
|
||||
# All provider IDs and aliases that are valid for the provider:model syntax.
|
||||
_KNOWN_PROVIDER_NAMES: set[str] = (
|
||||
set(_PROVIDER_LABELS.keys())
|
||||
| set(_PROVIDER_ALIASES.keys())
|
||||
| {"openrouter", "custom"}
|
||||
set(_PROVIDER_LABELS.keys()) | set(_PROVIDER_ALIASES.keys()) | {"openrouter", "custom"}
|
||||
)
|
||||
|
||||
|
||||
@@ -107,8 +105,13 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
"""
|
||||
# Canonical providers in display order
|
||||
_PROVIDER_ORDER = [
|
||||
"openrouter", "nous", "openai-codex",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn",
|
||||
"openrouter",
|
||||
"nous",
|
||||
"openai-codex",
|
||||
"zai",
|
||||
"kimi-coding",
|
||||
"minimax",
|
||||
"minimax-cn",
|
||||
]
|
||||
# Build reverse alias map
|
||||
aliases_for: dict[str, list[str]] = {}
|
||||
@@ -123,16 +126,19 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
runtime = resolve_runtime_provider(requested=pid)
|
||||
has_creds = bool(runtime.get("api_key"))
|
||||
except Exception:
|
||||
pass
|
||||
result.append({
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"aliases": alias_list,
|
||||
"authenticated": has_creds,
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"aliases": alias_list,
|
||||
"authenticated": has_creds,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -157,13 +163,13 @@ def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
|
||||
colon = stripped.find(":")
|
||||
if colon > 0:
|
||||
provider_part = stripped[:colon].strip().lower()
|
||||
model_part = stripped[colon + 1:].strip()
|
||||
model_part = stripped[colon + 1 :].strip()
|
||||
if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES:
|
||||
return (normalize_provider(provider_part), model_part)
|
||||
return (current_provider, stripped)
|
||||
|
||||
|
||||
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
|
||||
def curated_models_for_provider(provider: str | None) -> list[tuple[str, str]]:
|
||||
"""Return ``(model_id, description)`` tuples for a provider's curated list."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
@@ -172,7 +178,7 @@ def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]
|
||||
return [(m, "") for m in models]
|
||||
|
||||
|
||||
def normalize_provider(provider: Optional[str]) -> str:
|
||||
def normalize_provider(provider: str | None) -> str:
|
||||
"""Normalize provider aliases to Hermes' canonical provider ids.
|
||||
|
||||
Note: ``"auto"`` passes through unchanged — use
|
||||
@@ -183,7 +189,7 @@ def normalize_provider(provider: Optional[str]) -> str:
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
|
||||
def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
def provider_model_ids(provider: str | None) -> list[str]:
|
||||
"""Return the best known model catalog for a provider."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
@@ -196,10 +202,10 @@ def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
|
||||
|
||||
def fetch_api_models(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
api_key: str | None,
|
||||
base_url: str | None,
|
||||
timeout: float = 5.0,
|
||||
) -> Optional[list[str]]:
|
||||
) -> list[str] | None:
|
||||
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
|
||||
|
||||
Returns a list of model ID strings, or ``None`` if the endpoint could not
|
||||
@@ -225,10 +231,10 @@ def fetch_api_models(
|
||||
|
||||
def validate_requested_model(
|
||||
model_name: str,
|
||||
provider: Optional[str],
|
||||
provider: str | None,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate a ``/model`` value for the active provider.
|
||||
@@ -286,10 +292,7 @@ def validate_requested_model(
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Error: `{requested}` is not a valid model for this provider."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
"message": (f"Error: `{requested}` is not a valid model for this provider.{suggestion_text}"),
|
||||
}
|
||||
|
||||
# api_models is None — couldn't reach API, fall back to catalog check
|
||||
|
||||
@@ -8,6 +8,7 @@ Usage:
|
||||
hermes pairing clear-pending # Clear all expired/pending codes
|
||||
"""
|
||||
|
||||
|
||||
def pairing_command(args):
|
||||
"""Handle hermes pairing subcommands."""
|
||||
from gateway.pairing import PairingStore
|
||||
@@ -72,10 +73,10 @@ def _cmd_approve(store, platform: str, code: str):
|
||||
name = result.get("user_name", "")
|
||||
display = f"{name} ({uid})" if name else uid
|
||||
print(f"\n Approved! User {display} on {platform} can now use the bot~")
|
||||
print(f" They'll be recognized automatically on their next message.\n")
|
||||
print(" They'll be recognized automatically on their next message.\n")
|
||||
else:
|
||||
print(f"\n Code '{code}' not found or expired for platform '{platform}'.")
|
||||
print(f" Run 'hermes pairing list' to see pending codes.\n")
|
||||
print(" Run 'hermes pairing list' to see pending codes.\n")
|
||||
|
||||
|
||||
def _cmd_revoke(store, platform: str, user_id: str):
|
||||
|
||||
@@ -3,22 +3,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
PROVIDER_REGISTRY,
|
||||
AuthError,
|
||||
format_auth_error,
|
||||
resolve_provider,
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_provider,
|
||||
)
|
||||
from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
|
||||
def _get_model_config() -> Dict[str, Any]:
|
||||
def _get_model_config() -> dict[str, Any]:
|
||||
config = load_config()
|
||||
model_cfg = config.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
@@ -28,7 +28,7 @@ def _get_model_config() -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def resolve_requested_provider(requested: Optional[str] = None) -> str:
|
||||
def resolve_requested_provider(requested: str | None = None) -> str:
|
||||
"""Resolve provider request from explicit arg, env, then config."""
|
||||
if requested and requested.strip():
|
||||
return requested.strip().lower()
|
||||
@@ -48,9 +48,9 @@ def resolve_requested_provider(requested: Optional[str] = None) -> str:
|
||||
def _resolve_openrouter_runtime(
|
||||
*,
|
||||
requested_provider: str,
|
||||
explicit_api_key: Optional[str] = None,
|
||||
explicit_base_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
explicit_api_key: str | None = None,
|
||||
explicit_base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
model_cfg = _get_model_config()
|
||||
cfg_base_url = model_cfg.get("base_url") if isinstance(model_cfg.get("base_url"), str) else ""
|
||||
cfg_provider = model_cfg.get("provider") if isinstance(model_cfg.get("provider"), str) else ""
|
||||
@@ -81,19 +81,9 @@ def _resolve_openrouter_runtime(
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = "openrouter.ai" in base_url
|
||||
if _is_openrouter_url:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
api_key = explicit_api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||||
else:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or ""
|
||||
)
|
||||
api_key = explicit_api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY") or ""
|
||||
|
||||
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
|
||||
|
||||
@@ -108,10 +98,10 @@ def _resolve_openrouter_runtime(
|
||||
|
||||
def resolve_runtime_provider(
|
||||
*,
|
||||
requested: Optional[str] = None,
|
||||
explicit_api_key: Optional[str] = None,
|
||||
explicit_base_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
requested: str | None = None,
|
||||
explicit_api_key: str | None = None,
|
||||
explicit_base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve runtime provider credentials for agent execution."""
|
||||
requested_provider = resolve_requested_provider(requested)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,6 @@ handler are thin wrappers that parse args and delegate.
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -29,6 +28,7 @@ _console = Console()
|
||||
# Shared do_* functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_short_name(name: str, sources, console: Console) -> str:
|
||||
"""
|
||||
Resolve a short skill name (e.g. 'pptx') to a full identifier by searching
|
||||
@@ -57,7 +57,9 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
|
||||
table.add_column("Trust", style="dim")
|
||||
table.add_column("Identifier", style="bold cyan")
|
||||
for r in exact:
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(
|
||||
r.trust_level, "dim"
|
||||
)
|
||||
trust_label = "official" if r.source == "official" else r.trust_level
|
||||
table.add_row(r.source, f"[{trust_style}]{trust_label}[/]", r.identifier)
|
||||
c.print(table)
|
||||
@@ -76,8 +78,7 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def do_search(query: str, source: str = "all", limit: int = 10,
|
||||
console: Optional[Console] = None) -> None:
|
||||
def do_search(query: str, source: str = "all", limit: int = 10, console: Console | None = None) -> None:
|
||||
"""Search registries and display results as a Rich table."""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router, unified_search
|
||||
|
||||
@@ -111,18 +112,19 @@ def do_search(query: str, source: str = "all", limit: int = 10,
|
||||
)
|
||||
|
||||
c.print(table)
|
||||
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
c.print(
|
||||
"[dim]Use: hermes skills inspect <identifier> to preview, hermes skills install <identifier> to install[/]\n"
|
||||
)
|
||||
|
||||
|
||||
def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
console: Optional[Console] = None) -> None:
|
||||
def do_browse(page: int = 1, page_size: int = 20, source: str = "all", console: Console | None = None) -> None:
|
||||
"""Browse all available skills across registries, paginated.
|
||||
|
||||
Official skills are always shown first, regardless of source filter.
|
||||
"""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth, create_source_router, OptionalSkillSource, SkillMeta,
|
||||
GitHubAuth,
|
||||
create_source_router,
|
||||
)
|
||||
|
||||
# Clamp page_size to safe range
|
||||
@@ -136,8 +138,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
# Collect results from all (or filtered) sources
|
||||
# Use empty query to get everything; per-source limits prevent overload
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50, "claude-marketplace": 50, "lobehub": 50}
|
||||
|
||||
all_results: list = []
|
||||
source_counts: dict = {}
|
||||
@@ -168,11 +169,13 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
deduped = list(seen.values())
|
||||
|
||||
# Sort: official first, then by trust level (desc), then alphabetically
|
||||
deduped.sort(key=lambda r: (
|
||||
-_TRUST_RANK.get(r.trust_level, 0),
|
||||
r.source != "official",
|
||||
r.name.lower(),
|
||||
))
|
||||
deduped.sort(
|
||||
key=lambda r: (
|
||||
-_TRUST_RANK.get(r.trust_level, 0),
|
||||
r.source != "official",
|
||||
r.name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
# Paginate
|
||||
total = len(deduped)
|
||||
@@ -187,8 +190,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
|
||||
# Build header
|
||||
source_label = f"— {source}" if source != "all" else "— all sources"
|
||||
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
|
||||
f" [dim]({total} skills, page {page}/{total_pages})[/]")
|
||||
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/] [dim]({total} skills, page {page}/{total_pages})[/]")
|
||||
if official_count > 0 and page == 1:
|
||||
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
|
||||
c.print()
|
||||
@@ -202,8 +204,7 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
table.add_column("Trust", width=10)
|
||||
|
||||
for i, r in enumerate(page_items, start=start + 1):
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green",
|
||||
"community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_label = "★ official" if r.source == "official" else r.trust_level
|
||||
|
||||
desc = r.description[:50]
|
||||
@@ -235,18 +236,22 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
|
||||
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
|
||||
|
||||
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
|
||||
|
||||
def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
console: Optional[Console] = None) -> None:
|
||||
"""Fetch, quarantine, scan, confirm, and install a skill."""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth, create_source_router, ensure_hub_dirs,
|
||||
quarantine_bundle, install_from_quarantine, HubLockFile,
|
||||
c.print(
|
||||
"[dim]Use: hermes skills inspect <identifier> to preview, hermes skills install <identifier> to install[/]\n"
|
||||
)
|
||||
|
||||
|
||||
def do_install(identifier: str, category: str = "", force: bool = False, console: Console | None = None) -> None:
|
||||
"""Fetch, quarantine, scan, confirm, and install a skill."""
|
||||
from tools.skills_guard import format_scan_report, scan_skill, should_allow_install
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth,
|
||||
HubLockFile,
|
||||
create_source_router,
|
||||
ensure_hub_dirs,
|
||||
install_from_quarantine,
|
||||
quarantine_bundle,
|
||||
)
|
||||
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
|
||||
|
||||
c = console or _console
|
||||
ensure_hub_dirs()
|
||||
@@ -304,33 +309,43 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
# Clean up quarantine
|
||||
shutil.rmtree(q_path, ignore_errors=True)
|
||||
from tools.skills_hub import append_audit_log
|
||||
append_audit_log("BLOCKED", bundle.name, bundle.source,
|
||||
bundle.trust_level, result.verdict,
|
||||
f"{len(result.findings)}_findings")
|
||||
|
||||
append_audit_log(
|
||||
"BLOCKED",
|
||||
bundle.name,
|
||||
bundle.source,
|
||||
bundle.trust_level,
|
||||
result.verdict,
|
||||
f"{len(result.findings)}_findings",
|
||||
)
|
||||
return
|
||||
|
||||
# Confirm with user — show appropriate warning based on source
|
||||
if not force:
|
||||
c.print()
|
||||
if bundle.source == "official":
|
||||
c.print(Panel(
|
||||
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
|
||||
"It ships with hermes-agent but is not activated by default.\n"
|
||||
"Installing will copy it to your skills directory where the agent can use it.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Official Skill",
|
||||
border_style="bright_cyan",
|
||||
))
|
||||
c.print(
|
||||
Panel(
|
||||
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
|
||||
"It ships with hermes-agent but is not activated by default.\n"
|
||||
"Installing will copy it to your skills directory where the agent can use it.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Official Skill",
|
||||
border_style="bright_cyan",
|
||||
)
|
||||
)
|
||||
else:
|
||||
c.print(Panel(
|
||||
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
|
||||
"External skills can contain instructions that influence agent behavior,\n"
|
||||
"shell commands, and scripts. Even after automated scanning, you should\n"
|
||||
"review the installed files before use.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Disclaimer",
|
||||
border_style="yellow",
|
||||
))
|
||||
c.print(
|
||||
Panel(
|
||||
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
|
||||
"External skills can contain instructions that influence agent behavior,\n"
|
||||
"shell commands, and scripts. Even after automated scanning, you should\n"
|
||||
"review the installed files before use.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Disclaimer",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
c.print(f"[bold]Install '{bundle.name}'?[/]")
|
||||
try:
|
||||
answer = input("Confirm [y/N]: ").strip().lower()
|
||||
@@ -344,11 +359,12 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
# Install
|
||||
install_dir = install_from_quarantine(q_path, bundle.name, category, bundle, result)
|
||||
from tools.skills_hub import SKILLS_DIR
|
||||
|
||||
c.print(f"[bold green]Installed:[/] {install_dir.relative_to(SKILLS_DIR)}")
|
||||
c.print(f"[dim]Files: {', '.join(bundle.files.keys())}[/]\n")
|
||||
|
||||
|
||||
def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
def do_inspect(identifier: str, console: Console | None = None) -> None:
|
||||
"""Preview a skill's SKILL.md content without installing."""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router
|
||||
|
||||
@@ -406,7 +422,7 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
c.print()
|
||||
|
||||
|
||||
def do_list(source_filter: str = "all", console: Optional[Console] = None) -> None:
|
||||
def do_list(source_filter: str = "all", console: Console | None = None) -> None:
|
||||
"""List installed skills, distinguishing builtins from hub-installed."""
|
||||
from tools.skills_hub import HubLockFile, ensure_hub_dirs
|
||||
from tools.skills_tool import _find_all_skills
|
||||
@@ -446,14 +462,13 @@ def do_list(source_filter: str = "all", console: Optional[Console] = None) -> No
|
||||
table.add_row(name, category, source_display, f"[{trust_style}]{trust_label}[/]")
|
||||
|
||||
c.print(table)
|
||||
c.print(f"[dim]{len(hub_installed)} hub-installed, "
|
||||
f"{len(all_skills) - len(hub_installed)} builtin[/]\n")
|
||||
c.print(f"[dim]{len(hub_installed)} hub-installed, {len(all_skills) - len(hub_installed)} builtin[/]\n")
|
||||
|
||||
|
||||
def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> None:
|
||||
def do_audit(name: str | None = None, console: Console | None = None) -> None:
|
||||
"""Re-run security scan on installed hub skills."""
|
||||
from tools.skills_hub import HubLockFile, SKILLS_DIR
|
||||
from tools.skills_guard import scan_skill, format_scan_report
|
||||
from tools.skills_guard import format_scan_report, scan_skill
|
||||
from tools.skills_hub import SKILLS_DIR, HubLockFile
|
||||
|
||||
c = console or _console
|
||||
lock = HubLockFile()
|
||||
@@ -483,7 +498,7 @@ def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> N
|
||||
c.print()
|
||||
|
||||
|
||||
def do_uninstall(name: str, console: Optional[Console] = None) -> None:
|
||||
def do_uninstall(name: str, console: Console | None = None) -> None:
|
||||
"""Remove a hub-installed skill with confirmation."""
|
||||
from tools.skills_hub import uninstall_skill
|
||||
|
||||
@@ -505,7 +520,7 @@ def do_uninstall(name: str, console: Optional[Console] = None) -> None:
|
||||
c.print(f"[bold red]Error:[/] {msg}\n")
|
||||
|
||||
|
||||
def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> None:
|
||||
def do_tap(action: str, repo: str = "", console: Console | None = None) -> None:
|
||||
"""Manage taps (custom GitHub repo sources)."""
|
||||
from tools.skills_hub import TapsManager
|
||||
|
||||
@@ -547,11 +562,10 @@ def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> No
|
||||
c.print(f"[bold red]Unknown tap action:[/] {action}. Use: list, add, remove\n")
|
||||
|
||||
|
||||
def do_publish(skill_path: str, target: str = "github", repo: str = "",
|
||||
console: Optional[Console] = None) -> None:
|
||||
def do_publish(skill_path: str, target: str = "github", repo: str = "", console: Console | None = None) -> None:
|
||||
"""Publish a local skill to a registry (GitHub PR or ClawHub submission)."""
|
||||
from tools.skills_hub import GitHubAuth, SKILLS_DIR
|
||||
from tools.skills_guard import scan_skill, format_scan_report
|
||||
from tools.skills_guard import format_scan_report, scan_skill
|
||||
from tools.skills_hub import SKILLS_DIR, GitHubAuth
|
||||
|
||||
c = console or _console
|
||||
path = Path(skill_path)
|
||||
@@ -565,14 +579,16 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "",
|
||||
|
||||
# Validate the skill
|
||||
import yaml
|
||||
|
||||
skill_md = (path / "SKILL.md").read_text(encoding="utf-8")
|
||||
fm = {}
|
||||
if skill_md.startswith("---"):
|
||||
import re
|
||||
match = re.search(r'\n---\s*\n', skill_md[3:])
|
||||
|
||||
match = re.search(r"\n---\s*\n", skill_md[3:])
|
||||
if match:
|
||||
try:
|
||||
fm = yaml.safe_load(skill_md[3:match.start() + 3]) or {}
|
||||
fm = yaml.safe_load(skill_md[3 : match.start() + 3]) or {}
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
|
||||
@@ -592,14 +608,18 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "",
|
||||
|
||||
if target == "github":
|
||||
if not repo:
|
||||
c.print("[bold red]Error:[/] --repo required for GitHub publish.\n"
|
||||
"Usage: hermes skills publish <path> --to github --repo owner/repo\n")
|
||||
c.print(
|
||||
"[bold red]Error:[/] --repo required for GitHub publish.\n"
|
||||
"Usage: hermes skills publish <path> --to github --repo owner/repo\n"
|
||||
)
|
||||
return
|
||||
|
||||
auth = GitHubAuth()
|
||||
if not auth.is_authenticated():
|
||||
c.print("[bold red]Error:[/] GitHub authentication required.\n"
|
||||
"Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n")
|
||||
c.print(
|
||||
"[bold red]Error:[/] GitHub authentication required.\n"
|
||||
"Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n"
|
||||
)
|
||||
return
|
||||
|
||||
c.print(f"[bold]Publishing '{name}' to {repo}...[/]")
|
||||
@@ -610,14 +630,12 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "",
|
||||
c.print(f"[bold red]Error:[/] {msg}\n")
|
||||
|
||||
elif target == "clawhub":
|
||||
c.print("[yellow]ClawHub publishing is not yet supported. "
|
||||
"Submit manually at https://clawhub.ai/submit[/]\n")
|
||||
c.print("[yellow]ClawHub publishing is not yet supported. Submit manually at https://clawhub.ai/submit[/]\n")
|
||||
else:
|
||||
c.print(f"[bold red]Unknown target:[/] {target}. Use 'github' or 'clawhub'.\n")
|
||||
|
||||
|
||||
def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
auth) -> tuple:
|
||||
def _github_publish(skill_path: Path, skill_name: str, target_repo: str, auth) -> tuple:
|
||||
"""Create a PR to a GitHub repo with the skill. Returns (success, message)."""
|
||||
import httpx
|
||||
|
||||
@@ -627,7 +645,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"https://api.github.com/repos/{target_repo}/forks",
|
||||
headers=headers, timeout=30,
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 202):
|
||||
fork = resp.json()
|
||||
@@ -643,7 +662,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"https://api.github.com/repos/{target_repo}",
|
||||
headers=headers, timeout=15,
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
)
|
||||
default_branch = resp.json().get("default_branch", "main")
|
||||
except Exception:
|
||||
@@ -653,7 +673,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
try:
|
||||
resp = httpx.get(
|
||||
f"https://api.github.com/repos/{fork_repo}/git/refs/heads/{default_branch}",
|
||||
headers=headers, timeout=15,
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
)
|
||||
base_sha = resp.json()["object"]["sha"]
|
||||
except Exception as e:
|
||||
@@ -664,7 +685,8 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
try:
|
||||
httpx.post(
|
||||
f"https://api.github.com/repos/{fork_repo}/git/refs",
|
||||
headers=headers, timeout=15,
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
json={"ref": f"refs/heads/{branch_name}", "sha": base_sha},
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -678,10 +700,12 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
upload_path = f"skills/{skill_name}/{rel}"
|
||||
try:
|
||||
import base64
|
||||
|
||||
content_b64 = base64.b64encode(f.read_bytes()).decode()
|
||||
httpx.put(
|
||||
f"https://api.github.com/repos/{fork_repo}/contents/{upload_path}",
|
||||
headers=headers, timeout=15,
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
json={
|
||||
"message": f"Add {skill_name} skill: {rel}",
|
||||
"content": content_b64,
|
||||
@@ -695,11 +719,12 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"https://api.github.com/repos/{target_repo}/pulls",
|
||||
headers=headers, timeout=15,
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
json={
|
||||
"title": f"Add skill: {skill_name}",
|
||||
"body": f"Submitting the `{skill_name}` skill via Hermes Skills Hub.\n\n"
|
||||
f"This skill was scanned by the Hermes Skills Guard before submission.",
|
||||
f"This skill was scanned by the Hermes Skills Guard before submission.",
|
||||
"head": f"{fork_repo.split('/')[0]}:{branch_name}",
|
||||
"base": default_branch,
|
||||
},
|
||||
@@ -713,7 +738,7 @@ def _github_publish(skill_path: Path, skill_name: str, target_repo: str,
|
||||
return False, f"Network error creating PR: {e}"
|
||||
|
||||
|
||||
def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> None:
|
||||
def do_snapshot_export(output_path: str, console: Console | None = None) -> None:
|
||||
"""Export current hub skill configuration to a portable JSON file."""
|
||||
from tools.skills_hub import HubLockFile, TapsManager
|
||||
|
||||
@@ -726,16 +751,15 @@ def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> N
|
||||
|
||||
snapshot = {
|
||||
"hermes_version": "0.1.0",
|
||||
"exported_at": __import__("datetime").datetime.now(
|
||||
__import__("datetime").timezone.utc
|
||||
).isoformat(),
|
||||
"exported_at": __import__("datetime").datetime.now(__import__("datetime").timezone.utc).isoformat(),
|
||||
"skills": [
|
||||
{
|
||||
"name": entry["name"],
|
||||
"source": entry.get("source", ""),
|
||||
"identifier": entry.get("identifier", ""),
|
||||
"category": str(Path(entry.get("install_path", "")).parent)
|
||||
if "/" in entry.get("install_path", "") else "",
|
||||
if "/" in entry.get("install_path", "")
|
||||
else "",
|
||||
}
|
||||
for entry in installed
|
||||
],
|
||||
@@ -748,8 +772,7 @@ def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> N
|
||||
c.print(f"[dim]{len(installed)} skill(s), {len(tap_list)} tap(s)[/]\n")
|
||||
|
||||
|
||||
def do_snapshot_import(input_path: str, force: bool = False,
|
||||
console: Optional[Console] = None) -> None:
|
||||
def do_snapshot_import(input_path: str, force: bool = False, console: Console | None = None) -> None:
|
||||
"""Re-install skills from a snapshot file."""
|
||||
from tools.skills_hub import TapsManager
|
||||
|
||||
@@ -799,6 +822,7 @@ def do_snapshot_import(input_path: str, force: bool = False,
|
||||
# CLI argparse entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def skills_command(args) -> None:
|
||||
"""Router for `hermes skills <subcommand>` — called from hermes_cli/main.py."""
|
||||
action = getattr(args, "skills_action", None)
|
||||
@@ -839,7 +863,9 @@ def skills_command(args) -> None:
|
||||
return
|
||||
do_tap(tap_action, repo=repo)
|
||||
else:
|
||||
_console.print("Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
|
||||
_console.print(
|
||||
"Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n"
|
||||
)
|
||||
_console.print("Run 'hermes skills <command> --help' for details.\n")
|
||||
|
||||
|
||||
@@ -847,7 +873,8 @@ def skills_command(args) -> None:
|
||||
# Slash command entry point (/skills in chat)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
|
||||
|
||||
def handle_skills_slash(cmd: str, console: Console | None = None) -> None:
|
||||
"""
|
||||
Parse and dispatch `/skills <subcommand> [args]` from the chat interface.
|
||||
|
||||
@@ -1008,17 +1035,19 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
|
||||
|
||||
def _print_skills_help(console: Console) -> None:
|
||||
"""Print help for the /skills slash command."""
|
||||
console.print(Panel(
|
||||
"[bold]Skills Hub Commands:[/]\n\n"
|
||||
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
|
||||
" [cyan]search[/] <query> Search registries for skills\n"
|
||||
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
|
||||
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"
|
||||
" [cyan]list[/] [--source hub|builtin] List installed skills\n"
|
||||
" [cyan]audit[/] [name] Re-scan hub skills for security\n"
|
||||
" [cyan]uninstall[/] <name> Remove a hub-installed skill\n"
|
||||
" [cyan]publish[/] <path> --repo <r> Publish a skill to GitHub via PR\n"
|
||||
" [cyan]snapshot[/] export|import Export/import skill configurations\n"
|
||||
" [cyan]tap[/] list|add|remove Manage skill sources\n",
|
||||
title="/skills",
|
||||
))
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold]Skills Hub Commands:[/]\n\n"
|
||||
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
|
||||
" [cyan]search[/] <query> Search registries for skills\n"
|
||||
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
|
||||
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"
|
||||
" [cyan]list[/] [--source hub|builtin] List installed skills\n"
|
||||
" [cyan]audit[/] [name] Re-scan hub skills for security\n"
|
||||
" [cyan]uninstall[/] <name> Remove a hub-installed skill\n"
|
||||
" [cyan]publish[/] <path> --repo <r> Publish a skill to GitHub via PR\n"
|
||||
" [cyan]snapshot[/] export|import Export/import skill configurations\n"
|
||||
" [cyan]tap[/] list|add|remove Manage skill sources\n",
|
||||
title="/skills",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,21 +5,25 @@ Shows the status of all Hermes Agent components.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
from datetime import UTC
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.config import get_env_path, get_env_value
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
|
||||
def check_mark(ok: bool) -> str:
|
||||
if ok:
|
||||
return color("✓", Colors.GREEN)
|
||||
return color("✗", Colors.RED)
|
||||
|
||||
|
||||
def redact_key(key: str) -> str:
|
||||
"""Redact an API key for display."""
|
||||
if not key:
|
||||
@@ -33,7 +37,8 @@ def _format_iso_timestamp(value) -> str:
|
||||
"""Format ISO timestamps for status output, converting to local timezone."""
|
||||
if not value or not isinstance(value, str):
|
||||
return "(unknown)"
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return "(unknown)"
|
||||
@@ -42,7 +47,7 @@ def _format_iso_timestamp(value) -> str:
|
||||
try:
|
||||
parsed = datetime.fromisoformat(text)
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
parsed = parsed.replace(tzinfo=UTC)
|
||||
except Exception:
|
||||
return value
|
||||
return parsed.astimezone().strftime("%Y-%m-%d %H:%M:%S %Z")
|
||||
@@ -50,14 +55,14 @@ def _format_iso_timestamp(value) -> str:
|
||||
|
||||
def show_status(args):
|
||||
"""Show status of all Hermes Agent components."""
|
||||
show_all = getattr(args, 'all', False)
|
||||
deep = getattr(args, 'deep', False)
|
||||
|
||||
show_all = getattr(args, "all", False)
|
||||
deep = getattr(args, "deep", False)
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ ⚕ Hermes Agent Status │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Environment
|
||||
# =========================================================================
|
||||
@@ -65,19 +70,19 @@ def show_status(args):
|
||||
print(color("◆ Environment", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Project: {PROJECT_ROOT}")
|
||||
print(f" Python: {sys.version.split()[0]}")
|
||||
|
||||
|
||||
env_path = get_env_path()
|
||||
print(f" .env file: {check_mark(env_path.exists())} {'exists' if env_path.exists() else 'not found'}")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# API Keys
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API Keys", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
keys = {
|
||||
"OpenRouter": "OPENROUTER_API_KEY",
|
||||
"Anthropic": "ANTHROPIC_API_KEY",
|
||||
"Anthropic": "ANTHROPIC_API_KEY",
|
||||
"OpenAI": "OPENAI_API_KEY",
|
||||
"Z.AI/GLM": "GLM_API_KEY",
|
||||
"Kimi": "KIMI_API_KEY",
|
||||
@@ -91,7 +96,7 @@ def show_status(args):
|
||||
"ElevenLabs": "ELEVENLABS_API_KEY",
|
||||
"GitHub": "GITHUB_TOKEN",
|
||||
}
|
||||
|
||||
|
||||
for name, env_var in keys.items():
|
||||
value = get_env_value(env_var) or ""
|
||||
has_key = bool(value)
|
||||
@@ -105,7 +110,8 @@ def show_status(args):
|
||||
print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status
|
||||
from hermes_cli.auth import get_codex_auth_status, get_nous_auth_status
|
||||
|
||||
nous_status = get_nous_auth_status()
|
||||
codex_status = get_codex_auth_status()
|
||||
except Exception:
|
||||
@@ -148,10 +154,10 @@ def show_status(args):
|
||||
print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
apikey_providers = {
|
||||
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"Kimi / Moonshot": ("KIMI_API_KEY",),
|
||||
"MiniMax": ("MINIMAX_API_KEY",),
|
||||
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
|
||||
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"Kimi / Moonshot": ("KIMI_API_KEY",),
|
||||
"MiniMax": ("MINIMAX_API_KEY",),
|
||||
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
|
||||
}
|
||||
for pname, env_vars in apikey_providers.items():
|
||||
key_val = ""
|
||||
@@ -168,19 +174,20 @@ def show_status(args):
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Terminal Backend", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "")
|
||||
if not terminal_env:
|
||||
# Fall back to config file value when env var isn't set
|
||||
# (hermes status doesn't go through cli.py's config loading)
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
_cfg = load_config()
|
||||
terminal_env = _cfg.get("terminal", {}).get("backend", "local")
|
||||
except Exception:
|
||||
terminal_env = "local"
|
||||
print(f" Backend: {terminal_env}")
|
||||
|
||||
|
||||
if terminal_env == "ssh":
|
||||
ssh_host = os.getenv("TERMINAL_SSH_HOST", "")
|
||||
ssh_user = os.getenv("TERMINAL_SSH_USER", "")
|
||||
@@ -192,16 +199,16 @@ def show_status(args):
|
||||
elif terminal_env == "daytona":
|
||||
daytona_image = os.getenv("TERMINAL_DAYTONA_IMAGE", "nikolaik/python-nodejs:python3.11-nodejs20")
|
||||
print(f" Daytona Image: {daytona_image}")
|
||||
|
||||
|
||||
sudo_password = os.getenv("SUDO_PASSWORD", "")
|
||||
print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Messaging Platforms
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Messaging Platforms", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
platforms = {
|
||||
"Telegram": ("TELEGRAM_BOT_TOKEN", "TELEGRAM_HOME_CHANNEL"),
|
||||
"Discord": ("DISCORD_BOT_TOKEN", "DISCORD_HOME_CHANNEL"),
|
||||
@@ -209,59 +216,52 @@ def show_status(args):
|
||||
"Signal": ("SIGNAL_HTTP_URL", "SIGNAL_HOME_CHANNEL"),
|
||||
"Slack": ("SLACK_BOT_TOKEN", None),
|
||||
}
|
||||
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
token = os.getenv(token_var, "")
|
||||
has_token = bool(token)
|
||||
|
||||
|
||||
home_channel = ""
|
||||
if home_var:
|
||||
home_channel = os.getenv(home_var, "")
|
||||
|
||||
|
||||
status = "configured" if has_token else "not configured"
|
||||
if home_channel:
|
||||
status += f" (home: {home_channel})"
|
||||
|
||||
|
||||
print(f" {name:<12} {check_mark(has_token)} {status}")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Gateway Status
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if sys.platform.startswith("linux"):
|
||||
result = subprocess.run(["systemctl", "--user", "is-active", "hermes-gateway"], capture_output=True, text=True)
|
||||
is_active = result.stdout.strip() == "active"
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(f" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == 'darwin':
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
print(" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == "darwin":
|
||||
result = subprocess.run(["launchctl", "list", "ai.hermes.gateway"], capture_output=True, text=True)
|
||||
is_loaded = result.returncode == 0
|
||||
print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}")
|
||||
print(f" Manager: launchd")
|
||||
print(" Manager: launchd")
|
||||
else:
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(f" Manager: (not supported on this platform)")
|
||||
|
||||
print(" Manager: (not supported on this platform)")
|
||||
|
||||
# =========================================================================
|
||||
# Cron Jobs
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Scheduled Jobs", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
jobs_file = Path.home() / ".hermes" / "cron" / "jobs.json"
|
||||
if jobs_file.exists():
|
||||
import json
|
||||
|
||||
try:
|
||||
with open(jobs_file) as f:
|
||||
data = json.load(f)
|
||||
@@ -269,56 +269,57 @@ def show_status(args):
|
||||
enabled_jobs = [j for j in jobs if j.get("enabled", True)]
|
||||
print(f" Jobs: {len(enabled_jobs)} active, {len(jobs)} total")
|
||||
except Exception:
|
||||
print(f" Jobs: (error reading jobs file)")
|
||||
print(" Jobs: (error reading jobs file)")
|
||||
else:
|
||||
print(f" Jobs: 0")
|
||||
|
||||
print(" Jobs: 0")
|
||||
|
||||
# =========================================================================
|
||||
# Sessions
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Sessions", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
sessions_file = Path.home() / ".hermes" / "sessions" / "sessions.json"
|
||||
if sessions_file.exists():
|
||||
import json
|
||||
|
||||
try:
|
||||
with open(sessions_file) as f:
|
||||
data = json.load(f)
|
||||
print(f" Active: {len(data)} session(s)")
|
||||
except Exception:
|
||||
print(f" Active: (error reading sessions file)")
|
||||
print(" Active: (error reading sessions file)")
|
||||
else:
|
||||
print(f" Active: 0")
|
||||
|
||||
print(" Active: 0")
|
||||
|
||||
# =========================================================================
|
||||
# Deep checks
|
||||
# =========================================================================
|
||||
if deep:
|
||||
print()
|
||||
print(color("◆ Deep Checks", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
# Check OpenRouter connectivity
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY", "")
|
||||
if openrouter_key:
|
||||
try:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(
|
||||
OPENROUTER_MODELS_URL,
|
||||
headers={"Authorization": f"Bearer {openrouter_key}"},
|
||||
timeout=10
|
||||
OPENROUTER_MODELS_URL, headers={"Authorization": f"Bearer {openrouter_key}"}, timeout=10
|
||||
)
|
||||
ok = response.status_code == 200
|
||||
print(f" OpenRouter: {check_mark(ok)} {'reachable' if ok else f'error ({response.status_code})'}")
|
||||
except Exception as e:
|
||||
print(f" OpenRouter: {check_mark(False)} error: {e}")
|
||||
|
||||
|
||||
# Check gateway port
|
||||
try:
|
||||
import socket
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1)
|
||||
result = sock.connect_ex(('127.0.0.1', 18789))
|
||||
result = sock.connect_ex(("127.0.0.1", 18789))
|
||||
sock.close()
|
||||
# Port in use = gateway likely running
|
||||
port_in_use = result == 0
|
||||
@@ -326,7 +327,7 @@ def show_status(args):
|
||||
print(f" Port 18789: {'in use' if port_in_use else 'available'}")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
print()
|
||||
print(color("─" * 60, Colors.DIM))
|
||||
print(color(" Run 'hermes doctor' for detailed diagnostics", Colors.DIM))
|
||||
|
||||
@@ -11,33 +11,37 @@ the `platform_toolsets` key.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import os
|
||||
|
||||
from hermes_cli.config import (
|
||||
load_config, save_config, get_env_value, save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.config import (
|
||||
get_env_value,
|
||||
load_config,
|
||||
save_config,
|
||||
save_env_value,
|
||||
)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
|
||||
|
||||
|
||||
def _print_info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
|
||||
def _print_success(text: str):
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
|
||||
def _print_warning(text: str):
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
def _print_error(text: str):
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
@@ -46,6 +50,7 @@ def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
@@ -54,6 +59,7 @@ def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
print()
|
||||
return default or ""
|
||||
|
||||
|
||||
def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
while True:
|
||||
@@ -64,9 +70,9 @@ def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
return default
|
||||
if not value:
|
||||
return default
|
||||
if value in ('y', 'yes'):
|
||||
if value in ("y", "yes"):
|
||||
return True
|
||||
if value in ('n', 'no'):
|
||||
if value in ("n", "no"):
|
||||
return False
|
||||
|
||||
|
||||
@@ -76,24 +82,24 @@ def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
# Each entry: (toolset_name, label, description)
|
||||
# These map to keys in toolsets.py TOOLSETS dict.
|
||||
CONFIGURABLE_TOOLSETS = [
|
||||
("web", "🔍 Web Search & Scraping", "web_search, web_extract"),
|
||||
("browser", "🌐 Browser Automation", "navigate, click, type, scroll"),
|
||||
("terminal", "💻 Terminal & Processes", "terminal, process"),
|
||||
("file", "📁 File Operations", "read, write, patch, search"),
|
||||
("code_execution", "⚡ Code Execution", "execute_code"),
|
||||
("vision", "👁️ Vision / Image Analysis", "vision_analyze"),
|
||||
("image_gen", "🎨 Image Generation", "image_generate"),
|
||||
("moa", "🧠 Mixture of Agents", "mixture_of_agents"),
|
||||
("tts", "🔊 Text-to-Speech", "text_to_speech"),
|
||||
("skills", "📚 Skills", "list, view, manage"),
|
||||
("todo", "📋 Task Planning", "todo"),
|
||||
("memory", "💾 Memory", "persistent memory across sessions"),
|
||||
("session_search", "🔎 Session Search", "search past conversations"),
|
||||
("clarify", "❓ Clarifying Questions", "clarify"),
|
||||
("delegation", "👥 Task Delegation", "delegate_task"),
|
||||
("cronjob", "⏰ Cron Jobs", "schedule, list, remove"),
|
||||
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
|
||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||
("web", "🔍 Web Search & Scraping", "web_search, web_extract"),
|
||||
("browser", "🌐 Browser Automation", "navigate, click, type, scroll"),
|
||||
("terminal", "💻 Terminal & Processes", "terminal, process"),
|
||||
("file", "📁 File Operations", "read, write, patch, search"),
|
||||
("code_execution", "⚡ Code Execution", "execute_code"),
|
||||
("vision", "👁️ Vision / Image Analysis", "vision_analyze"),
|
||||
("image_gen", "🎨 Image Generation", "image_generate"),
|
||||
("moa", "🧠 Mixture of Agents", "mixture_of_agents"),
|
||||
("tts", "🔊 Text-to-Speech", "text_to_speech"),
|
||||
("skills", "📚 Skills", "list, view, manage"),
|
||||
("todo", "📋 Task Planning", "todo"),
|
||||
("memory", "💾 Memory", "persistent memory across sessions"),
|
||||
("session_search", "🔎 Session Search", "search past conversations"),
|
||||
("clarify", "❓ Clarifying Questions", "clarify"),
|
||||
("delegation", "👥 Task Delegation", "delegate_task"),
|
||||
("cronjob", "⏰ Cron Jobs", "schedule, list, remove"),
|
||||
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
|
||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||
]
|
||||
|
||||
# Toolsets that are OFF by default for new installs.
|
||||
@@ -103,11 +109,11 @@ _DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl"}
|
||||
|
||||
# Platform display config
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
|
||||
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
|
||||
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +137,11 @@ TOOL_CATEGORIES = {
|
||||
"name": "OpenAI TTS",
|
||||
"tag": "Premium - high quality voices",
|
||||
"env_vars": [
|
||||
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
|
||||
{
|
||||
"key": "VOICE_TOOLS_OPENAI_KEY",
|
||||
"prompt": "OpenAI API key",
|
||||
"url": "https://platform.openai.com/api-keys",
|
||||
},
|
||||
],
|
||||
"tts_provider": "openai",
|
||||
},
|
||||
@@ -139,7 +149,11 @@ TOOL_CATEGORIES = {
|
||||
"name": "ElevenLabs",
|
||||
"tag": "Premium - most natural voices",
|
||||
"env_vars": [
|
||||
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
|
||||
{
|
||||
"key": "ELEVENLABS_API_KEY",
|
||||
"prompt": "ElevenLabs API key",
|
||||
"url": "https://elevenlabs.io/app/settings/api-keys",
|
||||
},
|
||||
],
|
||||
"tts_provider": "elevenlabs",
|
||||
},
|
||||
@@ -224,7 +238,11 @@ TOOL_CATEGORIES = {
|
||||
"name": "Tinker / Atropos",
|
||||
"tag": "RL training platform",
|
||||
"env_vars": [
|
||||
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
|
||||
{
|
||||
"key": "TINKER_API_KEY",
|
||||
"prompt": "Tinker API key",
|
||||
"url": "https://tinker-console.thinkingmachines.ai/keys",
|
||||
},
|
||||
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
|
||||
],
|
||||
"post_setup": "rl_training",
|
||||
@@ -236,24 +254,26 @@ TOOL_CATEGORIES = {
|
||||
# Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES.
|
||||
# Used as a fallback for tools like vision/moa that just need an API key.
|
||||
TOOLSET_ENV_REQUIREMENTS = {
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
}
|
||||
|
||||
|
||||
# ─── Post-Setup Hooks ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _run_post_setup(post_setup_key: str):
|
||||
"""Run post-setup hooks for tools that need extra installation steps."""
|
||||
import shutil
|
||||
|
||||
if post_setup_key == "browserbase":
|
||||
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
|
||||
if not node_modules.exists() and shutil.which("npm"):
|
||||
_print_info(" Installing Node.js dependencies for browser tools...")
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--silent"],
|
||||
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
|
||||
["npm", "install", "--silent"], capture_output=True, text=True, cwd=str(PROJECT_ROOT)
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" Node.js dependencies installed")
|
||||
@@ -270,16 +290,17 @@ def _run_post_setup(post_setup_key: str):
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
_print_info(" Installing tinker-atropos submodule...")
|
||||
import subprocess
|
||||
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
result = subprocess.run(
|
||||
[uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)], capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" tinker-atropos installed")
|
||||
@@ -294,7 +315,8 @@ def _run_post_setup(post_setup_key: str):
|
||||
|
||||
# ─── Platform / Toolset Helpers ───────────────────────────────────────────────
|
||||
|
||||
def _get_enabled_platforms() -> List[str]:
|
||||
|
||||
def _get_enabled_platforms() -> list[str]:
|
||||
"""Return platform keys that are configured (have tokens or are CLI)."""
|
||||
enabled = ["cli"]
|
||||
if get_env_value("TELEGRAM_BOT_TOKEN"):
|
||||
@@ -308,9 +330,9 @@ def _get_enabled_platforms() -> List[str]:
|
||||
return enabled
|
||||
|
||||
|
||||
def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
def _get_platform_tools(config: dict, platform: str) -> set[str]:
|
||||
"""Resolve which individual toolset names are enabled for a platform."""
|
||||
from toolsets import resolve_toolset, TOOLSETS
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
platform_toolsets = config.get("platform_toolsets", {})
|
||||
toolset_names = platform_toolsets.get(platform)
|
||||
@@ -335,7 +357,7 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
return enabled_toolsets
|
||||
|
||||
|
||||
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[str]):
|
||||
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: set[str]):
|
||||
"""Save the selected toolset keys for a platform to config."""
|
||||
config.setdefault("platform_toolsets", {})
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys)
|
||||
@@ -364,6 +386,7 @@ def _toolset_has_keys(ts_key: str) -> bool:
|
||||
|
||||
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
|
||||
rendering bugs in tmux, iTerm, and other non-standard terminals."""
|
||||
@@ -371,6 +394,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
# Curses-based single-select — works in tmux, iTerm, and standard terminals
|
||||
try:
|
||||
import curses
|
||||
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
@@ -386,8 +410,9 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(0, 0, question, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
stdscr.addnstr(
|
||||
0, 0, question, max_x - 1, curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
@@ -410,14 +435,14 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
elif key in (27, ord("q")):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
@@ -431,7 +456,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
|
||||
print(color(f" {marker} {i + 1}. {c}", style) if style else f" {marker} {i + 1}. {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
@@ -445,7 +470,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
return default
|
||||
|
||||
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: set[str]) -> set[str]:
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
|
||||
labels = []
|
||||
@@ -455,15 +480,13 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
suffix = " [no API key]"
|
||||
labels.append(f"{ts_label} ({ts_desc}){suffix}")
|
||||
|
||||
pre_selected_indices = [
|
||||
i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)
|
||||
if ts_key in enabled
|
||||
]
|
||||
pre_selected_indices = [i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS) if ts_key in enabled]
|
||||
|
||||
# Curses-based multi-select — arrow keys + space to toggle + enter to confirm.
|
||||
# simple_term_menu has rendering bugs in tmux, iTerm, and other terminals.
|
||||
try:
|
||||
import curses
|
||||
|
||||
selected = set(pre_selected_indices)
|
||||
result_holder = [None]
|
||||
|
||||
@@ -483,7 +506,13 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
header = f"Tools for {platform_label} — ↑↓ navigate, SPACE toggle, ENTER confirm"
|
||||
try:
|
||||
stdscr.addnstr(0, 0, header, max_x - 1, curses.A_BOLD | curses.color_pair(2) if curses.has_colors() else curses.A_BOLD)
|
||||
stdscr.addnstr(
|
||||
0,
|
||||
0,
|
||||
header,
|
||||
max_x - 1,
|
||||
curses.A_BOLD | curses.color_pair(2) if curses.has_colors() else curses.A_BOLD,
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
@@ -514,11 +543,11 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(labels)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(labels)
|
||||
elif key == ord(' '):
|
||||
elif key == ord(" "):
|
||||
if cursor in selected:
|
||||
selected.discard(cursor)
|
||||
else:
|
||||
@@ -526,7 +555,7 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = {CONFIGURABLE_TOOLSETS[i][0] for i in selected}
|
||||
return
|
||||
elif key in (27, ord('q')): # ESC or q
|
||||
elif key in (27, ord("q")): # ESC or q
|
||||
result_holder[0] = enabled
|
||||
return
|
||||
|
||||
@@ -565,9 +594,10 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
|
||||
# ─── Provider-Aware Configuration ────────────────────────────────────────────
|
||||
|
||||
|
||||
def _configure_toolset(ts_key: str, config: dict):
|
||||
"""Configure a toolset - provider selection + API keys.
|
||||
|
||||
|
||||
Uses TOOL_CATEGORIES for provider-aware config, falls back to simple
|
||||
env var prompts for toolsets not in TOOL_CATEGORIES.
|
||||
"""
|
||||
@@ -591,7 +621,9 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
req = cat["requires_python"]
|
||||
if sys.version_info < req:
|
||||
print()
|
||||
_print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})")
|
||||
_print_error(
|
||||
f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})"
|
||||
)
|
||||
_print_info(" Upgrade Python and reinstall to enable this tool.")
|
||||
return
|
||||
|
||||
@@ -610,7 +642,7 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
# Multiple providers - let user choose
|
||||
print()
|
||||
# Use custom title if provided (e.g. "Select Search Provider")
|
||||
title = cat.get("setup_title", f"Choose a provider")
|
||||
title = cat.get("setup_title", "Choose a provider")
|
||||
print(color(f" --- {icon} {name} - {title} ---", Colors.CYAN))
|
||||
if cat.get("setup_note"):
|
||||
_print_info(f" {cat['setup_note']}")
|
||||
@@ -626,7 +658,11 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
|
||||
configured = (
|
||||
" [active]"
|
||||
if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "")
|
||||
else ""
|
||||
)
|
||||
else:
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
@@ -688,9 +724,9 @@ def _configure_provider(provider: dict, config: dict):
|
||||
|
||||
if value:
|
||||
save_env_value(var["key"], value)
|
||||
_print_success(f" Saved")
|
||||
_print_success(" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
_print_warning(" Skipped")
|
||||
all_configured = False
|
||||
|
||||
# Run post-setup hooks if needed
|
||||
@@ -721,9 +757,9 @@ def _configure_simple_requirements(ts_key: str):
|
||||
value = _prompt(f" {var}", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Saved")
|
||||
_print_success(" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
_print_warning(" Skipped")
|
||||
|
||||
|
||||
def _reconfigure_tool(config: dict):
|
||||
@@ -827,9 +863,9 @@ def _reconfigure_provider(provider: dict, config: dict):
|
||||
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
|
||||
if value and value.strip():
|
||||
save_env_value(var["key"], value.strip())
|
||||
_print_success(f" Updated")
|
||||
_print_success(" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
_print_info(" Kept current")
|
||||
|
||||
|
||||
def _reconfigure_simple_requirements(ts_key: str):
|
||||
@@ -851,13 +887,14 @@ def _reconfigure_simple_requirements(ts_key: str):
|
||||
value = _prompt(f" {var} (Enter to keep current)", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Updated")
|
||||
_print_success(" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
_print_info(" Kept current")
|
||||
|
||||
|
||||
# ─── Main Entry Point ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
"""Entry point for `hermes tools` and `hermes setup tools`.
|
||||
|
||||
@@ -907,7 +944,8 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
# TTS (Edge vs OpenAI vs ElevenLabs), etc. are shown even when
|
||||
# a free provider exists.
|
||||
to_configure = [
|
||||
ts_key for ts_key in sorted(new_enabled)
|
||||
ts_key
|
||||
for ts_key in sorted(new_enabled)
|
||||
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
]
|
||||
|
||||
@@ -981,7 +1019,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
# Configure newly enabled toolsets that need API keys
|
||||
for ts_key in sorted(added):
|
||||
if (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key):
|
||||
if not _toolset_has_keys(ts_key):
|
||||
_configure_toolset(ts_key, config)
|
||||
|
||||
|
||||
@@ -7,23 +7,25 @@ Provides options for:
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
def log_info(msg: str):
|
||||
print(f"{color('→', Colors.CYAN)} {msg}")
|
||||
|
||||
|
||||
def log_success(msg: str):
|
||||
print(f"{color('✓', Colors.GREEN)} {msg}")
|
||||
|
||||
|
||||
def log_warn(msg: str):
|
||||
print(f"{color('⚠', Colors.YELLOW)} {msg}")
|
||||
|
||||
|
||||
def log_error(msg: str):
|
||||
print(f"{color('✗', Colors.RED)} {msg}")
|
||||
|
||||
@@ -42,7 +44,7 @@ def find_shell_configs() -> list:
|
||||
"""Find shell configuration files that might have PATH entries."""
|
||||
home = Path.home()
|
||||
configs = []
|
||||
|
||||
|
||||
candidates = [
|
||||
home / ".bashrc",
|
||||
home / ".bash_profile",
|
||||
@@ -50,11 +52,11 @@ def find_shell_configs() -> list:
|
||||
home / ".zshrc",
|
||||
home / ".zprofile",
|
||||
]
|
||||
|
||||
|
||||
for config in candidates:
|
||||
if config.exists():
|
||||
configs.append(config)
|
||||
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@@ -62,45 +64,45 @@ def remove_path_from_shell_configs():
|
||||
"""Remove Hermes PATH entries from shell configuration files."""
|
||||
configs = find_shell_configs()
|
||||
removed_from = []
|
||||
|
||||
|
||||
for config_path in configs:
|
||||
try:
|
||||
content = config_path.read_text()
|
||||
original_content = content
|
||||
|
||||
|
||||
# Remove lines containing hermes-agent or hermes PATH entries
|
||||
new_lines = []
|
||||
skip_next = False
|
||||
|
||||
for line in content.split('\n'):
|
||||
|
||||
for line in content.split("\n"):
|
||||
# Skip the "# Hermes Agent" comment and following line
|
||||
if '# Hermes Agent' in line or '# hermes-agent' in line:
|
||||
if "# Hermes Agent" in line or "# hermes-agent" in line:
|
||||
skip_next = True
|
||||
continue
|
||||
if skip_next and ('hermes' in line.lower() and 'PATH' in line):
|
||||
if skip_next and ("hermes" in line.lower() and "PATH" in line):
|
||||
skip_next = False
|
||||
continue
|
||||
skip_next = False
|
||||
|
||||
|
||||
# Remove any PATH line containing hermes
|
||||
if 'hermes' in line.lower() and ('PATH=' in line or 'path=' in line.lower()):
|
||||
if "hermes" in line.lower() and ("PATH=" in line or "path=" in line.lower()):
|
||||
continue
|
||||
|
||||
|
||||
new_lines.append(line)
|
||||
|
||||
new_content = '\n'.join(new_lines)
|
||||
|
||||
|
||||
new_content = "\n".join(new_lines)
|
||||
|
||||
# Clean up multiple blank lines
|
||||
while '\n\n\n' in new_content:
|
||||
new_content = new_content.replace('\n\n\n', '\n\n')
|
||||
|
||||
while "\n\n\n" in new_content:
|
||||
new_content = new_content.replace("\n\n\n", "\n\n")
|
||||
|
||||
if new_content != original_content:
|
||||
config_path.write_text(new_content)
|
||||
removed_from.append(config_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log_warn(f"Could not update {config_path}: {e}")
|
||||
|
||||
|
||||
return removed_from
|
||||
|
||||
|
||||
@@ -110,61 +112,49 @@ def remove_wrapper_script():
|
||||
Path.home() / ".local" / "bin" / "hermes",
|
||||
Path("/usr/local/bin/hermes"),
|
||||
]
|
||||
|
||||
|
||||
removed = []
|
||||
for wrapper in wrapper_paths:
|
||||
if wrapper.exists():
|
||||
try:
|
||||
# Check if it's our wrapper (contains hermes_cli reference)
|
||||
content = wrapper.read_text()
|
||||
if 'hermes_cli' in content or 'hermes-agent' in content:
|
||||
if "hermes_cli" in content or "hermes-agent" in content:
|
||||
wrapper.unlink()
|
||||
removed.append(wrapper)
|
||||
except Exception as e:
|
||||
log_warn(f"Could not remove {wrapper}: {e}")
|
||||
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def uninstall_gateway_service():
|
||||
"""Stop and uninstall the gateway service if running."""
|
||||
import platform
|
||||
|
||||
|
||||
if platform.system() != "Linux":
|
||||
return False
|
||||
|
||||
|
||||
service_file = Path.home() / ".config" / "systemd" / "user" / "hermes-gateway.service"
|
||||
|
||||
|
||||
if not service_file.exists():
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Stop the service
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "stop", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
subprocess.run(["systemctl", "--user", "stop", "hermes-gateway"], capture_output=True, check=False)
|
||||
|
||||
# Disable the service
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "disable", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
subprocess.run(["systemctl", "--user", "disable", "hermes-gateway"], capture_output=True, check=False)
|
||||
|
||||
# Remove service file
|
||||
service_file.unlink()
|
||||
|
||||
|
||||
# Reload systemd
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], capture_output=True, check=False)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
log_warn(f"Could not fully remove gateway service: {e}")
|
||||
return False
|
||||
@@ -173,20 +163,20 @@ def uninstall_gateway_service():
|
||||
def run_uninstall(args):
|
||||
"""
|
||||
Run the uninstall process.
|
||||
|
||||
|
||||
Options:
|
||||
- Full uninstall: removes code + ~/.hermes/ (configs, data, logs)
|
||||
- Keep data: removes code but keeps ~/.hermes/ for future reinstall
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
hermes_home = get_hermes_home()
|
||||
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.MAGENTA, Colors.BOLD))
|
||||
print(color("│ ⚕ Hermes Agent Uninstaller │", Colors.MAGENTA, Colors.BOLD))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.MAGENTA, Colors.BOLD))
|
||||
print()
|
||||
|
||||
|
||||
# Show what will be affected
|
||||
print(color("Current Installation:", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Code: {project_root}")
|
||||
@@ -194,7 +184,7 @@ def run_uninstall(args):
|
||||
print(f" Secrets: {hermes_home / '.env'}")
|
||||
print(f" Data: {hermes_home / 'cron/'}, {hermes_home / 'sessions/'}, {hermes_home / 'logs/'}")
|
||||
print()
|
||||
|
||||
|
||||
# Ask for confirmation
|
||||
print(color("Uninstall Options:", Colors.YELLOW, Colors.BOLD))
|
||||
print()
|
||||
@@ -206,21 +196,21 @@ def run_uninstall(args):
|
||||
print()
|
||||
print(" 3) " + color("Cancel", Colors.CYAN) + " - Don't uninstall")
|
||||
print()
|
||||
|
||||
|
||||
try:
|
||||
choice = input(color("Select option [1/2/3]: ", Colors.BOLD)).strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
print("Cancelled.")
|
||||
return
|
||||
|
||||
|
||||
if choice == "3" or choice.lower() in ("c", "cancel", "q", "quit", "n", "no"):
|
||||
print()
|
||||
print("Uninstall cancelled.")
|
||||
return
|
||||
|
||||
full_uninstall = (choice == "2")
|
||||
|
||||
|
||||
full_uninstall = choice == "2"
|
||||
|
||||
# Final confirmation
|
||||
print()
|
||||
if full_uninstall:
|
||||
@@ -228,7 +218,7 @@ def run_uninstall(args):
|
||||
print(color(" Including: configs, API keys, sessions, scheduled jobs, logs", Colors.RED))
|
||||
else:
|
||||
print("This will remove the Hermes code but keep your configuration and data.")
|
||||
|
||||
|
||||
print()
|
||||
try:
|
||||
confirm = input(f"Type '{color('yes', Colors.YELLOW)}' to confirm: ").strip().lower()
|
||||
@@ -236,23 +226,23 @@ def run_uninstall(args):
|
||||
print()
|
||||
print("Cancelled.")
|
||||
return
|
||||
|
||||
|
||||
if confirm != "yes":
|
||||
print()
|
||||
print("Uninstall cancelled.")
|
||||
return
|
||||
|
||||
|
||||
print()
|
||||
print(color("Uninstalling...", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
|
||||
|
||||
# 1. Stop and uninstall gateway service
|
||||
log_info("Checking for gateway service...")
|
||||
if uninstall_gateway_service():
|
||||
log_success("Gateway service stopped and removed")
|
||||
else:
|
||||
log_info("No gateway service found")
|
||||
|
||||
|
||||
# 2. Remove PATH entries from shell configs
|
||||
log_info("Removing PATH entries from shell configs...")
|
||||
removed_configs = remove_path_from_shell_configs()
|
||||
@@ -261,7 +251,7 @@ def run_uninstall(args):
|
||||
log_success(f"Updated {config}")
|
||||
else:
|
||||
log_info("No PATH entries found to remove")
|
||||
|
||||
|
||||
# 3. Remove wrapper script
|
||||
log_info("Removing hermes command...")
|
||||
removed_wrappers = remove_wrapper_script()
|
||||
@@ -270,10 +260,10 @@ def run_uninstall(args):
|
||||
log_success(f"Removed {wrapper}")
|
||||
else:
|
||||
log_info("No wrapper script found")
|
||||
|
||||
|
||||
# 4. Remove installation directory (code)
|
||||
log_info(f"Removing installation directory...")
|
||||
|
||||
log_info("Removing installation directory...")
|
||||
|
||||
# Check if we're running from within the install dir
|
||||
# We need to be careful here
|
||||
try:
|
||||
@@ -289,7 +279,7 @@ def run_uninstall(args):
|
||||
except Exception as e:
|
||||
log_warn(f"Could not fully remove {project_root}: {e}")
|
||||
log_info("You may need to manually remove it")
|
||||
|
||||
|
||||
# 5. Optionally remove ~/.hermes/ data directory
|
||||
if full_uninstall:
|
||||
log_info("Removing configuration and data...")
|
||||
@@ -302,22 +292,27 @@ def run_uninstall(args):
|
||||
log_info("You may need to manually remove it")
|
||||
else:
|
||||
log_info(f"Keeping configuration and data in {hermes_home}")
|
||||
|
||||
|
||||
# Done
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.GREEN, Colors.BOLD))
|
||||
print(color("│ ✓ Uninstall Complete! │", Colors.GREEN, Colors.BOLD))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.GREEN, Colors.BOLD))
|
||||
print()
|
||||
|
||||
|
||||
if not full_uninstall:
|
||||
print(color("Your configuration and data have been preserved:", Colors.CYAN))
|
||||
print(f" {hermes_home}/")
|
||||
print()
|
||||
print("To reinstall later with your existing settings:")
|
||||
print(color(" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash", Colors.DIM))
|
||||
print(
|
||||
color(
|
||||
" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash",
|
||||
Colors.DIM,
|
||||
)
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
print(color("Reload your shell to complete the process:", Colors.YELLOW))
|
||||
print(" source ~/.bashrc # or ~/.zshrc")
|
||||
print()
|
||||
|
||||
@@ -19,8 +19,7 @@ import os
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from typing import Any
|
||||
|
||||
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
|
||||
|
||||
@@ -156,8 +155,7 @@ class SessionDB:
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
|
||||
"ON sessions(title) WHERE title IS NOT NULL"
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique ON sessions(title) WHERE title IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
@@ -185,7 +183,7 @@ class SessionDB:
|
||||
session_id: str,
|
||||
source: str,
|
||||
model: str = None,
|
||||
model_config: Dict[str, Any] = None,
|
||||
model_config: dict[str, Any] = None,
|
||||
system_prompt: str = None,
|
||||
user_id: str = None,
|
||||
parent_session_id: str = None,
|
||||
@@ -225,9 +223,7 @@ class SessionDB:
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def update_token_counts(
|
||||
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0
|
||||
) -> None:
|
||||
def update_token_counts(self, session_id: str, input_tokens: int = 0, output_tokens: int = 0) -> None:
|
||||
"""Increment token counters on a session."""
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
@@ -238,11 +234,9 @@ class SessionDB:
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
def get_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
"""Get a session by ID."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
cursor = self._conn.execute("SELECT * FROM sessions WHERE id = ?", (session_id,))
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
@@ -250,7 +244,7 @@ class SessionDB:
|
||||
MAX_TITLE_LENGTH = 100
|
||||
|
||||
@staticmethod
|
||||
def sanitize_title(title: Optional[str]) -> Optional[str]:
|
||||
def sanitize_title(title: str | None) -> str | None:
|
||||
"""Validate and sanitize a session title.
|
||||
|
||||
- Strips leading/trailing whitespace
|
||||
@@ -271,27 +265,26 @@ class SessionDB:
|
||||
# Remove ASCII control characters (0x00-0x1F, 0x7F) but keep
|
||||
# whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be
|
||||
# normalized to spaces by the whitespace collapsing step below
|
||||
cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title)
|
||||
cleaned = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", title)
|
||||
|
||||
# Remove problematic Unicode control characters:
|
||||
# - Zero-width chars (U+200B-U+200F, U+FEFF)
|
||||
# - Directional overrides (U+202A-U+202E, U+2066-U+2069)
|
||||
# - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB)
|
||||
cleaned = re.sub(
|
||||
r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]',
|
||||
'', cleaned,
|
||||
r"[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]",
|
||||
"",
|
||||
cleaned,
|
||||
)
|
||||
|
||||
# Collapse internal whitespace runs and strip
|
||||
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
||||
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||||
|
||||
if not cleaned:
|
||||
return None
|
||||
|
||||
if len(cleaned) > SessionDB.MAX_TITLE_LENGTH:
|
||||
raise ValueError(
|
||||
f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})"
|
||||
)
|
||||
raise ValueError(f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})")
|
||||
|
||||
return cleaned
|
||||
|
||||
@@ -312,9 +305,7 @@ class SessionDB:
|
||||
)
|
||||
conflict = cursor.fetchone()
|
||||
if conflict:
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
raise ValueError(f"Title '{title}' is already in use by session {conflict['id']}")
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
@@ -322,23 +313,19 @@ class SessionDB:
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||
def get_session_title(self, session_id: str) -> str | None:
|
||||
"""Get the title for a session, or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
cursor = self._conn.execute("SELECT title FROM sessions WHERE id = ?", (session_id,))
|
||||
row = cursor.fetchone()
|
||||
return row["title"] if row else None
|
||||
|
||||
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
|
||||
def get_session_by_title(self, title: str) -> dict[str, Any] | None:
|
||||
"""Look up a session by exact title. Returns session dict or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE title = ?", (title,)
|
||||
)
|
||||
cursor = self._conn.execute("SELECT * FROM sessions WHERE title = ?", (title,))
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_by_title(self, title: str) -> Optional[str]:
|
||||
def resolve_session_by_title(self, title: str) -> str | None:
|
||||
"""Resolve a title to a session ID, preferring the latest in a lineage.
|
||||
|
||||
If the exact title exists, returns that session's ID.
|
||||
@@ -353,8 +340,7 @@ class SessionDB:
|
||||
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
||||
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, title, started_at FROM sessions "
|
||||
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
"SELECT id, title, started_at FROM sessions WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
(f"{escaped} #%",),
|
||||
)
|
||||
numbered = cursor.fetchall()
|
||||
@@ -373,8 +359,9 @@ class SessionDB:
|
||||
the highest existing number and increments.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Strip existing #N suffix to find the true base
|
||||
match = re.match(r'^(.*?) #(\d+)$', base_title)
|
||||
match = re.match(r"^(.*?) #(\d+)$", base_title)
|
||||
if match:
|
||||
base = match.group(1)
|
||||
else:
|
||||
@@ -395,7 +382,7 @@ class SessionDB:
|
||||
# Find the highest number
|
||||
max_num = 1 # The unnumbered original counts as #1
|
||||
for t in existing:
|
||||
m = re.match(r'^.* #(\d+)$', t)
|
||||
m = re.match(r"^.* #(\d+)$", t)
|
||||
if m:
|
||||
max_num = max(max_num, int(m.group(1)))
|
||||
|
||||
@@ -406,7 +393,7 @@ class SessionDB:
|
||||
source: str = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List sessions with preview (first user message) and last active timestamp.
|
||||
|
||||
Returns dicts with keys: id, source, model, title, started_at, ended_at,
|
||||
@@ -506,7 +493,7 @@ class SessionDB:
|
||||
self._conn.commit()
|
||||
return msg_id
|
||||
|
||||
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
def get_messages(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""Load all messages for a session, ordered by timestamp."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
|
||||
@@ -524,7 +511,7 @@ class SessionDB:
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load messages in the OpenAI conversation format (role + content dicts).
|
||||
Used by the gateway to restore conversation history.
|
||||
@@ -556,11 +543,11 @@ class SessionDB:
|
||||
def search_messages(
|
||||
self,
|
||||
query: str,
|
||||
source_filter: List[str] = None,
|
||||
role_filter: List[str] = None,
|
||||
source_filter: list[str] = None,
|
||||
role_filter: list[str] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Full-text search across session messages using FTS5.
|
||||
|
||||
@@ -628,8 +615,7 @@ class SessionDB:
|
||||
(match["session_id"], match["id"], match["id"]),
|
||||
)
|
||||
context_msgs = [
|
||||
{"role": r["role"], "content": (r["content"] or "")[:200]}
|
||||
for r in ctx_cursor.fetchall()
|
||||
{"role": r["role"], "content": (r["content"] or "")[:200]} for r in ctx_cursor.fetchall()
|
||||
]
|
||||
match["context"] = context_msgs
|
||||
except Exception:
|
||||
@@ -645,7 +631,7 @@ class SessionDB:
|
||||
source: str = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List sessions, optionally filtered by source."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
@@ -666,9 +652,7 @@ class SessionDB:
|
||||
def session_count(self, source: str = None) -> int:
|
||||
"""Count sessions, optionally filtered by source."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
|
||||
)
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions WHERE source = ?", (source,))
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
return cursor.fetchone()[0]
|
||||
@@ -676,9 +660,7 @@ class SessionDB:
|
||||
def message_count(self, session_id: str = None) -> int:
|
||||
"""Count messages, optionally for a specific session."""
|
||||
if session_id:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,))
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM messages")
|
||||
return cursor.fetchone()[0]
|
||||
@@ -687,7 +669,7 @@ class SessionDB:
|
||||
# Export and cleanup
|
||||
# =========================================================================
|
||||
|
||||
def export_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
def export_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
"""Export a single session with all its messages as a dict."""
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
@@ -695,7 +677,7 @@ class SessionDB:
|
||||
messages = self.get_messages(session_id)
|
||||
return {**session, "messages": messages}
|
||||
|
||||
def export_all(self, source: str = None) -> List[Dict[str, Any]]:
|
||||
def export_all(self, source: str = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Export all sessions (with messages) as a list of dicts.
|
||||
Suitable for writing to a JSONL file for backup/analysis.
|
||||
@@ -709,9 +691,7 @@ class SessionDB:
|
||||
|
||||
def clear_messages(self, session_id: str) -> None:
|
||||
"""Delete all messages for a session and reset its counters."""
|
||||
self._conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||
(session_id,),
|
||||
@@ -720,9 +700,7 @@ class SessionDB:
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session and all its messages. Returns True if found."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,))
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
@@ -736,6 +714,7 @@ class SessionDB:
|
||||
Only prunes ended sessions (not active ones).
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
cutoff = _time.time() - (older_than_days * 86400)
|
||||
|
||||
if source:
|
||||
|
||||
@@ -20,11 +20,10 @@ Public API (signatures preserved from the original 2,400-line version):
|
||||
check_tool_availability(quiet) -> tuple
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
from tools.registry import registry
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
@@ -36,6 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
# Async Bridging (single source of truth -- used by registry.dispatch too)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync context.
|
||||
|
||||
@@ -56,6 +56,7 @@ def _run_async(coro):
|
||||
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=300)
|
||||
@@ -66,6 +67,7 @@ def _run_async(coro):
|
||||
# Tool Discovery (importing each module triggers its registry.register calls)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _discover_tools():
|
||||
"""Import all tool modules to trigger their registry.register() calls.
|
||||
|
||||
@@ -97,6 +99,7 @@ def _discover_tools():
|
||||
"tools.homeassistant_tool",
|
||||
]
|
||||
import importlib
|
||||
|
||||
for mod_name in _modules:
|
||||
try:
|
||||
importlib.import_module(mod_name)
|
||||
@@ -109,6 +112,7 @@ _discover_tools()
|
||||
# MCP tool discovery (external MCP servers from config)
|
||||
try:
|
||||
from tools.mcp_tool import discover_mcp_tools
|
||||
|
||||
discover_mcp_tools()
|
||||
except Exception as e:
|
||||
logger.debug("MCP tool discovery failed: %s", e)
|
||||
@@ -118,13 +122,13 @@ except Exception as e:
|
||||
# Backward-compat constants (built once after discovery)
|
||||
# =============================================================================
|
||||
|
||||
TOOL_TO_TOOLSET_MAP: Dict[str, str] = registry.get_tool_to_toolset_map()
|
||||
TOOL_TO_TOOLSET_MAP: dict[str, str] = registry.get_tool_to_toolset_map()
|
||||
|
||||
TOOLSET_REQUIREMENTS: Dict[str, dict] = registry.get_toolset_requirements()
|
||||
TOOLSET_REQUIREMENTS: dict[str, dict] = registry.get_toolset_requirements()
|
||||
|
||||
# Resolved tool names from the last get_tool_definitions() call.
|
||||
# Used by code_execution_tool to know which tools are available in this session.
|
||||
_last_resolved_tool_names: List[str] = []
|
||||
_last_resolved_tool_names: list[str] = []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -139,18 +143,29 @@ _LEGACY_TOOLSET_MAP = {
|
||||
"image_tools": ["image_generate"],
|
||||
"skills_tools": ["skills_list", "skill_view", "skill_manage"],
|
||||
"browser_tools": [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
"browser_navigate",
|
||||
"browser_snapshot",
|
||||
"browser_click",
|
||||
"browser_type",
|
||||
"browser_scroll",
|
||||
"browser_back",
|
||||
"browser_press",
|
||||
"browser_close",
|
||||
"browser_get_images",
|
||||
"browser_vision",
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
"rl_list_environments",
|
||||
"rl_select_environment",
|
||||
"rl_get_current_config",
|
||||
"rl_edit_config",
|
||||
"rl_start_training",
|
||||
"rl_check_status",
|
||||
"rl_stop_training",
|
||||
"rl_get_results",
|
||||
"rl_list_runs",
|
||||
"rl_test_inference",
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search_files"],
|
||||
"tts_tools": ["text_to_speech"],
|
||||
@@ -161,11 +176,12 @@ _LEGACY_TOOLSET_MAP = {
|
||||
# get_tool_definitions (the main schema provider)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_tool_definitions(
|
||||
enabled_toolsets: List[str] = None,
|
||||
disabled_toolsets: List[str] = None,
|
||||
enabled_toolsets: list[str] = None,
|
||||
disabled_toolsets: list[str] = None,
|
||||
quiet_mode: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for model API calls with toolset-based filtering.
|
||||
|
||||
@@ -200,6 +216,7 @@ def get_tool_definitions(
|
||||
|
||||
elif disabled_toolsets:
|
||||
from toolsets import get_all_toolsets
|
||||
|
||||
for ts_name in get_all_toolsets():
|
||||
tools_to_include.update(resolve_toolset(ts_name))
|
||||
|
||||
@@ -219,6 +236,7 @@ def get_tool_definitions(
|
||||
print(f"⚠️ Unknown toolset: {toolset_name}")
|
||||
else:
|
||||
from toolsets import get_all_toolsets
|
||||
|
||||
for ts_name in get_all_toolsets():
|
||||
tools_to_include.update(resolve_toolset(ts_name))
|
||||
|
||||
@@ -230,6 +248,7 @@ def get_tool_definitions(
|
||||
# execute_code" even when the user disabled the web toolset (#560-discord).
|
||||
if "execute_code" in tools_to_include:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
for i, td in enumerate(filtered_tools):
|
||||
@@ -263,9 +282,9 @@ _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
task_id: Optional[str] = None,
|
||||
user_task: Optional[str] = None,
|
||||
function_args: dict[str, Any],
|
||||
task_id: str | None = None,
|
||||
user_task: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Main function call dispatcher that routes calls to the tool registry.
|
||||
@@ -285,13 +304,15 @@ def handle_function_call(
|
||||
|
||||
if function_name == "execute_code":
|
||||
return registry.dispatch(
|
||||
function_name, function_args,
|
||||
function_name,
|
||||
function_args,
|
||||
task_id=task_id,
|
||||
enabled_tools=_last_resolved_tool_names,
|
||||
)
|
||||
|
||||
return registry.dispatch(
|
||||
function_name, function_args,
|
||||
function_name,
|
||||
function_args,
|
||||
task_id=task_id,
|
||||
user_task=user_task,
|
||||
)
|
||||
@@ -306,26 +327,27 @@ def handle_function_call(
|
||||
# Backward-compat wrapper functions
|
||||
# =============================================================================
|
||||
|
||||
def get_all_tool_names() -> List[str]:
|
||||
|
||||
def get_all_tool_names() -> list[str]:
|
||||
"""Return all registered tool names."""
|
||||
return registry.get_all_tool_names()
|
||||
|
||||
|
||||
def get_toolset_for_tool(tool_name: str) -> Optional[str]:
|
||||
def get_toolset_for_tool(tool_name: str) -> str | None:
|
||||
"""Return the toolset a tool belongs to."""
|
||||
return registry.get_toolset_for_tool(tool_name)
|
||||
|
||||
|
||||
def get_available_toolsets() -> Dict[str, dict]:
|
||||
def get_available_toolsets() -> dict[str, dict]:
|
||||
"""Return toolset availability info for UI display."""
|
||||
return registry.get_available_toolsets()
|
||||
|
||||
|
||||
def check_toolset_requirements() -> Dict[str, bool]:
|
||||
def check_toolset_requirements() -> dict[str, bool]:
|
||||
"""Return {toolset: available_bool} for every registered toolset."""
|
||||
return registry.check_toolset_requirements()
|
||||
|
||||
|
||||
def check_tool_availability(quiet: bool = False) -> Tuple[List[str], List[dict]]:
|
||||
def check_tool_availability(quiet: bool = False) -> tuple[list[str], list[dict]]:
|
||||
"""Return (available_toolsets, unavailable_info)."""
|
||||
return registry.check_tool_availability(quiet=quiet)
|
||||
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
modal = ["swe-rex[modal]>=1.4.0"]
|
||||
daytona = ["daytona>=0.148.0"]
|
||||
dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0"]
|
||||
dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0", "ruff", "pre-commit", "watchfiles"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
cron = ["croniter"]
|
||||
slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
@@ -76,6 +76,46 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron", "honcho_integration"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I", "UP", "B", "SIM"]
|
||||
ignore = [
|
||||
"E402", # late imports — intentional throughout codebase
|
||||
"E501", # line too long — handled by formatter where it can
|
||||
"E731", # lambda assignments — used in registry pattern
|
||||
"E741", # ambiguous variable name — existing patterns
|
||||
"F811", # redefined unused — intentional overrides
|
||||
"F841", # unused variable — cleanup separately
|
||||
"B007", # unused loop variable — cleanup separately
|
||||
"B904", # raise from — too noisy to gate on
|
||||
"B905", # zip strict — cleanup separately
|
||||
"B027", # empty method without abstract decorator
|
||||
"SIM102", # collapsible if — readability preference
|
||||
"SIM103", # needless bool — readability preference
|
||||
"SIM105", # suppressible exception — existing pattern
|
||||
"SIM108", # ternary — readability preference
|
||||
"SIM110", # reimplemented builtin
|
||||
"SIM112", # uncapitalized env var
|
||||
"SIM115", # open file with context handler
|
||||
"SIM117", # multiple with statements
|
||||
"SIM118", # in-dict-keys — cleanup separately
|
||||
"SIM212", # if-expr twisted arms
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"batch_runner.py" = ["F821"]
|
||||
"tools/patch_parser.py" = ["F821"]
|
||||
"gateway/run.py" = ["F821"]
|
||||
"gateway/channel_directory.py" = ["F401"]
|
||||
"hermes_cli/doctor.py" = ["F401"]
|
||||
"tools/image_generation_tool.py" = ["F401"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["tools", "hermes_cli", "gateway", "agent", "cron"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
markers = [
|
||||
|
||||
1550
run_agent.py
1550
run_agent.py
File diff suppressed because it is too large
Load Diff
@@ -16,249 +16,222 @@ for the AI agent to access all capabilities.
|
||||
"""
|
||||
|
||||
# Export all tools for easy importing
|
||||
from .web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key
|
||||
)
|
||||
|
||||
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
|
||||
from .terminal_tool import (
|
||||
terminal_tool,
|
||||
check_terminal_requirements,
|
||||
cleanup_vm,
|
||||
cleanup_all_environments,
|
||||
get_active_environments_info,
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
TERMINAL_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .vision_tools import (
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements
|
||||
)
|
||||
|
||||
from .mixture_of_agents_tool import (
|
||||
mixture_of_agents_tool,
|
||||
check_moa_requirements
|
||||
)
|
||||
|
||||
from .image_generation_tool import (
|
||||
image_generate_tool,
|
||||
check_image_generation_requirements
|
||||
)
|
||||
|
||||
from .skills_tool import (
|
||||
skills_list,
|
||||
skill_view,
|
||||
check_skills_requirements,
|
||||
SKILLS_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .skill_manager_tool import (
|
||||
skill_manage,
|
||||
check_skill_manage_requirements,
|
||||
SKILL_MANAGE_SCHEMA
|
||||
)
|
||||
|
||||
# Browser automation tools (agent-browser + Browserbase)
|
||||
from .browser_tool import (
|
||||
browser_navigate,
|
||||
browser_snapshot,
|
||||
browser_click,
|
||||
browser_type,
|
||||
browser_scroll,
|
||||
BROWSER_TOOL_SCHEMAS,
|
||||
browser_back,
|
||||
browser_press,
|
||||
browser_click,
|
||||
browser_close,
|
||||
browser_get_images,
|
||||
browser_navigate,
|
||||
browser_press,
|
||||
browser_scroll,
|
||||
browser_snapshot,
|
||||
browser_type,
|
||||
browser_vision,
|
||||
cleanup_browser,
|
||||
cleanup_all_browsers,
|
||||
get_active_browser_sessions,
|
||||
check_browser_requirements,
|
||||
BROWSER_TOOL_SCHEMAS
|
||||
)
|
||||
|
||||
# Cronjob management tools (CLI-only, hermes-cli toolset)
|
||||
from .cronjob_tools import (
|
||||
schedule_cronjob,
|
||||
list_cronjobs,
|
||||
remove_cronjob,
|
||||
check_cronjob_requirements,
|
||||
get_cronjob_tool_definitions,
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA
|
||||
)
|
||||
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_list_runs,
|
||||
rl_test_inference,
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
)
|
||||
|
||||
# File manipulation tools (read, write, patch, search)
|
||||
from .file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
patch_tool,
|
||||
search_tool,
|
||||
get_file_tools,
|
||||
clear_file_ops_cache,
|
||||
)
|
||||
|
||||
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
|
||||
from .tts_tool import (
|
||||
text_to_speech_tool,
|
||||
check_tts_requirements,
|
||||
)
|
||||
|
||||
# Planning & task management tool
|
||||
from .todo_tool import (
|
||||
todo_tool,
|
||||
check_todo_requirements,
|
||||
TODO_SCHEMA,
|
||||
TodoStore,
|
||||
cleanup_all_browsers,
|
||||
cleanup_browser,
|
||||
get_active_browser_sessions,
|
||||
)
|
||||
|
||||
# Clarifying questions tool (interactive Q&A with the user)
|
||||
from .clarify_tool import (
|
||||
clarify_tool,
|
||||
check_clarify_requirements,
|
||||
CLARIFY_SCHEMA,
|
||||
check_clarify_requirements,
|
||||
clarify_tool,
|
||||
)
|
||||
|
||||
# Code execution sandbox (programmatic tool calling)
|
||||
from .code_execution_tool import (
|
||||
execute_code,
|
||||
check_sandbox_requirements,
|
||||
EXECUTE_CODE_SCHEMA,
|
||||
check_sandbox_requirements,
|
||||
execute_code,
|
||||
)
|
||||
|
||||
# Cronjob management tools (CLI-only, hermes-cli toolset)
|
||||
from .cronjob_tools import (
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA,
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
check_cronjob_requirements,
|
||||
get_cronjob_tool_definitions,
|
||||
list_cronjobs,
|
||||
remove_cronjob,
|
||||
schedule_cronjob,
|
||||
)
|
||||
|
||||
# Subagent delegation (spawn child agents with isolated context)
|
||||
from .delegate_tool import (
|
||||
delegate_task,
|
||||
check_delegate_requirements,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
check_delegate_requirements,
|
||||
delegate_task,
|
||||
)
|
||||
|
||||
# File manipulation tools (read, write, patch, search)
|
||||
from .file_tools import (
|
||||
clear_file_ops_cache,
|
||||
get_file_tools,
|
||||
patch_tool,
|
||||
read_file_tool,
|
||||
search_tool,
|
||||
write_file_tool,
|
||||
)
|
||||
from .image_generation_tool import check_image_generation_requirements, image_generate_tool
|
||||
from .mixture_of_agents_tool import check_moa_requirements, mixture_of_agents_tool
|
||||
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
rl_check_status,
|
||||
rl_edit_config,
|
||||
rl_get_current_config,
|
||||
rl_get_results,
|
||||
rl_list_environments,
|
||||
rl_list_runs,
|
||||
rl_select_environment,
|
||||
rl_start_training,
|
||||
rl_stop_training,
|
||||
rl_test_inference,
|
||||
)
|
||||
from .skill_manager_tool import SKILL_MANAGE_SCHEMA, check_skill_manage_requirements, skill_manage
|
||||
from .skills_tool import SKILLS_TOOL_DESCRIPTION, check_skills_requirements, skill_view, skills_list
|
||||
|
||||
# Primary terminal tool (mini-swe-agent backend: local/docker/singularity/modal/daytona)
|
||||
from .terminal_tool import (
|
||||
TERMINAL_TOOL_DESCRIPTION,
|
||||
check_terminal_requirements,
|
||||
cleanup_all_environments,
|
||||
cleanup_vm,
|
||||
clear_task_env_overrides,
|
||||
get_active_environments_info,
|
||||
register_task_env_overrides,
|
||||
terminal_tool,
|
||||
)
|
||||
|
||||
# Planning & task management tool
|
||||
from .todo_tool import (
|
||||
TODO_SCHEMA,
|
||||
TodoStore,
|
||||
check_todo_requirements,
|
||||
todo_tool,
|
||||
)
|
||||
|
||||
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
|
||||
from .tts_tool import (
|
||||
check_tts_requirements,
|
||||
text_to_speech_tool,
|
||||
)
|
||||
from .vision_tools import check_vision_requirements, vision_analyze_tool
|
||||
from .web_tools import check_firecrawl_api_key, web_crawl_tool, web_extract_tool, web_search_tool
|
||||
|
||||
|
||||
# File tools have no external requirements - they use the terminal backend
|
||||
def check_file_requirements():
|
||||
"""File tools only require terminal backend to be available."""
|
||||
from .terminal_tool import check_terminal_requirements
|
||||
|
||||
return check_terminal_requirements()
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Web tools
|
||||
'web_search_tool',
|
||||
'web_extract_tool',
|
||||
'web_crawl_tool',
|
||||
'check_firecrawl_api_key',
|
||||
"web_search_tool",
|
||||
"web_extract_tool",
|
||||
"web_crawl_tool",
|
||||
"check_firecrawl_api_key",
|
||||
# Terminal tools (mini-swe-agent backend)
|
||||
'terminal_tool',
|
||||
'check_terminal_requirements',
|
||||
'cleanup_vm',
|
||||
'cleanup_all_environments',
|
||||
'get_active_environments_info',
|
||||
'register_task_env_overrides',
|
||||
'clear_task_env_overrides',
|
||||
'TERMINAL_TOOL_DESCRIPTION',
|
||||
"terminal_tool",
|
||||
"check_terminal_requirements",
|
||||
"cleanup_vm",
|
||||
"cleanup_all_environments",
|
||||
"get_active_environments_info",
|
||||
"register_task_env_overrides",
|
||||
"clear_task_env_overrides",
|
||||
"TERMINAL_TOOL_DESCRIPTION",
|
||||
# Vision tools
|
||||
'vision_analyze_tool',
|
||||
'check_vision_requirements',
|
||||
"vision_analyze_tool",
|
||||
"check_vision_requirements",
|
||||
# MoA tools
|
||||
'mixture_of_agents_tool',
|
||||
'check_moa_requirements',
|
||||
"mixture_of_agents_tool",
|
||||
"check_moa_requirements",
|
||||
# Image generation tools
|
||||
'image_generate_tool',
|
||||
'check_image_generation_requirements',
|
||||
"image_generate_tool",
|
||||
"check_image_generation_requirements",
|
||||
# Skills tools
|
||||
'skills_list',
|
||||
'skill_view',
|
||||
'check_skills_requirements',
|
||||
'SKILLS_TOOL_DESCRIPTION',
|
||||
"skills_list",
|
||||
"skill_view",
|
||||
"check_skills_requirements",
|
||||
"SKILLS_TOOL_DESCRIPTION",
|
||||
# Skill management
|
||||
'skill_manage',
|
||||
'check_skill_manage_requirements',
|
||||
'SKILL_MANAGE_SCHEMA',
|
||||
"skill_manage",
|
||||
"check_skill_manage_requirements",
|
||||
"SKILL_MANAGE_SCHEMA",
|
||||
# Browser automation tools
|
||||
'browser_navigate',
|
||||
'browser_snapshot',
|
||||
'browser_click',
|
||||
'browser_type',
|
||||
'browser_scroll',
|
||||
'browser_back',
|
||||
'browser_press',
|
||||
'browser_close',
|
||||
'browser_get_images',
|
||||
'browser_vision',
|
||||
'cleanup_browser',
|
||||
'cleanup_all_browsers',
|
||||
'get_active_browser_sessions',
|
||||
'check_browser_requirements',
|
||||
'BROWSER_TOOL_SCHEMAS',
|
||||
"browser_navigate",
|
||||
"browser_snapshot",
|
||||
"browser_click",
|
||||
"browser_type",
|
||||
"browser_scroll",
|
||||
"browser_back",
|
||||
"browser_press",
|
||||
"browser_close",
|
||||
"browser_get_images",
|
||||
"browser_vision",
|
||||
"cleanup_browser",
|
||||
"cleanup_all_browsers",
|
||||
"get_active_browser_sessions",
|
||||
"check_browser_requirements",
|
||||
"BROWSER_TOOL_SCHEMAS",
|
||||
# Cronjob management tools (CLI-only)
|
||||
'schedule_cronjob',
|
||||
'list_cronjobs',
|
||||
'remove_cronjob',
|
||||
'check_cronjob_requirements',
|
||||
'get_cronjob_tool_definitions',
|
||||
'SCHEDULE_CRONJOB_SCHEMA',
|
||||
'LIST_CRONJOBS_SCHEMA',
|
||||
'REMOVE_CRONJOB_SCHEMA',
|
||||
"schedule_cronjob",
|
||||
"list_cronjobs",
|
||||
"remove_cronjob",
|
||||
"check_cronjob_requirements",
|
||||
"get_cronjob_tool_definitions",
|
||||
"SCHEDULE_CRONJOB_SCHEMA",
|
||||
"LIST_CRONJOBS_SCHEMA",
|
||||
"REMOVE_CRONJOB_SCHEMA",
|
||||
# RL Training tools
|
||||
'rl_list_environments',
|
||||
'rl_select_environment',
|
||||
'rl_get_current_config',
|
||||
'rl_edit_config',
|
||||
'rl_start_training',
|
||||
'rl_check_status',
|
||||
'rl_stop_training',
|
||||
'rl_get_results',
|
||||
'rl_list_runs',
|
||||
'rl_test_inference',
|
||||
'check_rl_api_keys',
|
||||
'get_missing_keys',
|
||||
"rl_list_environments",
|
||||
"rl_select_environment",
|
||||
"rl_get_current_config",
|
||||
"rl_edit_config",
|
||||
"rl_start_training",
|
||||
"rl_check_status",
|
||||
"rl_stop_training",
|
||||
"rl_get_results",
|
||||
"rl_list_runs",
|
||||
"rl_test_inference",
|
||||
"check_rl_api_keys",
|
||||
"get_missing_keys",
|
||||
# File manipulation tools
|
||||
'read_file_tool',
|
||||
'write_file_tool',
|
||||
'patch_tool',
|
||||
'search_tool',
|
||||
'get_file_tools',
|
||||
'clear_file_ops_cache',
|
||||
'check_file_requirements',
|
||||
"read_file_tool",
|
||||
"write_file_tool",
|
||||
"patch_tool",
|
||||
"search_tool",
|
||||
"get_file_tools",
|
||||
"clear_file_ops_cache",
|
||||
"check_file_requirements",
|
||||
# Text-to-speech tools
|
||||
'text_to_speech_tool',
|
||||
'check_tts_requirements',
|
||||
"text_to_speech_tool",
|
||||
"check_tts_requirements",
|
||||
# Planning & task management tool
|
||||
'todo_tool',
|
||||
'check_todo_requirements',
|
||||
'TODO_SCHEMA',
|
||||
'TodoStore',
|
||||
"todo_tool",
|
||||
"check_todo_requirements",
|
||||
"TODO_SCHEMA",
|
||||
"TodoStore",
|
||||
# Clarifying questions tool
|
||||
'clarify_tool',
|
||||
'check_clarify_requirements',
|
||||
'CLARIFY_SCHEMA',
|
||||
"clarify_tool",
|
||||
"check_clarify_requirements",
|
||||
"CLARIFY_SCHEMA",
|
||||
# Code execution sandbox
|
||||
'execute_code',
|
||||
'check_sandbox_requirements',
|
||||
'EXECUTE_CODE_SCHEMA',
|
||||
"execute_code",
|
||||
"check_sandbox_requirements",
|
||||
"EXECUTE_CODE_SCHEMA",
|
||||
# Subagent delegation
|
||||
'delegate_task',
|
||||
'check_delegate_requirements',
|
||||
'DELEGATE_TASK_SCHEMA',
|
||||
"delegate_task",
|
||||
"check_delegate_requirements",
|
||||
"DELEGATE_TASK_SCHEMA",
|
||||
]
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,32 +20,32 @@ logger = logging.getLogger(__name__)
|
||||
# =========================================================================
|
||||
|
||||
DANGEROUS_PATTERNS = [
|
||||
(r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"),
|
||||
(r'\brm\s+-[^\s]*r', "recursive delete"),
|
||||
(r'\brm\s+--recursive\b', "recursive delete (long flag)"),
|
||||
(r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"),
|
||||
(r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"),
|
||||
(r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"),
|
||||
(r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"),
|
||||
(r'\bmkfs\b', "format filesystem"),
|
||||
(r'\bdd\s+.*if=', "disk copy"),
|
||||
(r'>\s*/dev/sd', "write to block device"),
|
||||
(r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"),
|
||||
(r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"),
|
||||
(r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"),
|
||||
(r'>\s*/etc/', "overwrite system config"),
|
||||
(r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"),
|
||||
(r'\bkill\s+-9\s+-1\b', "kill all processes"),
|
||||
(r'\bpkill\s+-9\b', "force kill processes"),
|
||||
(r':()\s*{\s*:\s*\|\s*:&\s*}\s*;:', "fork bomb"),
|
||||
(r'\b(bash|sh|zsh)\s+-c\s+', "shell command via -c flag"),
|
||||
(r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"),
|
||||
(r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"),
|
||||
(r'\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b', "execute remote script via process substitution"),
|
||||
(r'\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)', "overwrite system file via tee"),
|
||||
(r'\bxargs\s+.*\brm\b', "xargs with rm"),
|
||||
(r'\bfind\b.*-exec\s+(/\S*/)?rm\b', "find -exec rm"),
|
||||
(r'\bfind\b.*-delete\b', "find -delete"),
|
||||
(r"\brm\s+(-[^\s]*\s+)*/", "delete in root path"),
|
||||
(r"\brm\s+-[^\s]*r", "recursive delete"),
|
||||
(r"\brm\s+--recursive\b", "recursive delete (long flag)"),
|
||||
(r"\bchmod\s+(-[^\s]*\s+)*777\b", "world-writable permissions"),
|
||||
(r"\bchmod\s+--recursive\b.*777", "recursive world-writable (long flag)"),
|
||||
(r"\bchown\s+(-[^\s]*)?R\s+root", "recursive chown to root"),
|
||||
(r"\bchown\s+--recursive\b.*root", "recursive chown to root (long flag)"),
|
||||
(r"\bmkfs\b", "format filesystem"),
|
||||
(r"\bdd\s+.*if=", "disk copy"),
|
||||
(r">\s*/dev/sd", "write to block device"),
|
||||
(r"\bDROP\s+(TABLE|DATABASE)\b", "SQL DROP"),
|
||||
(r"\bDELETE\s+FROM\b(?!.*\bWHERE\b)", "SQL DELETE without WHERE"),
|
||||
(r"\bTRUNCATE\s+(TABLE)?\s*\w", "SQL TRUNCATE"),
|
||||
(r">\s*/etc/", "overwrite system config"),
|
||||
(r"\bsystemctl\s+(stop|disable|mask)\b", "stop/disable system service"),
|
||||
(r"\bkill\s+-9\s+-1\b", "kill all processes"),
|
||||
(r"\bpkill\s+-9\b", "force kill processes"),
|
||||
(r":()\s*{\s*:\s*\|\s*:&\s*}\s*;:", "fork bomb"),
|
||||
(r"\b(bash|sh|zsh)\s+-c\s+", "shell command via -c flag"),
|
||||
(r"\b(python[23]?|perl|ruby|node)\s+-[ec]\s+", "script execution via -e/-c flag"),
|
||||
(r"\b(curl|wget)\b.*\|\s*(ba)?sh\b", "pipe remote content to shell"),
|
||||
(r"\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b", "execute remote script via process substitution"),
|
||||
(r"\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)", "overwrite system file via tee"),
|
||||
(r"\bxargs\s+.*\brm\b", "xargs with rm"),
|
||||
(r"\bfind\b.*-exec\s+(/\S*/)?rm\b", "find -exec rm"),
|
||||
(r"\bfind\b.*-delete\b", "find -delete"),
|
||||
]
|
||||
|
||||
|
||||
@@ -54,6 +53,7 @@ DANGEROUS_PATTERNS = [
|
||||
# Detection
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def detect_dangerous_command(command: str) -> tuple:
|
||||
"""Check if a command matches any dangerous patterns.
|
||||
|
||||
@@ -63,7 +63,7 @@ def detect_dangerous_command(command: str) -> tuple:
|
||||
command_lower = command.lower()
|
||||
for pattern, description in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, command_lower, re.IGNORECASE | re.DOTALL):
|
||||
pattern_key = pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20]
|
||||
pattern_key = pattern.split(r"\b")[1] if r"\b" in pattern else pattern[:20]
|
||||
return (True, pattern_key, description)
|
||||
return (False, None, None)
|
||||
|
||||
@@ -84,7 +84,7 @@ def submit_pending(session_key: str, approval: dict):
|
||||
_pending[session_key] = approval
|
||||
|
||||
|
||||
def pop_pending(session_key: str) -> Optional[dict]:
|
||||
def pop_pending(session_key: str) -> dict | None:
|
||||
"""Retrieve and remove a pending approval for a session."""
|
||||
with _lock:
|
||||
return _pending.pop(session_key, None)
|
||||
@@ -133,6 +133,7 @@ def clear_session(session_key: str):
|
||||
# Config persistence for permanent allowlist
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def load_permanent_allowlist() -> set:
|
||||
"""Load permanently allowed command patterns from config.
|
||||
|
||||
@@ -141,6 +142,7 @@ def load_permanent_allowlist() -> set:
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
patterns = set(config.get("command_allowlist", []) or [])
|
||||
if patterns:
|
||||
@@ -154,6 +156,7 @@ def save_permanent_allowlist(patterns: set):
|
||||
"""Save permanently allowed command patterns to config."""
|
||||
try:
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
config = load_config()
|
||||
config["command_allowlist"] = list(patterns)
|
||||
save_config(config)
|
||||
@@ -165,9 +168,8 @@ def save_permanent_allowlist(patterns: set):
|
||||
# Approval prompting + orchestration
|
||||
# =========================================================================
|
||||
|
||||
def prompt_dangerous_approval(command: str, description: str,
|
||||
timeout_seconds: int = 60,
|
||||
approval_callback=None) -> str:
|
||||
|
||||
def prompt_dangerous_approval(command: str, description: str, timeout_seconds: int = 60, approval_callback=None) -> str:
|
||||
"""Prompt the user to approve a dangerous command (CLI only).
|
||||
|
||||
Args:
|
||||
@@ -188,7 +190,7 @@ def prompt_dangerous_approval(command: str, description: str,
|
||||
print(f" ⚠️ DANGEROUS COMMAND: {description}")
|
||||
print(f" {command[:80]}{'...' if len(command) > 80 else ''}")
|
||||
print()
|
||||
print(f" [o]nce | [s]ession | [a]lways | [d]eny")
|
||||
print(" [o]nce | [s]ession | [a]lways | [d]eny")
|
||||
print()
|
||||
sys.stdout.flush()
|
||||
|
||||
@@ -209,13 +211,13 @@ def prompt_dangerous_approval(command: str, description: str,
|
||||
return "deny"
|
||||
|
||||
choice = result["choice"]
|
||||
if choice in ('o', 'once'):
|
||||
if choice in ("o", "once"):
|
||||
print(" ✓ Allowed once")
|
||||
return "once"
|
||||
elif choice in ('s', 'session'):
|
||||
elif choice in ("s", "session"):
|
||||
print(" ✓ Allowed for this session")
|
||||
return "session"
|
||||
elif choice in ('a', 'always'):
|
||||
elif choice in ("a", "always"):
|
||||
print(" ✓ Added to permanent allowlist")
|
||||
return "always"
|
||||
else:
|
||||
@@ -232,8 +234,7 @@ def prompt_dangerous_approval(command: str, description: str,
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def check_dangerous_command(command: str, env_type: str,
|
||||
approval_callback=None) -> dict:
|
||||
def check_dangerous_command(command: str, env_type: str, approval_callback=None) -> dict:
|
||||
"""Check if a command is dangerous and handle approval.
|
||||
|
||||
This is the main entry point called by terminal_tool before executing
|
||||
@@ -265,11 +266,14 @@ def check_dangerous_command(command: str, env_type: str,
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
if is_gateway or os.getenv("HERMES_EXEC_ASK"):
|
||||
submit_pending(session_key, {
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"description": description,
|
||||
})
|
||||
submit_pending(
|
||||
session_key,
|
||||
{
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
return {
|
||||
"approved": False,
|
||||
"pattern_key": pattern_key,
|
||||
@@ -279,8 +283,7 @@ def check_dangerous_command(command: str, env_type: str,
|
||||
"message": f"⚠️ This command is potentially dangerous ({description}). Asking the user for approval...",
|
||||
}
|
||||
|
||||
choice = prompt_dangerous_approval(command, description,
|
||||
approval_callback=approval_callback)
|
||||
choice = prompt_dangerous_approval(command, description, approval_callback=approval_callback)
|
||||
|
||||
if choice == "deny":
|
||||
return {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,8 +12,7 @@ a thin dispatcher that delegates to a platform-provided callback.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
# Maximum number of predefined choices the agent can offer.
|
||||
# A 5th "Other (type your answer)" option is always appended by the UI.
|
||||
@@ -22,8 +21,8 @@ MAX_CHOICES = 4
|
||||
|
||||
def clarify_tool(
|
||||
question: str,
|
||||
choices: Optional[List[str]] = None,
|
||||
callback: Optional[Callable] = None,
|
||||
choices: list[str] | None = None,
|
||||
callback: Callable | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ask the user a question, optionally with multiple-choice options.
|
||||
@@ -68,11 +67,14 @@ def clarify_tool(
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"question": question,
|
||||
"choices_offered": choices,
|
||||
"user_response": str(user_response).strip(),
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"question": question,
|
||||
"choices_offered": choices,
|
||||
"user_response": str(user_response).strip(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def check_clarify_requirements() -> bool:
|
||||
@@ -133,8 +135,7 @@ registry.register(
|
||||
toolset="clarify",
|
||||
schema=CLARIFY_SCHEMA,
|
||||
handler=lambda args, **kw: clarify_tool(
|
||||
question=args.get("question", ""),
|
||||
choices=args.get("choices"),
|
||||
callback=kw.get("callback")),
|
||||
question=args.get("question", ""), choices=args.get("choices"), callback=kw.get("callback")
|
||||
),
|
||||
check_fn=check_clarify_requirements,
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ import time
|
||||
import uuid
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
# Availability gate: UDS requires a POSIX OS
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,21 +40,23 @@ SANDBOX_AVAILABLE = sys.platform != "win32"
|
||||
|
||||
# The 7 tools allowed inside the sandbox. The intersection of this list
|
||||
# and the session's enabled tools determines which stubs are generated.
|
||||
SANDBOX_ALLOWED_TOOLS = frozenset([
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"search_files",
|
||||
"patch",
|
||||
"terminal",
|
||||
])
|
||||
SANDBOX_ALLOWED_TOOLS = frozenset(
|
||||
[
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"search_files",
|
||||
"patch",
|
||||
"terminal",
|
||||
]
|
||||
)
|
||||
|
||||
# Resource limit defaults (overridable via config.yaml → code_execution.*)
|
||||
DEFAULT_TIMEOUT = 300 # 5 minutes
|
||||
DEFAULT_TIMEOUT = 300 # 5 minutes
|
||||
DEFAULT_MAX_TOOL_CALLS = 50
|
||||
MAX_STDOUT_BYTES = 50_000 # 50 KB
|
||||
MAX_STDERR_BYTES = 10_000 # 10 KB
|
||||
MAX_STDOUT_BYTES = 50_000 # 50 KB
|
||||
MAX_STDERR_BYTES = 10_000 # 10 KB
|
||||
|
||||
|
||||
def check_sandbox_requirements() -> bool:
|
||||
@@ -114,7 +116,7 @@ _TOOL_STUBS = {
|
||||
}
|
||||
|
||||
|
||||
def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
|
||||
def generate_hermes_tools_module(enabled_tools: list[str]) -> str:
|
||||
"""
|
||||
Build the source code for the hermes_tools.py stub module.
|
||||
|
||||
@@ -128,11 +130,7 @@ def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
|
||||
if tool_name not in _TOOL_STUBS:
|
||||
continue
|
||||
func_name, sig, doc, args_expr = _TOOL_STUBS[tool_name]
|
||||
stub_functions.append(
|
||||
f"def {func_name}({sig}):\n"
|
||||
f" {doc}\n"
|
||||
f" return _call({func_name!r}, {args_expr})\n"
|
||||
)
|
||||
stub_functions.append(f"def {func_name}({sig}):\n {doc}\n return _call({func_name!r}, {args_expr})\n")
|
||||
export_names.append(func_name)
|
||||
|
||||
header = '''\
|
||||
@@ -223,7 +221,7 @@ def _rpc_server_loop(
|
||||
server_sock: socket.socket,
|
||||
task_id: str,
|
||||
tool_call_log: list,
|
||||
tool_call_counter: list, # mutable [int] so the thread can increment
|
||||
tool_call_counter: list, # mutable [int] so the thread can increment
|
||||
max_tool_calls: int,
|
||||
allowed_tools: frozenset,
|
||||
):
|
||||
@@ -243,7 +241,7 @@ def _rpc_server_loop(
|
||||
while True:
|
||||
try:
|
||||
chunk = conn.recv(65536)
|
||||
except socket.timeout:
|
||||
except TimeoutError:
|
||||
break
|
||||
if not chunk:
|
||||
break
|
||||
@@ -270,23 +268,22 @@ def _rpc_server_loop(
|
||||
# Enforce the allow-list
|
||||
if tool_name not in allowed_tools:
|
||||
available = ", ".join(sorted(allowed_tools))
|
||||
resp = json.dumps({
|
||||
"error": (
|
||||
f"Tool '{tool_name}' is not available in execute_code. "
|
||||
f"Available: {available}"
|
||||
)
|
||||
})
|
||||
resp = json.dumps(
|
||||
{"error": (f"Tool '{tool_name}' is not available in execute_code. Available: {available}")}
|
||||
)
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
# Enforce tool call limit
|
||||
if tool_call_counter[0] >= max_tool_calls:
|
||||
resp = json.dumps({
|
||||
"error": (
|
||||
f"Tool call limit reached ({max_tool_calls}). "
|
||||
"No more tool calls allowed in this execution."
|
||||
)
|
||||
})
|
||||
resp = json.dumps(
|
||||
{
|
||||
"error": (
|
||||
f"Tool call limit reached ({max_tool_calls}). "
|
||||
"No more tool calls allowed in this execution."
|
||||
)
|
||||
}
|
||||
)
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
@@ -303,9 +300,7 @@ def _rpc_server_loop(
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
try:
|
||||
result = handle_function_call(
|
||||
tool_name, tool_args, task_id=task_id
|
||||
)
|
||||
result = handle_function_call(tool_name, tool_args, task_id=task_id)
|
||||
finally:
|
||||
sys.stdout.close()
|
||||
sys.stderr.close()
|
||||
@@ -318,15 +313,17 @@ def _rpc_server_loop(
|
||||
|
||||
# Log for observability
|
||||
args_preview = str(tool_args)[:80]
|
||||
tool_call_log.append({
|
||||
"tool": tool_name,
|
||||
"args_preview": args_preview,
|
||||
"duration": round(call_duration, 2),
|
||||
})
|
||||
tool_call_log.append(
|
||||
{
|
||||
"tool": tool_name,
|
||||
"args_preview": args_preview,
|
||||
"duration": round(call_duration, 2),
|
||||
}
|
||||
)
|
||||
|
||||
conn.sendall((result + "\n").encode())
|
||||
|
||||
except socket.timeout:
|
||||
except TimeoutError:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
@@ -342,10 +339,11 @@ def _rpc_server_loop(
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_code(
|
||||
code: str,
|
||||
task_id: Optional[str] = None,
|
||||
enabled_tools: Optional[List[str]] = None,
|
||||
task_id: str | None = None,
|
||||
enabled_tools: list[str] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Run a Python script in a sandboxed child process with RPC access
|
||||
@@ -361,9 +359,7 @@ def execute_code(
|
||||
JSON string with execution results.
|
||||
"""
|
||||
if not SANDBOX_AVAILABLE:
|
||||
return json.dumps({
|
||||
"error": "execute_code is not available on Windows. Use normal tool calls instead."
|
||||
})
|
||||
return json.dumps({"error": "execute_code is not available on Windows. Use normal tool calls instead."})
|
||||
|
||||
if not code or not code.strip():
|
||||
return json.dumps({"error": "No code provided."})
|
||||
@@ -397,9 +393,7 @@ def execute_code(
|
||||
|
||||
try:
|
||||
# Write the auto-generated hermes_tools module
|
||||
tools_src = generate_hermes_tools_module(
|
||||
list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS)
|
||||
)
|
||||
tools_src = generate_hermes_tools_module(list(sandbox_tools) if enabled_tools else list(SANDBOX_ALLOWED_TOOLS))
|
||||
with open(os.path.join(tmpdir, "hermes_tools.py"), "w") as f:
|
||||
f.write(tools_src)
|
||||
|
||||
@@ -415,8 +409,12 @@ def execute_code(
|
||||
rpc_thread = threading.Thread(
|
||||
target=_rpc_server_loop,
|
||||
args=(
|
||||
server_sock, task_id, tool_call_log,
|
||||
tool_call_counter, max_tool_calls, sandbox_tools,
|
||||
server_sock,
|
||||
task_id,
|
||||
tool_call_log,
|
||||
tool_call_counter,
|
||||
max_tool_calls,
|
||||
sandbox_tools,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
@@ -426,11 +424,24 @@ def execute_code(
|
||||
# Build a minimal environment for the child. We intentionally exclude
|
||||
# API keys and tokens to prevent credential exfiltration from LLM-
|
||||
# generated scripts. The child accesses tools via RPC, not direct API.
|
||||
_SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM",
|
||||
"TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME",
|
||||
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA")
|
||||
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL",
|
||||
"PASSWD", "AUTH")
|
||||
_SAFE_ENV_PREFIXES = (
|
||||
"PATH",
|
||||
"HOME",
|
||||
"USER",
|
||||
"LANG",
|
||||
"LC_",
|
||||
"TERM",
|
||||
"TMPDIR",
|
||||
"TMP",
|
||||
"TEMP",
|
||||
"SHELL",
|
||||
"LOGNAME",
|
||||
"XDG_",
|
||||
"PYTHONPATH",
|
||||
"VIRTUAL_ENV",
|
||||
"CONDA",
|
||||
)
|
||||
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL", "PASSWD", "AUTH")
|
||||
child_env = {}
|
||||
for k, v in os.environ.items():
|
||||
if any(s in k.upper() for s in _SECRET_SUBSTRINGS):
|
||||
@@ -515,7 +526,7 @@ def execute_code(
|
||||
rpc_thread.join(timeout=3)
|
||||
|
||||
# Build response
|
||||
result: Dict[str, Any] = {
|
||||
result: dict[str, Any] = {
|
||||
"status": status,
|
||||
"output": stdout_text,
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
@@ -538,17 +549,21 @@ def execute_code(
|
||||
except Exception as exc:
|
||||
duration = round(time.monotonic() - exec_start, 2)
|
||||
logging.exception("execute_code failed")
|
||||
return json.dumps({
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
"duration_seconds": duration,
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
"duration_seconds": duration,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup temp dir and socket
|
||||
try:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
except Exception as e:
|
||||
logger.debug("Could not clean temp dir: %s", e)
|
||||
@@ -592,6 +607,7 @@ def _load_config() -> dict:
|
||||
"""Load code_execution config from CLI_CONFIG if available."""
|
||||
try:
|
||||
from cli import CLI_CONFIG
|
||||
|
||||
return CLI_CONFIG.get("code_execution", {})
|
||||
except Exception:
|
||||
return {}
|
||||
@@ -604,27 +620,37 @@ def _load_config() -> dict:
|
||||
# Per-tool documentation lines for the execute_code description.
|
||||
# Ordered to match the canonical display order.
|
||||
_TOOL_DOC_LINES = [
|
||||
("web_search",
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"),
|
||||
("web_extract",
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
" Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"),
|
||||
("read_file",
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"),
|
||||
("write_file",
|
||||
" write_file(path: str, content: str) -> dict\n"
|
||||
" Always overwrites the entire file."),
|
||||
("search_files",
|
||||
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
|
||||
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"),
|
||||
("patch",
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file."),
|
||||
("terminal",
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"),
|
||||
(
|
||||
"web_search",
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
' Returns {"data": {"web": [{"url", "title", "description"}, ...]}}',
|
||||
),
|
||||
(
|
||||
"web_extract",
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
' Returns {"results": [{"url", "title", "content", "error"}, ...]} where content is markdown',
|
||||
),
|
||||
(
|
||||
"read_file",
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
' Lines are 1-indexed. Returns {"content": "...", "total_lines": N}',
|
||||
),
|
||||
("write_file", " write_file(path: str, content: str) -> dict\n Always overwrites the entire file."),
|
||||
(
|
||||
"search_files",
|
||||
' search_files(pattern: str, target="content", path=".", file_glob=None, limit=50) -> dict\n'
|
||||
' target: "content" (search inside files) or "files" (find files by name). Returns {"matches": [...]}',
|
||||
),
|
||||
(
|
||||
"patch",
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file.",
|
||||
),
|
||||
(
|
||||
"terminal",
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
' Foreground only (no background/pty). Returns {"output": "...", "exit_code": N}',
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -639,9 +665,7 @@ def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict:
|
||||
enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS
|
||||
|
||||
# Build tool documentation lines for only the enabled tools
|
||||
tool_lines = "\n".join(
|
||||
doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools
|
||||
)
|
||||
tool_lines = "\n".join(doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools)
|
||||
|
||||
# Build example import list from enabled tools
|
||||
import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools]
|
||||
@@ -702,8 +726,7 @@ registry.register(
|
||||
toolset="code_execution",
|
||||
schema=EXECUTE_CODE_SCHEMA,
|
||||
handler=lambda args, **kw: execute_code(
|
||||
code=args.get("code", ""),
|
||||
task_id=kw.get("task_id"),
|
||||
enabled_tools=kw.get("enabled_tools")),
|
||||
code=args.get("code", ""), task_id=kw.get("task_id"), enabled_tools=kw.get("enabled_tools")
|
||||
),
|
||||
check_fn=check_sandbox_requirements,
|
||||
)
|
||||
|
||||
@@ -11,37 +11,44 @@ The prompt must contain ALL necessary information.
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
# Import from cron module (will be available when properly installed)
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import create_job, get_job, list_jobs, remove_job
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cron prompt scanning — critical-severity patterns only, since cron prompts
|
||||
# run in fresh sessions with full tool access.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CRON_THREAT_PATTERNS = [
|
||||
(r'ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions', "prompt_injection"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
|
||||
(r'authorized_keys', "ssh_backdoor"),
|
||||
(r'/etc/sudoers|visudo', "sudoers_mod"),
|
||||
(r'rm\s+-rf\s+/', "destructive_root_rm"),
|
||||
(r"ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions", "prompt_injection"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)", "read_secrets"),
|
||||
(r"authorized_keys", "ssh_backdoor"),
|
||||
(r"/etc/sudoers|visudo", "sudoers_mod"),
|
||||
(r"rm\s+-rf\s+/", "destructive_root_rm"),
|
||||
]
|
||||
|
||||
_CRON_INVISIBLE_CHARS = {
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
}
|
||||
|
||||
|
||||
@@ -60,17 +67,18 @@ def _scan_cron_prompt(prompt: str) -> str:
|
||||
# Tool: schedule_cronjob
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def schedule_cronjob(
|
||||
prompt: str,
|
||||
schedule: str,
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
task_id: str = None
|
||||
name: str | None = None,
|
||||
repeat: int | None = None,
|
||||
deliver: str | None = None,
|
||||
task_id: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Schedule an automated task to run the agent on a schedule.
|
||||
|
||||
|
||||
IMPORTANT: When the cronjob runs, it starts a COMPLETELY FRESH session.
|
||||
The agent will have NO memory of this conversation or any prior context.
|
||||
Therefore, the prompt MUST contain ALL necessary information:
|
||||
@@ -78,12 +86,12 @@ def schedule_cronjob(
|
||||
- Specific file paths, URLs, or identifiers
|
||||
- Clear success criteria
|
||||
- Any relevant background information
|
||||
|
||||
|
||||
BAD prompt: "Check on that server issue"
|
||||
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
|
||||
is running with 'systemctl status nginx', and verify the site
|
||||
GOOD prompt: "SSH into server 192.168.1.100 as user 'deploy', check if nginx
|
||||
is running with 'systemctl status nginx', and verify the site
|
||||
https://example.com returns HTTP 200. Report any issues found."
|
||||
|
||||
|
||||
Args:
|
||||
prompt: Complete, self-contained instructions for the future agent.
|
||||
Must include ALL context needed - the agent won't remember anything.
|
||||
@@ -105,7 +113,7 @@ def schedule_cronjob(
|
||||
- "signal": Send to Signal home channel
|
||||
- "telegram:123456": Send to specific chat ID
|
||||
- "signal:+15551234567": Send to specific Signal number
|
||||
|
||||
|
||||
Returns:
|
||||
JSON with job_id, next_run time, and confirmation
|
||||
"""
|
||||
@@ -124,17 +132,10 @@ def schedule_cronjob(
|
||||
"chat_id": origin_chat_id,
|
||||
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
job = create_job(
|
||||
prompt=prompt,
|
||||
schedule=schedule,
|
||||
name=name,
|
||||
repeat=repeat,
|
||||
deliver=deliver,
|
||||
origin=origin
|
||||
)
|
||||
|
||||
job = create_job(prompt=prompt, schedule=schedule, name=name, repeat=repeat, deliver=deliver, origin=origin)
|
||||
|
||||
# Format repeat info for display
|
||||
times = job["repeat"].get("times")
|
||||
if times is None:
|
||||
@@ -143,23 +144,23 @@ def schedule_cronjob(
|
||||
repeat_display = "once"
|
||||
else:
|
||||
repeat_display = f"{times} times"
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_display,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job["next_run_at"],
|
||||
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}."
|
||||
}, indent=2)
|
||||
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_display,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job["next_run_at"],
|
||||
"message": f"Cronjob '{job['name']}' created. It will run {repeat_display}, deliver to {job.get('deliver', 'local')}, next at {job['next_run_at']}.",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
|
||||
|
||||
SCHEDULE_CRONJOB_SCHEMA = {
|
||||
@@ -177,7 +178,7 @@ The future agent will NOT remember anything from the current conversation.
|
||||
|
||||
SCHEDULE FORMATS:
|
||||
- One-shot: "30m", "2h", "1d" (runs once after delay)
|
||||
- Interval: "every 30m", "every 2h" (recurring)
|
||||
- Interval: "every 30m", "every 2h" (recurring)
|
||||
- Cron: "0 9 * * *" (cron expression for precise scheduling)
|
||||
- Timestamp: "2026-02-03T14:00:00" (specific date/time)
|
||||
|
||||
@@ -202,27 +203,24 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation."
|
||||
"description": "Complete, self-contained instructions. Must include ALL context - the future agent will have NO memory of this conversation.",
|
||||
},
|
||||
"schedule": {
|
||||
"type": "string",
|
||||
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional human-friendly name for the job"
|
||||
"description": "When to run: '30m' (once in 30min), 'every 30m' (recurring), '0 9 * * *' (cron), or ISO timestamp",
|
||||
},
|
||||
"name": {"type": "string", "description": "Optional human-friendly name for the job"},
|
||||
"repeat": {
|
||||
"type": "integer",
|
||||
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs."
|
||||
"description": "How many times to run. Omit for default (once for one-shot, forever for recurring). Set to N for exactly N runs.",
|
||||
},
|
||||
"deliver": {
|
||||
"type": "string",
|
||||
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'"
|
||||
}
|
||||
"description": "Where to send output: 'origin' (back to this chat), 'local' (files only), 'telegram', 'discord', 'signal', or 'platform:chat_id'",
|
||||
},
|
||||
},
|
||||
"required": ["prompt", "schedule"]
|
||||
}
|
||||
"required": ["prompt", "schedule"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -230,10 +228,11 @@ Use for: reminders, periodic checks, scheduled reports, automated maintenance.""
|
||||
# Tool: list_cronjobs
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
||||
"""
|
||||
List all scheduled cronjobs.
|
||||
|
||||
|
||||
Returns information about each job including:
|
||||
- Job ID (needed for removal)
|
||||
- Name
|
||||
@@ -241,16 +240,16 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
||||
- Repeat status (completed/total or 'forever')
|
||||
- Next scheduled run time
|
||||
- Last run time and status (if any)
|
||||
|
||||
|
||||
Args:
|
||||
include_disabled: Whether to include disabled/completed jobs
|
||||
|
||||
|
||||
Returns:
|
||||
JSON array of all scheduled jobs
|
||||
"""
|
||||
try:
|
||||
jobs = list_jobs(include_disabled=include_disabled)
|
||||
|
||||
|
||||
formatted_jobs = []
|
||||
for job in jobs:
|
||||
# Format repeat status
|
||||
@@ -260,31 +259,26 @@ def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
||||
repeat_status = "forever"
|
||||
else:
|
||||
repeat_status = f"{completed}/{times}"
|
||||
|
||||
formatted_jobs.append({
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_status,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job.get("next_run_at"),
|
||||
"last_run_at": job.get("last_run_at"),
|
||||
"last_status": job.get("last_status"),
|
||||
"enabled": job.get("enabled", True)
|
||||
})
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"count": len(formatted_jobs),
|
||||
"jobs": formatted_jobs
|
||||
}, indent=2)
|
||||
|
||||
|
||||
formatted_jobs.append(
|
||||
{
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"prompt_preview": job["prompt"][:100] + "..." if len(job["prompt"]) > 100 else job["prompt"],
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": repeat_status,
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job.get("next_run_at"),
|
||||
"last_run_at": job.get("last_run_at"),
|
||||
"last_status": job.get("last_status"),
|
||||
"enabled": job.get("enabled", True),
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps({"success": True, "count": len(formatted_jobs), "jobs": formatted_jobs}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
|
||||
|
||||
LIST_CRONJOBS_SCHEMA = {
|
||||
@@ -302,11 +296,11 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
|
||||
"properties": {
|
||||
"include_disabled": {
|
||||
"type": "boolean",
|
||||
"description": "Include disabled/completed jobs in the list (default: false)"
|
||||
"description": "Include disabled/completed jobs in the list (default: false)",
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -314,48 +308,45 @@ Returns job_id, name, schedule, repeat status, next/last run times.""",
|
||||
# Tool: remove_cronjob
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def remove_cronjob(job_id: str, task_id: str = None) -> str:
|
||||
"""
|
||||
Remove a scheduled cronjob by its ID.
|
||||
|
||||
|
||||
Use list_cronjobs first to find the job_id of the job you want to remove.
|
||||
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job to remove (from list_cronjobs output)
|
||||
|
||||
|
||||
Returns:
|
||||
JSON confirmation of removal
|
||||
"""
|
||||
try:
|
||||
job = get_job(job_id)
|
||||
if not job:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs."
|
||||
}, indent=2)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Job with ID '{job_id}' not found. Use list_cronjobs to see available jobs.",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
removed = remove_job(job_id)
|
||||
if removed:
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
|
||||
"removed_job": {
|
||||
"id": job_id,
|
||||
"name": job["name"],
|
||||
"schedule": job["schedule_display"]
|
||||
}
|
||||
}, indent=2)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Cronjob '{job['name']}' (ID: {job_id}) has been removed.",
|
||||
"removed_job": {"id": job_id, "name": job["name"], "schedule": job["schedule_display"]},
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Failed to remove job '{job_id}'"
|
||||
}, indent=2)
|
||||
|
||||
return json.dumps({"success": False, "error": f"Failed to remove job '{job_id}'"}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, indent=2)
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
|
||||
|
||||
REMOVE_CRONJOB_SCHEMA = {
|
||||
@@ -368,13 +359,10 @@ use this to cancel a job before it completes.""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the cronjob to remove (from list_cronjobs output)"
|
||||
}
|
||||
"job_id": {"type": "string", "description": "The ID of the cronjob to remove (from list_cronjobs output)"}
|
||||
},
|
||||
"required": ["job_id"]
|
||||
}
|
||||
"required": ["job_id"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -382,44 +370,34 @@ use this to cancel a job before it completes.""",
|
||||
# Requirements check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def check_cronjob_requirements() -> bool:
|
||||
"""
|
||||
Check if cronjob tools can be used.
|
||||
|
||||
|
||||
Available in interactive CLI mode and gateway/messaging platforms.
|
||||
Cronjobs are server-side scheduled tasks so they work from any interface.
|
||||
"""
|
||||
return bool(
|
||||
os.getenv("HERMES_INTERACTIVE")
|
||||
or os.getenv("HERMES_GATEWAY_SESSION")
|
||||
or os.getenv("HERMES_EXEC_ASK")
|
||||
)
|
||||
return bool(os.getenv("HERMES_INTERACTIVE") or os.getenv("HERMES_GATEWAY_SESSION") or os.getenv("HERMES_EXEC_ASK"))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_cronjob_tool_definitions():
|
||||
"""Return tool definitions for cronjob management."""
|
||||
return [
|
||||
SCHEDULE_CRONJOB_SCHEMA,
|
||||
LIST_CRONJOBS_SCHEMA,
|
||||
REMOVE_CRONJOB_SCHEMA
|
||||
]
|
||||
return [SCHEDULE_CRONJOB_SCHEMA, LIST_CRONJOBS_SCHEMA, REMOVE_CRONJOB_SCHEMA]
|
||||
|
||||
|
||||
# For direct testing
|
||||
if __name__ == "__main__":
|
||||
# Test the tools
|
||||
print("Testing schedule_cronjob:")
|
||||
result = schedule_cronjob(
|
||||
prompt="Test prompt for cron job",
|
||||
schedule="5m",
|
||||
name="Test Job"
|
||||
)
|
||||
result = schedule_cronjob(prompt="Test prompt for cron job", schedule="5m", name="Test Job")
|
||||
print(result)
|
||||
|
||||
|
||||
print("\nTesting list_cronjobs:")
|
||||
result = list_cronjobs()
|
||||
print(result)
|
||||
@@ -438,7 +416,8 @@ registry.register(
|
||||
name=args.get("name"),
|
||||
repeat=args.get("repeat"),
|
||||
deliver=args.get("deliver"),
|
||||
task_id=kw.get("task_id")),
|
||||
task_id=kw.get("task_id"),
|
||||
),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
registry.register(
|
||||
@@ -446,16 +425,14 @@ registry.register(
|
||||
toolset="cronjob",
|
||||
schema=LIST_CRONJOBS_SCHEMA,
|
||||
handler=lambda args, **kw: list_cronjobs(
|
||||
include_disabled=args.get("include_disabled", False),
|
||||
task_id=kw.get("task_id")),
|
||||
include_disabled=args.get("include_disabled", False), task_id=kw.get("task_id")
|
||||
),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
registry.register(
|
||||
name="remove_cronjob",
|
||||
toolset="cronjob",
|
||||
schema=REMOVE_CRONJOB_SCHEMA,
|
||||
handler=lambda args, **kw: remove_cronjob(
|
||||
job_id=args.get("job_id", ""),
|
||||
task_id=kw.get("task_id")),
|
||||
handler=lambda args, **kw: remove_cronjob(job_id=args.get("job_id", ""), task_id=kw.get("task_id")),
|
||||
check_fn=check_cronjob_requirements,
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,27 +44,28 @@ class DebugSession:
|
||||
self.enabled = os.getenv(env_var, "false").lower() == "true"
|
||||
self.session_id = str(uuid.uuid4()) if self.enabled else ""
|
||||
self.log_dir = Path("./logs")
|
||||
self._calls: list[Dict[str, Any]] = []
|
||||
self._calls: list[dict[str, Any]] = []
|
||||
self._start_time = datetime.datetime.now().isoformat() if self.enabled else ""
|
||||
|
||||
if self.enabled:
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
logger.debug("%s debug mode enabled - Session ID: %s",
|
||||
tool_name, self.session_id)
|
||||
logger.debug("%s debug mode enabled - Session ID: %s", tool_name, self.session_id)
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
return self.enabled
|
||||
|
||||
def log_call(self, call_name: str, call_data: Dict[str, Any]) -> None:
|
||||
def log_call(self, call_name: str, call_data: dict[str, Any]) -> None:
|
||||
"""Append a tool-call entry to the in-memory log."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self._calls.append({
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": call_name,
|
||||
**call_data,
|
||||
})
|
||||
self._calls.append(
|
||||
{
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": call_name,
|
||||
**call_data,
|
||||
}
|
||||
)
|
||||
|
||||
def save(self) -> None:
|
||||
"""Flush the in-memory log to a JSON file in the logs directory."""
|
||||
@@ -87,7 +88,7 @@ class DebugSession:
|
||||
except Exception as e:
|
||||
logger.error("Error saving %s debug log: %s", self.tool_name, e)
|
||||
|
||||
def get_session_info(self) -> Dict[str, Any]:
|
||||
def get_session_info(self) -> dict[str, Any]:
|
||||
"""Return a summary dict suitable for returning from get_debug_session_info()."""
|
||||
if not self.enabled:
|
||||
return {
|
||||
|
||||
@@ -20,21 +20,22 @@ import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Tools that children must never have access to
|
||||
DELEGATE_BLOCKED_TOOLS = frozenset([
|
||||
"delegate_task", # no recursive delegation
|
||||
"clarify", # no user interaction
|
||||
"memory", # no writes to shared MEMORY.md
|
||||
"send_message", # no cross-platform side effects
|
||||
"execute_code", # children should reason step-by-step, not write scripts
|
||||
])
|
||||
DELEGATE_BLOCKED_TOOLS = frozenset(
|
||||
[
|
||||
"delegate_task", # no recursive delegation
|
||||
"clarify", # no user interaction
|
||||
"memory", # no writes to shared MEMORY.md
|
||||
"send_message", # no cross-platform side effects
|
||||
"execute_code", # children should reason step-by-step, not write scripts
|
||||
]
|
||||
)
|
||||
|
||||
MAX_CONCURRENT_CHILDREN = 3
|
||||
MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2)
|
||||
@@ -47,7 +48,7 @@ def check_delegate_requirements() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
|
||||
def _build_child_system_prompt(goal: str, context: str | None = None) -> str:
|
||||
"""Build a focused system prompt for a child agent."""
|
||||
parts = [
|
||||
"You are a focused subagent working on a specific delegated task.",
|
||||
@@ -69,15 +70,18 @@ def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _strip_blocked_tools(toolsets: List[str]) -> List[str]:
|
||||
def _strip_blocked_tools(toolsets: list[str]) -> list[str]:
|
||||
"""Remove toolsets that contain only blocked tools."""
|
||||
blocked_toolset_names = {
|
||||
"delegation", "clarify", "memory", "code_execution",
|
||||
"delegation",
|
||||
"clarify",
|
||||
"memory",
|
||||
"code_execution",
|
||||
}
|
||||
return [t for t in toolsets if t not in blocked_toolset_names]
|
||||
|
||||
|
||||
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Optional[callable]:
|
||||
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Callable | None:
|
||||
"""Build a callback that relays child agent tool calls to the parent display.
|
||||
|
||||
Two display paths:
|
||||
@@ -87,8 +91,8 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
Returns None if no display mechanism is available, in which case the
|
||||
child agent runs with no progress callback (identical to current behavior).
|
||||
"""
|
||||
spinner = getattr(parent_agent, '_delegate_spinner', None)
|
||||
parent_cb = getattr(parent_agent, 'tool_progress_callback', None)
|
||||
spinner = getattr(parent_agent, "_delegate_spinner", None)
|
||||
parent_cb = getattr(parent_agent, "tool_progress_callback", None)
|
||||
|
||||
if not spinner and not parent_cb:
|
||||
return None # No display → no callback → zero behavior change
|
||||
@@ -98,7 +102,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
|
||||
# Gateway: batch tool names, flush periodically
|
||||
_BATCH_SIZE = 5
|
||||
_batch: List[str] = []
|
||||
_batch: list[str] = []
|
||||
|
||||
def _callback(tool_name: str, preview: str = None):
|
||||
# Special "_thinking" event: model produced text content (reasoning)
|
||||
@@ -106,7 +110,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
if spinner:
|
||||
short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "")
|
||||
try:
|
||||
spinner.print_above(f" {prefix}├─ 💭 \"{short}\"")
|
||||
spinner.print_above(f' {prefix}├─ 💭 "{short}"')
|
||||
except Exception:
|
||||
pass
|
||||
# Don't relay thinking to gateway (too noisy for chat)
|
||||
@@ -116,17 +120,25 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
if spinner:
|
||||
short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "")
|
||||
tool_emojis = {
|
||||
"terminal": "💻", "web_search": "🔍", "web_extract": "📄",
|
||||
"read_file": "📖", "write_file": "✍️", "patch": "🔧",
|
||||
"search_files": "🔎", "list_directory": "📂",
|
||||
"browser_navigate": "🌐", "browser_click": "👆",
|
||||
"text_to_speech": "🔊", "image_generate": "🎨",
|
||||
"vision_analyze": "👁️", "process": "⚙️",
|
||||
"terminal": "💻",
|
||||
"web_search": "🔍",
|
||||
"web_extract": "📄",
|
||||
"read_file": "📖",
|
||||
"write_file": "✍️",
|
||||
"patch": "🔧",
|
||||
"search_files": "🔎",
|
||||
"list_directory": "📂",
|
||||
"browser_navigate": "🌐",
|
||||
"browser_click": "👆",
|
||||
"text_to_speech": "🔊",
|
||||
"image_generate": "🎨",
|
||||
"vision_analyze": "👁️",
|
||||
"process": "⚙️",
|
||||
}
|
||||
emoji = tool_emojis.get(tool_name, "⚡")
|
||||
line = f" {prefix}├─ {emoji} {tool_name}"
|
||||
if short:
|
||||
line += f" \"{short}\""
|
||||
line += f' "{short}"'
|
||||
try:
|
||||
spinner.print_above(line)
|
||||
except Exception:
|
||||
@@ -159,13 +171,13 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
||||
def _run_single_child(
|
||||
task_index: int,
|
||||
goal: str,
|
||||
context: Optional[str],
|
||||
toolsets: Optional[List[str]],
|
||||
model: Optional[str],
|
||||
context: str | None,
|
||||
toolsets: list[str] | None,
|
||||
model: str | None,
|
||||
max_iterations: int,
|
||||
parent_agent,
|
||||
task_count: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Spawn and run a single child agent. Called from within a thread.
|
||||
Returns a structured result dict.
|
||||
@@ -216,7 +228,7 @@ def _run_single_child(
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
clarify_callback=None,
|
||||
session_db=getattr(parent_agent, '_session_db', None),
|
||||
session_db=getattr(parent_agent, "_session_db", None),
|
||||
providers_allowed=parent_agent.providers_allowed,
|
||||
providers_ignored=parent_agent.providers_ignored,
|
||||
providers_order=parent_agent.providers_order,
|
||||
@@ -226,10 +238,10 @@ def _run_single_child(
|
||||
)
|
||||
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||
child._delegate_depth = getattr(parent_agent, "_delegate_depth", 0) + 1
|
||||
|
||||
# Register child for interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
if hasattr(parent_agent, "_active_children"):
|
||||
parent_agent._active_children.append(child)
|
||||
|
||||
# Run with stdout/stderr suppressed to prevent interleaved output
|
||||
@@ -238,7 +250,7 @@ def _run_single_child(
|
||||
result = child.run_conversation(user_message=goal)
|
||||
|
||||
# Flush any remaining batched progress to gateway
|
||||
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
|
||||
if child_progress_cb and hasattr(child_progress_cb, "_flush"):
|
||||
try:
|
||||
child_progress_cb._flush()
|
||||
except Exception:
|
||||
@@ -258,7 +270,7 @@ def _run_single_child(
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
entry: Dict[str, Any] = {
|
||||
entry: dict[str, Any] = {
|
||||
"task_index": task_index,
|
||||
"status": status,
|
||||
"summary": summary,
|
||||
@@ -284,7 +296,7 @@ def _run_single_child(
|
||||
|
||||
finally:
|
||||
# Unregister child from interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
if hasattr(parent_agent, "_active_children"):
|
||||
try:
|
||||
parent_agent._active_children.remove(child)
|
||||
except (ValueError, UnboundLocalError):
|
||||
@@ -292,11 +304,11 @@ def _run_single_child(
|
||||
|
||||
|
||||
def delegate_task(
|
||||
goal: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
toolsets: Optional[List[str]] = None,
|
||||
tasks: Optional[List[Dict[str, Any]]] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
goal: str | None = None,
|
||||
context: str | None = None,
|
||||
toolsets: list[str] | None = None,
|
||||
tasks: list[dict[str, Any]] | None = None,
|
||||
max_iterations: int | None = None,
|
||||
parent_agent=None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -312,14 +324,11 @@ def delegate_task(
|
||||
return json.dumps({"error": "delegate_task requires a parent agent context."})
|
||||
|
||||
# Depth limit
|
||||
depth = getattr(parent_agent, '_delegate_depth', 0)
|
||||
depth = getattr(parent_agent, "_delegate_depth", 0)
|
||||
if depth >= MAX_DEPTH:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Delegation depth limit reached ({MAX_DEPTH}). "
|
||||
"Subagents cannot spawn further subagents."
|
||||
)
|
||||
})
|
||||
return json.dumps(
|
||||
{"error": (f"Delegation depth limit reached ({MAX_DEPTH}). Subagents cannot spawn further subagents.")}
|
||||
)
|
||||
|
||||
# Load config
|
||||
cfg = _load_config()
|
||||
@@ -366,7 +375,7 @@ def delegate_task(
|
||||
else:
|
||||
# Batch -- run in parallel with per-task progress lines
|
||||
completed_count = 0
|
||||
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
|
||||
spinner_ref = getattr(parent_agent, "_delegate_spinner", None)
|
||||
|
||||
# Save stdout/stderr before the executor — redirect_stdout in child
|
||||
# threads races on sys.stdout and can leave it as devnull permanently.
|
||||
@@ -412,7 +421,7 @@ def delegate_task(
|
||||
status = entry.get("status", "?")
|
||||
icon = "✓" if status == "completed" else "✗"
|
||||
remaining = n_tasks - completed_count
|
||||
completion_line = f"{icon} [{idx+1}/{n_tasks}] {label} ({dur}s)"
|
||||
completion_line = f"{icon} [{idx + 1}/{n_tasks}] {label} ({dur}s)"
|
||||
if spinner_ref:
|
||||
try:
|
||||
spinner_ref.print_above(completion_line)
|
||||
@@ -437,16 +446,20 @@ def delegate_task(
|
||||
|
||||
total_duration = round(time.monotonic() - overall_start, 2)
|
||||
|
||||
return json.dumps({
|
||||
"results": results,
|
||||
"total_duration_seconds": total_duration,
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"results": results,
|
||||
"total_duration_seconds": total_duration,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Load delegation config from CLI_CONFIG if available."""
|
||||
try:
|
||||
from cli import CLI_CONFIG
|
||||
|
||||
return CLI_CONFIG.get("delegation", {})
|
||||
except Exception:
|
||||
return {}
|
||||
@@ -537,10 +550,7 @@ DELEGATE_TASK_SCHEMA = {
|
||||
},
|
||||
"max_iterations": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max tool-calling turns per subagent (default: 50). "
|
||||
"Only set lower for simple tasks."
|
||||
),
|
||||
"description": ("Max tool-calling turns per subagent (default: 50). Only set lower for simple tasks."),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -561,6 +571,7 @@ registry.register(
|
||||
toolsets=args.get("toolsets"),
|
||||
tasks=args.get("tasks"),
|
||||
max_iterations=args.get("max_iterations"),
|
||||
parent_agent=kw.get("parent_agent")),
|
||||
parent_agent=kw.get("parent_agent"),
|
||||
),
|
||||
check_fn=check_delegate_requirements,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Base class for all Hermes execution environment backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -34,9 +34,9 @@ class BaseEnvironment(ABC):
|
||||
self.env = env or {}
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
...
|
||||
|
||||
@@ -62,10 +62,10 @@ class BaseEnvironment(ABC):
|
||||
def _prepare_command(self, command: str) -> str:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
|
||||
return _transform_sudo_command(command)
|
||||
|
||||
def _build_run_kwargs(self, timeout: int | None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def _build_run_kwargs(self, timeout: int | None, stdin_data: str | None = None) -> dict:
|
||||
"""Build common subprocess.run kwargs for non-interactive execution."""
|
||||
kw = {
|
||||
"text": True,
|
||||
|
||||
@@ -11,7 +11,6 @@ import shlex
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
@@ -32,8 +31,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
cwd: str = "/home/daytona",
|
||||
timeout: int = 60,
|
||||
cpu: int = 1,
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
@@ -41,8 +40,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
from daytona import (
|
||||
Daytona,
|
||||
CreateSandboxFromImageParams,
|
||||
Daytona,
|
||||
DaytonaError,
|
||||
Resources,
|
||||
SandboxState,
|
||||
@@ -73,13 +72,11 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
try:
|
||||
self._sandbox = self._daytona.find_one(labels=labels)
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: resumed sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
logger.info("Daytona: resumed sandbox %s for task %s", self._sandbox.id, task_id)
|
||||
except DaytonaError:
|
||||
self._sandbox = None
|
||||
except Exception as e:
|
||||
logger.warning("Daytona: failed to resume sandbox for task %s: %s",
|
||||
task_id, e)
|
||||
logger.warning("Daytona: failed to resume sandbox for task %s: %s", task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# Create a fresh sandbox if we don't have one
|
||||
@@ -92,8 +89,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
resources=resources,
|
||||
)
|
||||
)
|
||||
logger.info("Daytona: created sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
logger.info("Daytona: created sandbox %s for task %s", self._sandbox.id, task_id)
|
||||
|
||||
# Resolve cwd: detect actual home dir inside the sandbox
|
||||
if self._requested_cwd in ("~", "/home/daytona"):
|
||||
@@ -112,7 +108,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
||||
|
||||
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
|
||||
def _exec_in_thread(self, exec_command: str, cwd: str | None, timeout: int) -> dict:
|
||||
"""Run exec in a background thread with interrupt polling.
|
||||
|
||||
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
|
||||
@@ -130,7 +126,8 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
def _run():
|
||||
try:
|
||||
response = self._sandbox.process.exec(
|
||||
timed_command, cwd=cwd,
|
||||
timed_command,
|
||||
cwd=cwd,
|
||||
)
|
||||
result_holder["value"] = {
|
||||
"output": response.result or "",
|
||||
@@ -169,9 +166,9 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
return {"error": result_holder["error"]}
|
||||
return result_holder["value"]
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: Optional[int] = None,
|
||||
stdin_data: Optional[str] = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
with self._lock:
|
||||
self._ensure_sandbox_ready()
|
||||
|
||||
@@ -189,6 +186,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
|
||||
if "error" in result:
|
||||
from daytona import DaytonaError
|
||||
|
||||
err = result["error"]
|
||||
if isinstance(err, DaytonaError):
|
||||
with self._lock:
|
||||
@@ -210,8 +208,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
try:
|
||||
if self._persistent:
|
||||
self._sandbox.stop()
|
||||
logger.info("Daytona: stopped sandbox %s (filesystem preserved)",
|
||||
self._sandbox.id)
|
||||
logger.info("Daytona: stopped sandbox %s (filesystem preserved)", self._sandbox.id)
|
||||
else:
|
||||
self._daytona.delete(self._sandbox)
|
||||
logger.info("Daytona: deleted sandbox %s", self._sandbox.id)
|
||||
|
||||
@@ -11,7 +11,6 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
@@ -19,7 +18,6 @@ from tools.interrupt import is_interrupted
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
# Security flags applied to every container.
|
||||
# The container itself is the security boundary (isolated from host).
|
||||
# We drop all capabilities then add back the minimum needed:
|
||||
@@ -28,19 +26,28 @@ logger = logging.getLogger(__name__)
|
||||
# Block privilege escalation and limit PIDs.
|
||||
# /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds).
|
||||
_SECURITY_ARGS = [
|
||||
"--cap-drop", "ALL",
|
||||
"--cap-add", "DAC_OVERRIDE",
|
||||
"--cap-add", "CHOWN",
|
||||
"--cap-add", "FOWNER",
|
||||
"--security-opt", "no-new-privileges",
|
||||
"--pids-limit", "256",
|
||||
"--tmpfs", "/tmp:rw,nosuid,size=512m",
|
||||
"--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m",
|
||||
"--tmpfs", "/run:rw,noexec,nosuid,size=64m",
|
||||
"--cap-drop",
|
||||
"ALL",
|
||||
"--cap-add",
|
||||
"DAC_OVERRIDE",
|
||||
"--cap-add",
|
||||
"CHOWN",
|
||||
"--cap-add",
|
||||
"FOWNER",
|
||||
"--security-opt",
|
||||
"no-new-privileges",
|
||||
"--pids-limit",
|
||||
"256",
|
||||
"--tmpfs",
|
||||
"/tmp:rw,nosuid,size=512m",
|
||||
"--tmpfs",
|
||||
"/var/tmp:rw,noexec,nosuid,size=256m",
|
||||
"--tmpfs",
|
||||
"/run:rw,noexec,nosuid,size=64m",
|
||||
]
|
||||
|
||||
|
||||
_storage_opt_ok: Optional[bool] = None # cached result across instances
|
||||
_storage_opt_ok: bool | None = None # cached result across instances
|
||||
|
||||
|
||||
class DockerEnvironment(BaseEnvironment):
|
||||
@@ -74,7 +81,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
self._base_image = image
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._container_id: Optional[str] = None
|
||||
self._container_id: str | None = None
|
||||
logger.info(f"DockerEnvironment volumes: {volumes}")
|
||||
# Ensure volumes is a list (config.yaml could be malformed)
|
||||
if volumes is not None and not isinstance(volumes, list):
|
||||
@@ -105,8 +112,8 @@ class DockerEnvironment(BaseEnvironment):
|
||||
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
||||
from tools.environments.base import get_sandbox_dir
|
||||
|
||||
self._workspace_dir: Optional[str] = None
|
||||
self._home_dir: Optional[str] = None
|
||||
self._workspace_dir: str | None = None
|
||||
self._home_dir: str | None = None
|
||||
if self._persistent:
|
||||
sandbox = get_sandbox_dir() / "docker" / task_id
|
||||
self._workspace_dir = str(sandbox / "workspace")
|
||||
@@ -114,14 +121,19 @@ class DockerEnvironment(BaseEnvironment):
|
||||
os.makedirs(self._workspace_dir, exist_ok=True)
|
||||
os.makedirs(self._home_dir, exist_ok=True)
|
||||
writable_args = [
|
||||
"-v", f"{self._workspace_dir}:/workspace",
|
||||
"-v", f"{self._home_dir}:/root",
|
||||
"-v",
|
||||
f"{self._workspace_dir}:/workspace",
|
||||
"-v",
|
||||
f"{self._home_dir}:/root",
|
||||
]
|
||||
else:
|
||||
writable_args = [
|
||||
"--tmpfs", "/workspace:rw,exec,size=10g",
|
||||
"--tmpfs", "/home:rw,exec,size=1g",
|
||||
"--tmpfs", "/root:rw,exec,size=1g",
|
||||
"--tmpfs",
|
||||
"/workspace:rw,exec,size=10g",
|
||||
"--tmpfs",
|
||||
"/home:rw,exec,size=1g",
|
||||
"--tmpfs",
|
||||
"/root:rw,exec,size=1g",
|
||||
]
|
||||
|
||||
# All containers get security hardening (capabilities dropped, no privilege
|
||||
@@ -129,7 +141,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
# can install packages as needed.
|
||||
# User-configured volume mounts (from config.yaml docker_volumes)
|
||||
volume_args = []
|
||||
for vol in (volumes or []):
|
||||
for vol in volumes or []:
|
||||
if not isinstance(vol, str):
|
||||
logger.warning(f"Docker volume entry is not a string: {vol!r}")
|
||||
continue
|
||||
@@ -146,7 +158,9 @@ class DockerEnvironment(BaseEnvironment):
|
||||
logger.info(f"Docker run_args: {all_run_args}")
|
||||
|
||||
self._inner = _Docker(
|
||||
image=image, cwd=cwd, timeout=timeout,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
run_args=all_run_args,
|
||||
)
|
||||
self._container_id = self._inner.container_id
|
||||
@@ -154,7 +168,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
@staticmethod
|
||||
def _storage_opt_supported() -> bool:
|
||||
"""Check if Docker's storage driver supports --storage-opt size=.
|
||||
|
||||
|
||||
Only overlay2 on XFS with pquota supports per-container disk quotas.
|
||||
Ubuntu (and most distros) default to ext4, where this flag errors out.
|
||||
"""
|
||||
@@ -164,7 +178,9 @@ class DockerEnvironment(BaseEnvironment):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info", "--format", "{{.Driver}}"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
driver = result.stdout.strip().lower()
|
||||
if driver != "overlay2":
|
||||
@@ -174,14 +190,15 @@ class DockerEnvironment(BaseEnvironment):
|
||||
# Probe by attempting a dry-ish run — the fastest reliable check.
|
||||
probe = subprocess.run(
|
||||
["docker", "create", "--storage-opt", "size=1m", "hello-world"],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
if probe.returncode == 0:
|
||||
# Clean up the created container
|
||||
container_id = probe.stdout.strip()
|
||||
if container_id:
|
||||
subprocess.run(["docker", "rm", container_id],
|
||||
capture_output=True, timeout=5)
|
||||
subprocess.run(["docker", "rm", container_id], capture_output=True, timeout=5)
|
||||
_storage_opt_ok = True
|
||||
else:
|
||||
_storage_opt_ok = False
|
||||
@@ -190,9 +207,9 @@ class DockerEnvironment(BaseEnvironment):
|
||||
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
||||
return _storage_opt_ok
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
exec_command = self._prepare_command(command)
|
||||
work_dir = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
@@ -218,7 +235,8 @@ class DockerEnvironment(BaseEnvironment):
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
@@ -269,6 +287,7 @@ class DockerEnvironment(BaseEnvironment):
|
||||
|
||||
if not self._persistent:
|
||||
import shutil
|
||||
|
||||
for d in (self._workspace_dir, self._home_dir):
|
||||
if d:
|
||||
shutil.rmtree(d, ignore_errors=True)
|
||||
|
||||
@@ -154,9 +154,9 @@ class LocalEnvironment(BaseEnvironment):
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
|
||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
|
||||
work_dir = cwd or self.cwd or os.getcwd()
|
||||
@@ -172,11 +172,7 @@ class LocalEnvironment(BaseEnvironment):
|
||||
# Wrap with output fences so we can later extract the real
|
||||
# command output and discard shell init/exit noise.
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}';"
|
||||
f" {exec_command};"
|
||||
f" __hermes_rc=$?;"
|
||||
f" printf '{_OUTPUT_FENCE}';"
|
||||
f" exit $__hermes_rc"
|
||||
f"printf '{_OUTPUT_FENCE}'; {exec_command}; __hermes_rc=$?; printf '{_OUTPUT_FENCE}'; exit $__hermes_rc"
|
||||
)
|
||||
# Ensure PATH always includes standard dirs — systemd services
|
||||
# and some terminal multiplexers inherit a minimal PATH.
|
||||
@@ -200,12 +196,14 @@ class LocalEnvironment(BaseEnvironment):
|
||||
)
|
||||
|
||||
if stdin_data is not None:
|
||||
|
||||
def _write_stdin():
|
||||
try:
|
||||
proc.stdin.write(stdin_data)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
_output_chunks: list[str] = []
|
||||
|
||||
@@ -8,10 +8,9 @@ project files, and config changes survive across sessions.
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
@@ -21,7 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
_SNAPSHOT_STORE = Path.home() / ".hermes" / "modal_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
def _load_snapshots() -> dict[str, str]:
|
||||
"""Load snapshot ID mapping from disk."""
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
@@ -31,7 +30,7 @@ def _load_snapshots() -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
def _save_snapshots(data: dict[str, str]) -> None:
|
||||
"""Persist snapshot ID mapping to disk."""
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
@@ -52,7 +51,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
image: str,
|
||||
cwd: str = "~",
|
||||
timeout: int = 60,
|
||||
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
|
||||
modal_sandbox_kwargs: dict[str, Any] | None = None,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
@@ -61,6 +60,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
if not ModalEnvironment._patches_applied:
|
||||
try:
|
||||
from environments.patches import apply_patches
|
||||
|
||||
apply_patches()
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -79,6 +79,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
if snapshot_id:
|
||||
try:
|
||||
import modal
|
||||
|
||||
restored_image = modal.Image.from_id(snapshot_id)
|
||||
logger.info("Modal: restoring from snapshot %s", snapshot_id[:20])
|
||||
except Exception as e:
|
||||
@@ -88,6 +89,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
effective_image = restored_image if restored_image else image
|
||||
|
||||
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
|
||||
|
||||
self._inner = SwerexModalEnvironment(
|
||||
image=effective_image,
|
||||
cwd=cwd,
|
||||
@@ -97,9 +99,9 @@ class ModalEnvironment(BaseEnvironment):
|
||||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
)
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
@@ -139,29 +141,29 @@ class ModalEnvironment(BaseEnvironment):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
if self._persistent:
|
||||
try:
|
||||
sandbox = getattr(self._inner, 'deployment', None)
|
||||
sandbox = getattr(sandbox, '_sandbox', None) if sandbox else None
|
||||
sandbox = getattr(self._inner, "deployment", None)
|
||||
sandbox = getattr(sandbox, "_sandbox", None) if sandbox else None
|
||||
if sandbox:
|
||||
import asyncio
|
||||
|
||||
async def _snapshot():
|
||||
img = await sandbox.snapshot_filesystem.aio()
|
||||
return img.object_id
|
||||
|
||||
try:
|
||||
snapshot_id = asyncio.run(_snapshot())
|
||||
except RuntimeError:
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
snapshot_id = pool.submit(
|
||||
asyncio.run, _snapshot()
|
||||
).result(timeout=60)
|
||||
snapshot_id = pool.submit(asyncio.run, _snapshot()).result(timeout=60)
|
||||
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = snapshot_id
|
||||
_save_snapshots(snapshots)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20], self._task_id)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s", snapshot_id[:20], self._task_id)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
if hasattr(self._inner, 'stop'):
|
||||
if hasattr(self._inner, "stop"):
|
||||
self._inner.stop()
|
||||
|
||||
@@ -10,11 +10,9 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
@@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
_SNAPSHOT_STORE = Path.home() / ".hermes" / "singularity_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
def _load_snapshots() -> dict[str, str]:
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
@@ -33,7 +31,7 @@ def _load_snapshots() -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
def _save_snapshots(data: dict[str, str]) -> None:
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
@@ -42,6 +40,7 @@ def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
# Singularity helpers (scratch dir, SIF cache, SIF building)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_scratch_dir() -> Path:
|
||||
"""Get the best directory for Singularity sandboxes.
|
||||
|
||||
@@ -58,6 +57,7 @@ def _get_scratch_dir() -> Path:
|
||||
return scratch_path
|
||||
|
||||
from tools.environments.base import get_sandbox_dir
|
||||
|
||||
sandbox = get_sandbox_dir() / "singularity"
|
||||
|
||||
scratch = Path("/scratch")
|
||||
@@ -93,12 +93,12 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
Returns the path unchanged if it's already a .sif file.
|
||||
For docker:// URLs, checks the cache and builds if needed.
|
||||
"""
|
||||
if image.endswith('.sif') and Path(image).exists():
|
||||
if image.endswith(".sif") and Path(image).exists():
|
||||
return image
|
||||
if not image.startswith('docker://'):
|
||||
if not image.startswith("docker://"):
|
||||
return image
|
||||
|
||||
image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-')
|
||||
image_name = image.replace("docker://", "").replace("/", "-").replace(":", "-")
|
||||
cache_dir = _get_apptainer_cache_dir()
|
||||
sif_path = cache_dir / f"{image_name}.sif"
|
||||
|
||||
@@ -123,7 +123,10 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[executable, "build", str(sif_path), image],
|
||||
capture_output=True, text=True, timeout=600, env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning("SIF build failed, falling back to docker:// URL")
|
||||
@@ -145,6 +148,7 @@ def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
# SingularityEnvironment
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SingularityEnvironment(BaseEnvironment):
|
||||
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
||||
|
||||
@@ -174,7 +178,7 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
self._instance_started = False
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._overlay_dir: Optional[Path] = None
|
||||
self._overlay_dir: Path | None = None
|
||||
|
||||
# Resource limits
|
||||
self._cpu = cpu
|
||||
@@ -215,14 +219,13 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to start instance: {result.stderr}")
|
||||
self._instance_started = True
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
self.instance_id, self._persistent)
|
||||
logger.info("Singularity instance %s started (persistent=%s)", self.instance_id, self._persistent)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError("Instance start timed out")
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
if not self._instance_started:
|
||||
return {"output": "Instance not started", "returncode": -1}
|
||||
|
||||
@@ -235,16 +238,16 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
exec_command = f"cd {work_dir} && {exec_command}"
|
||||
work_dir = "/tmp"
|
||||
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir,
|
||||
f"instance://{self.instance_id}",
|
||||
"bash", "-c", exec_command]
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir, f"instance://{self.instance_id}", "bash", "-c", exec_command]
|
||||
|
||||
try:
|
||||
import time as _time
|
||||
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
@@ -295,7 +298,9 @@ class SingularityEnvironment(BaseEnvironment):
|
||||
try:
|
||||
subprocess.run(
|
||||
[self.executable, "instance", "stop", self.instance_id],
|
||||
capture_output=True, text=True, timeout=30,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
logger.info("Singularity instance %s stopped", self.instance_id)
|
||||
except Exception as e:
|
||||
|
||||
@@ -24,8 +24,7 @@ class SSHEnvironment(BaseEnvironment):
|
||||
and a remote kill is attempted over the ControlMaster socket.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, user: str, cwd: str = "~",
|
||||
timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
def __init__(self, host: str, user: str, cwd: str = "~", timeout: int = 60, port: int = 22, key_path: str = ""):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.host = host
|
||||
self.user = user
|
||||
@@ -65,12 +64,12 @@ class SSHEnvironment(BaseEnvironment):
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
def execute(
|
||||
self, command: str, cwd: str = "", *, timeout: int | None = None, stdin_data: str | None = None
|
||||
) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command = self._prepare_command(command)
|
||||
wrapped = f'cd {work_dir} && {exec_command}'
|
||||
wrapped = f"cd {work_dir} && {exec_command}"
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
cmd = self._build_ssh_command()
|
||||
@@ -136,8 +135,7 @@ class SSHEnvironment(BaseEnvironment):
|
||||
def cleanup(self):
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
"-O", "exit", f"{self.user}@{self.host}"]
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}", "-O", "exit", f"{self.user}@{self.host}"]
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,11 +3,10 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
from agent.redact import redact_sensitive_text
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,14 +24,19 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
Thread-safe: uses the same per-task creation locks as terminal_tool to
|
||||
prevent duplicate sandbox creation from concurrent tool calls.
|
||||
"""
|
||||
from tools.terminal_tool import (
|
||||
_active_environments, _env_lock, _create_environment,
|
||||
_get_env_config, _last_activity, _start_cleanup_thread,
|
||||
_check_disk_usage_warning,
|
||||
_creation_locks, _creation_locks_lock,
|
||||
)
|
||||
import time
|
||||
|
||||
from tools.terminal_tool import (
|
||||
_active_environments,
|
||||
_create_environment,
|
||||
_creation_locks,
|
||||
_creation_locks_lock,
|
||||
_env_lock,
|
||||
_get_env_config,
|
||||
_last_activity,
|
||||
_start_cleanup_thread,
|
||||
)
|
||||
|
||||
# Fast path: check cache -- but also verify the underlying environment
|
||||
# is still alive (it may have been killed by the cleanup thread).
|
||||
with _file_ops_lock:
|
||||
@@ -143,17 +147,23 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||
result = file_ops.write_file(path, content)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
|
||||
print(f"[FileTools] write_file error: {type(e).__name__}: {e}", flush=True)
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
new_string: str = None, replace_all: bool = False, patch: str = None,
|
||||
task_id: str = "default") -> str:
|
||||
def patch_tool(
|
||||
mode: str = "replace",
|
||||
path: str = None,
|
||||
old_string: str = None,
|
||||
new_string: str = None,
|
||||
replace_all: bool = False,
|
||||
patch: str = None,
|
||||
task_id: str = "default",
|
||||
) -> str:
|
||||
"""Patch a file using replace mode or V4A patch format."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
|
||||
|
||||
if mode == "replace":
|
||||
if not path:
|
||||
return json.dumps({"error": "path required"})
|
||||
@@ -166,7 +176,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
result = file_ops.patch_v4a(patch)
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown mode: {mode}"})
|
||||
|
||||
|
||||
result_dict = result.to_dict()
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
# Hint when old_string not found — saves iterations where the agent
|
||||
@@ -178,20 +188,33 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def search_tool(pattern: str, target: str = "content", path: str = ".",
|
||||
file_glob: str = None, limit: int = 50, offset: int = 0,
|
||||
output_mode: str = "content", context: int = 0,
|
||||
task_id: str = "default") -> str:
|
||||
def search_tool(
|
||||
pattern: str,
|
||||
target: str = "content",
|
||||
path: str = ".",
|
||||
file_glob: str = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
output_mode: str = "content",
|
||||
context: int = 0,
|
||||
task_id: str = "default",
|
||||
) -> str:
|
||||
"""Search for content or files."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.search(
|
||||
pattern=pattern, path=path, target=target, file_glob=file_glob,
|
||||
limit=limit, offset=offset, output_mode=output_mode, context=context
|
||||
pattern=pattern,
|
||||
path=path,
|
||||
target=target,
|
||||
file_glob=file_glob,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
output_mode=output_mode,
|
||||
context=context,
|
||||
)
|
||||
if hasattr(result, 'matches'):
|
||||
if hasattr(result, "matches"):
|
||||
for m in result.matches:
|
||||
if hasattr(m, 'content') and m.content:
|
||||
if hasattr(m, "content") and m.content:
|
||||
m.content = redact_sensitive_text(m.content)
|
||||
result_dict = result.to_dict()
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
@@ -209,7 +232,7 @@ FILE_TOOLS = [
|
||||
{"name": "read_file", "function": read_file_tool},
|
||||
{"name": "write_file", "function": write_file_tool},
|
||||
{"name": "patch", "function": patch_tool},
|
||||
{"name": "search_files", "function": search_tool}
|
||||
{"name": "search_files", "function": search_tool},
|
||||
]
|
||||
|
||||
|
||||
@@ -227,8 +250,10 @@ from tools.registry import registry
|
||||
def _check_file_reqs():
|
||||
"""Lazy wrapper to avoid circular import with tools/__init__.py."""
|
||||
from tools import check_file_requirements
|
||||
|
||||
return check_file_requirements()
|
||||
|
||||
|
||||
READ_FILE_SCHEMA = {
|
||||
"name": "read_file",
|
||||
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
|
||||
@@ -236,11 +261,21 @@ READ_FILE_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the file to read (absolute, relative, or ~/path)"},
|
||||
"offset": {"type": "integer", "description": "Line number to start reading from (1-indexed, default: 1)", "default": 1, "minimum": 1},
|
||||
"limit": {"type": "integer", "description": "Maximum number of lines to read (default: 500, max: 2000)", "default": 500, "maximum": 2000}
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, default: 1)",
|
||||
"default": 1,
|
||||
"minimum": 1,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (default: 500, max: 2000)",
|
||||
"default": 500,
|
||||
"maximum": 2000,
|
||||
},
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
"required": ["path"],
|
||||
},
|
||||
}
|
||||
|
||||
WRITE_FILE_SCHEMA = {
|
||||
@@ -249,11 +284,14 @@ WRITE_FILE_SCHEMA = {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"},
|
||||
"content": {"type": "string", "description": "Complete content to write to the file"}
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)",
|
||||
},
|
||||
"content": {"type": "string", "description": "Complete content to write to the file"},
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
"required": ["path", "content"],
|
||||
},
|
||||
}
|
||||
|
||||
PATCH_SCHEMA = {
|
||||
@@ -262,15 +300,33 @@ PATCH_SCHEMA = {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {"type": "string", "enum": ["replace", "patch"], "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", "default": "replace"},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["replace", "patch"],
|
||||
"description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches",
|
||||
"default": "replace",
|
||||
},
|
||||
"path": {"type": "string", "description": "File path to edit (required for 'replace' mode)"},
|
||||
"old_string": {"type": "string", "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."},
|
||||
"new_string": {"type": "string", "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."},
|
||||
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match (default: false)", "default": False},
|
||||
"patch": {"type": "string", "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"}
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness.",
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text.",
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences instead of requiring a unique match (default: false)",
|
||||
"default": False,
|
||||
},
|
||||
"patch": {
|
||||
"type": "string",
|
||||
"description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch",
|
||||
},
|
||||
},
|
||||
"required": ["mode"]
|
||||
}
|
||||
"required": ["mode"],
|
||||
},
|
||||
}
|
||||
|
||||
SEARCH_FILES_SCHEMA = {
|
||||
@@ -279,23 +335,57 @@ SEARCH_FILES_SCHEMA = {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search"},
|
||||
"target": {"type": "string", "enum": ["content", "files"], "description": "'content' searches inside file contents, 'files' searches for files by name", "default": "content"},
|
||||
"path": {"type": "string", "description": "Directory or file to search in (default: current working directory)", "default": "."},
|
||||
"file_glob": {"type": "string", "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)"},
|
||||
"limit": {"type": "integer", "description": "Maximum number of results to return (default: 50)", "default": 50},
|
||||
"offset": {"type": "integer", "description": "Skip first N results for pagination (default: 0)", "default": 0},
|
||||
"output_mode": {"type": "string", "enum": ["content", "files_only", "count"], "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", "default": "content"},
|
||||
"context": {"type": "integer", "description": "Number of context lines before and after each match (grep mode only)", "default": 0}
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files"],
|
||||
"description": "'content' searches inside file contents, 'files' searches for files by name",
|
||||
"default": "content",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search in (default: current working directory)",
|
||||
"default": ".",
|
||||
},
|
||||
"file_glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 50)",
|
||||
"default": 50,
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Skip first N results for pagination (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
"output_mode": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files_only", "count"],
|
||||
"description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file",
|
||||
"default": "content",
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Number of context lines before and after each match (grep mode only)",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
"required": ["pattern"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_read_file(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid)
|
||||
return read_file_tool(
|
||||
path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid
|
||||
)
|
||||
|
||||
|
||||
def _handle_write_file(args, **kw):
|
||||
@@ -306,9 +396,14 @@ def _handle_write_file(args, **kw):
|
||||
def _handle_patch(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return patch_tool(
|
||||
mode=args.get("mode", "replace"), path=args.get("path"),
|
||||
old_string=args.get("old_string"), new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid)
|
||||
mode=args.get("mode", "replace"),
|
||||
path=args.get("path"),
|
||||
old_string=args.get("old_string"),
|
||||
new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False),
|
||||
patch=args.get("patch"),
|
||||
task_id=tid,
|
||||
)
|
||||
|
||||
|
||||
def _handle_search_files(args, **kw):
|
||||
@@ -317,12 +412,29 @@ def _handle_search_files(args, **kw):
|
||||
raw_target = args.get("target", "content")
|
||||
target = target_map.get(raw_target, raw_target)
|
||||
return search_tool(
|
||||
pattern=args.get("pattern", ""), target=target, path=args.get("path", "."),
|
||||
file_glob=args.get("file_glob"), limit=args.get("limit", 50), offset=args.get("offset", 0),
|
||||
output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid)
|
||||
pattern=args.get("pattern", ""),
|
||||
target=target,
|
||||
path=args.get("path", "."),
|
||||
file_glob=args.get("file_glob"),
|
||||
limit=args.get("limit", 50),
|
||||
offset=args.get("offset", 0),
|
||||
output_mode=args.get("output_mode", "content"),
|
||||
context=args.get("context", 0),
|
||||
task_id=tid,
|
||||
)
|
||||
|
||||
|
||||
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs)
|
||||
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs)
|
||||
registry.register(
|
||||
name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs
|
||||
)
|
||||
registry.register(
|
||||
name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs
|
||||
)
|
||||
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs)
|
||||
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs)
|
||||
registry.register(
|
||||
name="search_files",
|
||||
toolset="file",
|
||||
schema=SEARCH_FILES_SCHEMA,
|
||||
handler=_handle_search_files,
|
||||
check_fn=_check_file_reqs,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ The 9-strategy chain (inspired by OpenCode):
|
||||
|
||||
Usage:
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
|
||||
new_content, match_count, error = fuzzy_find_and_replace(
|
||||
content="def foo():\\n pass",
|
||||
old_string="def foo():",
|
||||
@@ -29,21 +29,22 @@ Usage:
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
from collections.abc import Callable
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
|
||||
def fuzzy_find_and_replace(
|
||||
content: str, old_string: str, new_string: str, replace_all: bool = False
|
||||
) -> tuple[str, int, str | None]:
|
||||
"""
|
||||
Find and replace text using a chain of increasingly fuzzy matching strategies.
|
||||
|
||||
|
||||
Args:
|
||||
content: The file content to search in
|
||||
old_string: The text to find
|
||||
new_string: The replacement text
|
||||
replace_all: If True, replace all occurrences; if False, require uniqueness
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (new_content, match_count, error_message)
|
||||
- If successful: (modified_content, number_of_replacements, None)
|
||||
@@ -51,12 +52,12 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
"""
|
||||
if not old_string:
|
||||
return content, 0, "old_string cannot be empty"
|
||||
|
||||
|
||||
if old_string == new_string:
|
||||
return content, 0, "old_string and new_string are identical"
|
||||
|
||||
|
||||
# Try each matching strategy in order
|
||||
strategies: List[Tuple[str, Callable]] = [
|
||||
strategies: list[tuple[str, Callable]] = [
|
||||
("exact", _strategy_exact),
|
||||
("line_trimmed", _strategy_line_trimmed),
|
||||
("whitespace_normalized", _strategy_whitespace_normalized),
|
||||
@@ -66,46 +67,50 @@ def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
("block_anchor", _strategy_block_anchor),
|
||||
("context_aware", _strategy_context_aware),
|
||||
]
|
||||
|
||||
|
||||
for strategy_name, strategy_fn in strategies:
|
||||
matches = strategy_fn(content, old_string)
|
||||
|
||||
|
||||
if matches:
|
||||
# Found matches with this strategy
|
||||
if len(matches) > 1 and not replace_all:
|
||||
return content, 0, (
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
return (
|
||||
content,
|
||||
0,
|
||||
(
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Perform replacement
|
||||
new_content = _apply_replacements(content, matches, new_string)
|
||||
return new_content, len(matches), None
|
||||
|
||||
|
||||
# No strategy found a match
|
||||
return content, 0, "Could not find a match for old_string in the file"
|
||||
|
||||
|
||||
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
|
||||
def _apply_replacements(content: str, matches: list[tuple[int, int]], new_string: str) -> str:
|
||||
"""
|
||||
Apply replacements at the given positions.
|
||||
|
||||
|
||||
Args:
|
||||
content: Original content
|
||||
matches: List of (start, end) positions to replace
|
||||
new_string: Replacement text
|
||||
|
||||
|
||||
Returns:
|
||||
Content with replacements applied
|
||||
"""
|
||||
# Sort matches by position (descending) to replace from end to start
|
||||
# This preserves positions of earlier matches
|
||||
sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True)
|
||||
|
||||
|
||||
result = content
|
||||
for start, end in sorted_matches:
|
||||
result = result[:start] + new_string + result[end:]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -113,7 +118,8 @@ def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string
|
||||
# Matching Strategies
|
||||
# =============================================================================
|
||||
|
||||
def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
|
||||
def _strategy_exact(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""Strategy 1: Exact string match."""
|
||||
matches = []
|
||||
start = 0
|
||||
@@ -126,206 +132,201 @@ def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_line_trimmed(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 2: Match with line-by-line whitespace trimming.
|
||||
|
||||
|
||||
Strips leading/trailing whitespace from each line before matching.
|
||||
"""
|
||||
# Normalize pattern and content by trimming each line
|
||||
pattern_lines = [line.strip() for line in pattern.split('\n')]
|
||||
pattern_normalized = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
pattern_lines = [line.strip() for line in pattern.split("\n")]
|
||||
pattern_normalized = "\n".join(pattern_lines)
|
||||
|
||||
content_lines = content.split("\n")
|
||||
content_normalized_lines = [line.strip() for line in content_lines]
|
||||
|
||||
|
||||
# Build mapping from normalized positions back to original positions
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_normalized_lines,
|
||||
pattern, pattern_normalized
|
||||
)
|
||||
return _find_normalized_matches(content, content_lines, content_normalized_lines, pattern, pattern_normalized)
|
||||
|
||||
|
||||
def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_whitespace_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 3: Collapse multiple whitespace to single space.
|
||||
"""
|
||||
|
||||
def normalize(s):
|
||||
# Collapse multiple spaces/tabs to single space, preserve newlines
|
||||
return re.sub(r'[ \t]+', ' ', s)
|
||||
|
||||
return re.sub(r"[ \t]+", " ", s)
|
||||
|
||||
pattern_normalized = normalize(pattern)
|
||||
content_normalized = normalize(content)
|
||||
|
||||
|
||||
# Find in normalized, map back to original
|
||||
matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized)
|
||||
|
||||
|
||||
if not matches_in_normalized:
|
||||
return []
|
||||
|
||||
|
||||
# Map positions back to original content
|
||||
return _map_normalized_positions(content, content_normalized, matches_in_normalized)
|
||||
|
||||
|
||||
def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_indentation_flexible(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 4: Ignore indentation differences entirely.
|
||||
|
||||
|
||||
Strips all leading whitespace from lines before matching.
|
||||
"""
|
||||
|
||||
def strip_indent(s):
|
||||
return '\n'.join(line.lstrip() for line in s.split('\n'))
|
||||
|
||||
return "\n".join(line.lstrip() for line in s.split("\n"))
|
||||
|
||||
pattern_stripped = strip_indent(pattern)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
content_lines = content.split("\n")
|
||||
content_stripped_lines = [line.lstrip() for line in content_lines]
|
||||
pattern_lines = [line.lstrip() for line in pattern.split('\n')]
|
||||
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_stripped_lines,
|
||||
pattern, '\n'.join(pattern_lines)
|
||||
)
|
||||
pattern_lines = [line.lstrip() for line in pattern.split("\n")]
|
||||
|
||||
return _find_normalized_matches(content, content_lines, content_stripped_lines, pattern, "\n".join(pattern_lines))
|
||||
|
||||
|
||||
def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_escape_normalized(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 5: Convert escape sequences to actual characters.
|
||||
|
||||
|
||||
Handles \\n -> newline, \\t -> tab, etc.
|
||||
"""
|
||||
|
||||
def unescape(s):
|
||||
# Convert common escape sequences
|
||||
return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r')
|
||||
|
||||
return s.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
|
||||
|
||||
pattern_unescaped = unescape(pattern)
|
||||
|
||||
|
||||
if pattern_unescaped == pattern:
|
||||
# No escapes to convert, skip this strategy
|
||||
return []
|
||||
|
||||
|
||||
return _strategy_exact(content, pattern_unescaped)
|
||||
|
||||
|
||||
def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_trimmed_boundary(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 6: Trim whitespace from first and last lines only.
|
||||
|
||||
|
||||
Useful when the pattern boundaries have whitespace differences.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
pattern_lines = pattern.split("\n")
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
|
||||
# Trim only first and last lines
|
||||
pattern_lines[0] = pattern_lines[0].strip()
|
||||
if len(pattern_lines) > 1:
|
||||
pattern_lines[-1] = pattern_lines[-1].strip()
|
||||
|
||||
modified_pattern = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
|
||||
modified_pattern = "\n".join(pattern_lines)
|
||||
|
||||
content_lines = content.split("\n")
|
||||
|
||||
# Search through content for matching block
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
block_lines = content_lines[i : i + pattern_line_count]
|
||||
|
||||
# Trim first and last of this block
|
||||
check_lines = block_lines.copy()
|
||||
check_lines[0] = check_lines[0].strip()
|
||||
if len(check_lines) > 1:
|
||||
check_lines[-1] = check_lines[-1].strip()
|
||||
|
||||
if '\n'.join(check_lines) == modified_pattern:
|
||||
|
||||
if "\n".join(check_lines) == modified_pattern:
|
||||
# Found match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 7: Match by anchoring on first and last lines.
|
||||
|
||||
|
||||
If first and last lines match exactly, accept middle with 70% similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
pattern_lines = pattern.split("\n")
|
||||
if len(pattern_lines) < 2:
|
||||
return [] # Need at least 2 lines for anchoring
|
||||
|
||||
|
||||
first_line = pattern_lines[0].strip()
|
||||
last_line = pattern_lines[-1].strip()
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
content_lines = content.split("\n")
|
||||
matches = []
|
||||
|
||||
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
# Check if first and last lines match
|
||||
if (content_lines[i].strip() == first_line and
|
||||
content_lines[i + pattern_line_count - 1].strip() == last_line):
|
||||
|
||||
if content_lines[i].strip() == first_line and content_lines[i + pattern_line_count - 1].strip() == last_line:
|
||||
# Check middle similarity
|
||||
if pattern_line_count <= 2:
|
||||
# Only first and last, they match
|
||||
similarity = 1.0
|
||||
else:
|
||||
content_middle = '\n'.join(content_lines[i+1:i+pattern_line_count-1])
|
||||
pattern_middle = '\n'.join(pattern_lines[1:-1])
|
||||
content_middle = "\n".join(content_lines[i + 1 : i + pattern_line_count - 1])
|
||||
pattern_middle = "\n".join(pattern_lines[1:-1])
|
||||
similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio()
|
||||
|
||||
|
||||
if similarity >= 0.70:
|
||||
# Calculate positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
def _strategy_context_aware(content: str, pattern: str) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Strategy 8: Line-by-line similarity with 50% threshold.
|
||||
|
||||
|
||||
Finds blocks where at least 50% of lines have high similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
content_lines = content.split('\n')
|
||||
|
||||
pattern_lines = pattern.split("\n")
|
||||
content_lines = content.split("\n")
|
||||
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
block_lines = content_lines[i : i + pattern_line_count]
|
||||
|
||||
# Calculate line-by-line similarity
|
||||
high_similarity_count = 0
|
||||
for p_line, c_line in zip(pattern_lines, block_lines):
|
||||
sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio()
|
||||
if sim >= 0.80:
|
||||
high_similarity_count += 1
|
||||
|
||||
|
||||
# Need at least 50% of lines to have high similarity
|
||||
if high_similarity_count >= len(pattern_lines) * 0.5:
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
@@ -333,74 +334,76 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def _find_normalized_matches(content: str, content_lines: List[str],
|
||||
content_normalized_lines: List[str],
|
||||
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
|
||||
|
||||
def _find_normalized_matches(
|
||||
content: str, content_lines: list[str], content_normalized_lines: list[str], pattern: str, pattern_normalized: str
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Find matches in normalized content and map back to original positions.
|
||||
|
||||
|
||||
Args:
|
||||
content: Original content string
|
||||
content_lines: Original content split by lines
|
||||
content_normalized_lines: Normalized content lines
|
||||
pattern: Original pattern
|
||||
pattern_normalized: Normalized pattern
|
||||
|
||||
|
||||
Returns:
|
||||
List of (start, end) positions in the original content
|
||||
"""
|
||||
pattern_norm_lines = pattern_normalized.split('\n')
|
||||
pattern_norm_lines = pattern_normalized.split("\n")
|
||||
num_pattern_lines = len(pattern_norm_lines)
|
||||
|
||||
|
||||
matches = []
|
||||
|
||||
|
||||
for i in range(len(content_normalized_lines) - num_pattern_lines + 1):
|
||||
# Check if this block matches
|
||||
block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines])
|
||||
|
||||
block = "\n".join(content_normalized_lines[i : i + num_pattern_lines])
|
||||
|
||||
if block == pattern_normalized:
|
||||
# Found a match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1
|
||||
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[: i + num_pattern_lines]) - 1
|
||||
|
||||
# Handle case where end is past content
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
|
||||
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _map_normalized_positions(original: str, normalized: str,
|
||||
normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
||||
def _map_normalized_positions(
|
||||
original: str, normalized: str, normalized_matches: list[tuple[int, int]]
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Map positions from normalized string back to original.
|
||||
|
||||
|
||||
This is a best-effort mapping that works for whitespace normalization.
|
||||
"""
|
||||
if not normalized_matches:
|
||||
return []
|
||||
|
||||
|
||||
# Build character mapping from normalized to original
|
||||
orig_to_norm = [] # orig_to_norm[i] = position in normalized
|
||||
|
||||
|
||||
orig_idx = 0
|
||||
norm_idx = 0
|
||||
|
||||
|
||||
while orig_idx < len(original) and norm_idx < len(normalized):
|
||||
if original[orig_idx] == normalized[norm_idx]:
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ':
|
||||
elif original[orig_idx] in " \t" and normalized[norm_idx] == " ":
|
||||
# Original has space/tab, normalized collapsed to space
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
# Don't advance norm_idx yet - wait until all whitespace consumed
|
||||
if orig_idx < len(original) and original[orig_idx] not in ' \t':
|
||||
if orig_idx < len(original) and original[orig_idx] not in " \t":
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in ' \t':
|
||||
elif original[orig_idx] in " \t":
|
||||
# Extra whitespace in original
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
@@ -408,21 +411,21 @@ def _map_normalized_positions(original: str, normalized: str,
|
||||
# Mismatch - shouldn't happen with our normalization
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
|
||||
|
||||
# Fill remaining
|
||||
while orig_idx < len(original):
|
||||
orig_to_norm.append(len(normalized))
|
||||
orig_idx += 1
|
||||
|
||||
|
||||
# Reverse mapping: for each normalized position, find original range
|
||||
norm_to_orig_start = {}
|
||||
norm_to_orig_end = {}
|
||||
|
||||
|
||||
for orig_pos, norm_pos in enumerate(orig_to_norm):
|
||||
if norm_pos not in norm_to_orig_start:
|
||||
norm_to_orig_start[norm_pos] = orig_pos
|
||||
norm_to_orig_end[norm_pos] = orig_pos
|
||||
|
||||
|
||||
# Map matches
|
||||
original_matches = []
|
||||
for norm_start, norm_end in normalized_matches:
|
||||
@@ -432,17 +435,17 @@ def _map_normalized_positions(original: str, normalized: str,
|
||||
else:
|
||||
# Find nearest
|
||||
orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start)
|
||||
|
||||
|
||||
# Find original end
|
||||
if norm_end - 1 in norm_to_orig_end:
|
||||
orig_end = norm_to_orig_end[norm_end - 1] + 1
|
||||
else:
|
||||
orig_end = orig_start + (norm_end - norm_start)
|
||||
|
||||
|
||||
# Expand to include trailing whitespace that was normalized
|
||||
while orig_end < len(original) and original[orig_end] in ' \t':
|
||||
while orig_end < len(original) and original[orig_end] in " \t":
|
||||
orig_end += 1
|
||||
|
||||
|
||||
original_matches.append((orig_start, min(orig_end, len(original))))
|
||||
|
||||
|
||||
return original_matches
|
||||
|
||||
@@ -15,7 +15,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,23 +35,26 @@ def _get_config():
|
||||
_HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
|
||||
)
|
||||
|
||||
|
||||
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
|
||||
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
|
||||
|
||||
# Service domains blocked for security -- these allow arbitrary code/command
|
||||
# execution on the HA host or enable SSRF attacks on the local network.
|
||||
# HA provides zero service-level access control; all safety must be in our layer.
|
||||
_BLOCKED_DOMAINS = frozenset({
|
||||
"shell_command", # arbitrary shell commands as root in HA container
|
||||
"command_line", # sensors/switches that execute shell commands
|
||||
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||
"pyscript", # scripting integration with broader access
|
||||
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||
})
|
||||
_BLOCKED_DOMAINS = frozenset(
|
||||
{
|
||||
"shell_command", # arbitrary shell commands as root in HA container
|
||||
"command_line", # sensors/switches that execute shell commands
|
||||
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||
"pyscript", # scripting integration with broader access
|
||||
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_headers(token: str = "") -> Dict[str, str]:
|
||||
def _get_headers(token: str = "") -> dict[str, str]:
|
||||
"""Return authorization headers for HA REST API."""
|
||||
if not token:
|
||||
_, token = _get_config()
|
||||
@@ -65,11 +68,12 @@ def _get_headers(token: str = "") -> Dict[str, str]:
|
||||
# Async helpers (called from sync handlers via run_until_complete)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _filter_and_summarize(
|
||||
states: list,
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
domain: str | None = None,
|
||||
area: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Filter raw HA states by domain/area and return a compact summary."""
|
||||
if domain:
|
||||
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
|
||||
@@ -77,26 +81,29 @@ def _filter_and_summarize(
|
||||
if area:
|
||||
area_lower = area.lower()
|
||||
states = [
|
||||
s for s in states
|
||||
s
|
||||
for s in states
|
||||
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
|
||||
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
|
||||
]
|
||||
|
||||
entities = []
|
||||
for s in states:
|
||||
entities.append({
|
||||
"entity_id": s["entity_id"],
|
||||
"state": s["state"],
|
||||
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
||||
})
|
||||
entities.append(
|
||||
{
|
||||
"entity_id": s["entity_id"],
|
||||
"state": s["state"],
|
||||
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"count": len(entities), "entities": entities}
|
||||
|
||||
|
||||
async def _async_list_entities(
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
domain: str | None = None,
|
||||
area: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch entity states from HA and optionally filter by domain/area."""
|
||||
import aiohttp
|
||||
|
||||
@@ -110,7 +117,7 @@ async def _async_list_entities(
|
||||
return _filter_and_summarize(states, domain, area)
|
||||
|
||||
|
||||
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
||||
async def _async_get_state(entity_id: str) -> dict[str, Any]:
|
||||
"""Fetch detailed state of a single entity."""
|
||||
import aiohttp
|
||||
|
||||
@@ -131,11 +138,11 @@ async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def _build_service_payload(
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
entity_id: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the JSON payload for a HA service call."""
|
||||
payload: Dict[str, Any] = {}
|
||||
payload: dict[str, Any] = {}
|
||||
if data:
|
||||
payload.update(data)
|
||||
# entity_id parameter takes precedence over data["entity_id"]
|
||||
@@ -148,15 +155,17 @@ def _parse_service_response(
|
||||
domain: str,
|
||||
service: str,
|
||||
result: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Parse HA service call response into a structured result."""
|
||||
affected = []
|
||||
if isinstance(result, list):
|
||||
for s in result:
|
||||
affected.append({
|
||||
"entity_id": s.get("entity_id", ""),
|
||||
"state": s.get("state", ""),
|
||||
})
|
||||
affected.append(
|
||||
{
|
||||
"entity_id": s.get("entity_id", ""),
|
||||
"state": s.get("state", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -168,9 +177,9 @@ def _parse_service_response(
|
||||
async def _async_call_service(
|
||||
domain: str,
|
||||
service: str,
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
entity_id: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Call a Home Assistant service."""
|
||||
import aiohttp
|
||||
|
||||
@@ -178,15 +187,17 @@ async def _async_call_service(
|
||||
url = f"{hass_url}/api/services/{domain}/{service}"
|
||||
payload = _build_service_payload(entity_id, data)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
url,
|
||||
headers=_get_headers(hass_token),
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
result = await resp.json()
|
||||
) as resp,
|
||||
):
|
||||
resp.raise_for_status()
|
||||
result = await resp.json()
|
||||
|
||||
return _parse_service_response(domain, service, result)
|
||||
|
||||
@@ -195,6 +206,7 @@ async def _async_call_service(
|
||||
# Sync wrappers (handler signature: (args, **kw) -> str)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync handler."""
|
||||
try:
|
||||
@@ -205,6 +217,7 @@ def _run_async(coro):
|
||||
if loop and loop.is_running():
|
||||
# Already inside an event loop -- create a new thread
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=30)
|
||||
@@ -247,10 +260,12 @@ def _handle_call_service(args: dict, **kw) -> str:
|
||||
return json.dumps({"error": "Missing required parameters: domain and service"})
|
||||
|
||||
if domain in _BLOCKED_DOMAINS:
|
||||
return json.dumps({
|
||||
"error": f"Service domain '{domain}' is blocked for security. "
|
||||
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||
})
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"Service domain '{domain}' is blocked for security. "
|
||||
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||
}
|
||||
)
|
||||
|
||||
entity_id = args.get("entity_id")
|
||||
if entity_id and not _ENTITY_ID_RE.match(entity_id):
|
||||
@@ -269,7 +284,8 @@ def _handle_call_service(args: dict, **kw) -> str:
|
||||
# List services
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
async def _async_list_services(domain: str | None = None) -> dict[str, Any]:
|
||||
"""Fetch available services from HA and optionally filter by domain."""
|
||||
import aiohttp
|
||||
|
||||
@@ -290,13 +306,10 @@ async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
|
||||
d = svc_domain.get("domain", "")
|
||||
domain_services = {}
|
||||
for svc_name, svc_info in svc_domain.get("services", {}).items():
|
||||
svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
|
||||
svc_entry: dict[str, Any] = {"description": svc_info.get("description", "")}
|
||||
fields = svc_info.get("fields", {})
|
||||
if fields:
|
||||
svc_entry["fields"] = {
|
||||
k: v.get("description", "") for k, v in fields.items()
|
||||
if isinstance(v, dict)
|
||||
}
|
||||
svc_entry["fields"] = {k: v.get("description", "") for k, v in fields.items() if isinstance(v, dict)}
|
||||
domain_services[svc_name] = svc_entry
|
||||
result.append({"domain": d, "services": domain_services})
|
||||
|
||||
@@ -318,6 +331,7 @@ def _handle_list_services(args: dict, **kw) -> str:
|
||||
# Availability check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_ha_available() -> bool:
|
||||
"""Tool is only available when HASS_TOKEN is set."""
|
||||
return bool(os.getenv("HASS_TOKEN"))
|
||||
@@ -369,8 +383,7 @@ HA_GET_STATE_SCHEMA = {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The entity ID to query (e.g. 'light.living_room', "
|
||||
"'climate.thermostat', 'sensor.temperature')."
|
||||
"The entity ID to query (e.g. 'light.living_room', 'climate.thermostat', 'sensor.temperature')."
|
||||
),
|
||||
},
|
||||
},
|
||||
@@ -392,8 +405,7 @@ HA_LIST_SERVICES_SCHEMA = {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Filter by domain (e.g. 'light', 'climate', 'switch'). "
|
||||
"Omit to list services for all domains."
|
||||
"Filter by domain (e.g. 'light', 'climate', 'switch'). Omit to list services for all domains."
|
||||
),
|
||||
},
|
||||
},
|
||||
@@ -428,8 +440,7 @@ HA_CALL_SERVICE_SCHEMA = {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Target entity ID (e.g. 'light.living_room'). "
|
||||
"Some services (like scene.turn_on) may not need this."
|
||||
"Target entity ID (e.g. 'light.living_room'). Some services (like scene.turn_on) may not need this."
|
||||
),
|
||||
},
|
||||
"data": {
|
||||
|
||||
@@ -65,6 +65,7 @@ HONCHO_TOOL_SCHEMA = {
|
||||
|
||||
# ── Tool handler ──
|
||||
|
||||
|
||||
def _handle_query_user_context(args: dict, **kw) -> str:
|
||||
"""Execute the Honcho context query."""
|
||||
query = args.get("query", "")
|
||||
@@ -84,6 +85,7 @@ def _handle_query_user_context(args: dict, **kw) -> str:
|
||||
|
||||
# ── Availability check ──
|
||||
|
||||
|
||||
def _check_honcho_available() -> bool:
|
||||
"""Tool is only available when Honcho is active."""
|
||||
return _session_manager is not None and _session_key is not None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"""
|
||||
Image Generation Tools Module
|
||||
|
||||
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
|
||||
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
|
||||
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
|
||||
|
||||
Available tools:
|
||||
@@ -19,7 +19,7 @@ Features:
|
||||
Usage:
|
||||
from image_generation_tool import image_generate_tool
|
||||
import asyncio
|
||||
|
||||
|
||||
# Generate and automatically upscale an image
|
||||
result = await image_generate_tool(
|
||||
prompt="A serene mountain landscape with cherry blossoms",
|
||||
@@ -28,12 +28,14 @@ Usage:
|
||||
)
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import fal_client
|
||||
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -51,11 +53,7 @@ ENABLE_SAFETY_CHECKER = False
|
||||
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
|
||||
|
||||
# Aspect ratio mapping - simplified choices for model to select
|
||||
ASPECT_RATIO_MAP = {
|
||||
"landscape": "landscape_16_9",
|
||||
"square": "square_hd",
|
||||
"portrait": "portrait_16_9"
|
||||
}
|
||||
ASPECT_RATIO_MAP = {"landscape": "landscape_16_9", "square": "square_hd", "portrait": "portrait_16_9"}
|
||||
VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys())
|
||||
|
||||
# Configuration for automatic upscaling
|
||||
@@ -70,9 +68,7 @@ UPSCALER_GUIDANCE_SCALE = 4
|
||||
UPSCALER_NUM_INFERENCE_STEPS = 18
|
||||
|
||||
# Valid parameter values for validation based on FLUX 2 Pro documentation
|
||||
VALID_IMAGE_SIZES = [
|
||||
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
|
||||
]
|
||||
VALID_IMAGE_SIZES = ["square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"]
|
||||
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
|
||||
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
|
||||
|
||||
@@ -80,16 +76,16 @@ _debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
|
||||
|
||||
|
||||
def _validate_parameters(
|
||||
image_size: Union[str, Dict[str, int]],
|
||||
image_size: str | dict[str, int],
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
num_images: int,
|
||||
output_format: str,
|
||||
acceleration: str = "none"
|
||||
) -> Dict[str, Any]:
|
||||
acceleration: str = "none",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate and normalize image generation parameters for FLUX 2 Pro model.
|
||||
|
||||
|
||||
Args:
|
||||
image_size: Either a preset string or custom size dict
|
||||
num_inference_steps: Number of inference steps
|
||||
@@ -97,15 +93,15 @@ def _validate_parameters(
|
||||
num_images: Number of images to generate
|
||||
output_format: Output format for images
|
||||
acceleration: Acceleration mode for generation speed
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Validated and normalized parameters
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is invalid
|
||||
"""
|
||||
validated = {}
|
||||
|
||||
|
||||
# Validate image_size
|
||||
if isinstance(image_size, str):
|
||||
if image_size not in VALID_IMAGE_SIZES:
|
||||
@@ -123,52 +119,52 @@ def _validate_parameters(
|
||||
validated["image_size"] = image_size
|
||||
else:
|
||||
raise ValueError("image_size must be either a preset string or a dict with width/height")
|
||||
|
||||
|
||||
# Validate num_inference_steps
|
||||
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
|
||||
raise ValueError("num_inference_steps must be an integer between 1 and 100")
|
||||
validated["num_inference_steps"] = num_inference_steps
|
||||
|
||||
|
||||
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
|
||||
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
|
||||
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
|
||||
validated["guidance_scale"] = float(guidance_scale)
|
||||
|
||||
|
||||
# Validate num_images
|
||||
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
|
||||
raise ValueError("num_images must be an integer between 1 and 4")
|
||||
validated["num_images"] = num_images
|
||||
|
||||
|
||||
# Validate output_format
|
||||
if output_format not in VALID_OUTPUT_FORMATS:
|
||||
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
|
||||
validated["output_format"] = output_format
|
||||
|
||||
|
||||
# Validate acceleration
|
||||
if acceleration not in VALID_ACCELERATION_MODES:
|
||||
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
|
||||
validated["acceleration"] = acceleration
|
||||
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||
def _upscale_image(image_url: str, original_prompt: str) -> dict[str, Any]:
|
||||
"""
|
||||
Upscale an image using FAL.ai's Clarity Upscaler.
|
||||
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues
|
||||
when called from threaded contexts (e.g. gateway thread pool).
|
||||
|
||||
|
||||
Args:
|
||||
image_url (str): URL of the image to upscale
|
||||
original_prompt (str): Original prompt used to generate the image
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Upscaled image data or None if upscaling fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Upscaling image with Clarity Upscaler...")
|
||||
|
||||
|
||||
# Prepare arguments for upscaler
|
||||
upscaler_arguments = {
|
||||
"image_url": image_url,
|
||||
@@ -179,35 +175,36 @@ def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||
"resemblance": UPSCALER_RESEMBLANCE,
|
||||
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
|
||||
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER,
|
||||
}
|
||||
|
||||
|
||||
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
|
||||
# The async API (submit_async) caches a global httpx.AsyncClient via
|
||||
# @cached_property, which breaks when asyncio.run() destroys the loop
|
||||
# between calls (gateway thread-pool pattern).
|
||||
handler = fal_client.submit(
|
||||
UPSCALER_MODEL,
|
||||
arguments=upscaler_arguments
|
||||
)
|
||||
|
||||
handler = fal_client.submit(UPSCALER_MODEL, arguments=upscaler_arguments)
|
||||
|
||||
# Get the upscaled result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
|
||||
if result and "image" in result:
|
||||
upscaled_image = result["image"]
|
||||
logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown'))
|
||||
logger.info(
|
||||
"Image upscaled successfully to %sx%s",
|
||||
upscaled_image.get("width", "unknown"),
|
||||
upscaled_image.get("height", "unknown"),
|
||||
)
|
||||
return {
|
||||
"url": upscaled_image["url"],
|
||||
"width": upscaled_image.get("width", 0),
|
||||
"height": upscaled_image.get("height", 0),
|
||||
"upscaled": True,
|
||||
"upscale_factor": UPSCALER_FACTOR
|
||||
"upscale_factor": UPSCALER_FACTOR,
|
||||
}
|
||||
else:
|
||||
logger.error("Upscaler returned invalid response")
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error upscaling image: %s", e)
|
||||
return None
|
||||
@@ -220,16 +217,16 @@ def image_generate_tool(
|
||||
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
||||
num_images: int = DEFAULT_NUM_IMAGES,
|
||||
output_format: str = DEFAULT_OUTPUT_FORMAT,
|
||||
seed: Optional[int] = None
|
||||
seed: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
|
||||
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
|
||||
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
|
||||
when asyncio.run() destroys and recreates event loops between calls, which
|
||||
happens in the gateway's thread-pool pattern.
|
||||
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt describing the desired image
|
||||
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
|
||||
@@ -238,7 +235,7 @@ def image_generate_tool(
|
||||
num_images (int): Number of images to generate (1-4, default: 1)
|
||||
output_format (str): Image format "jpeg" or "png" (default: "png")
|
||||
seed (Optional[int]): Random seed for reproducible results (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON string containing minimal generation results:
|
||||
{
|
||||
@@ -252,7 +249,7 @@ def image_generate_tool(
|
||||
logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO)
|
||||
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
|
||||
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
|
||||
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"prompt": prompt,
|
||||
@@ -262,32 +259,32 @@ def image_generate_tool(
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_images": num_images,
|
||||
"output_format": output_format,
|
||||
"seed": seed
|
||||
"seed": seed,
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"images_generated": 0,
|
||||
"generation_time": 0
|
||||
"generation_time": 0,
|
||||
}
|
||||
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80])
|
||||
|
||||
|
||||
# Validate prompt
|
||||
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required and must be a non-empty string")
|
||||
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("FAL_KEY"):
|
||||
raise ValueError("FAL_KEY environment variable not set")
|
||||
|
||||
|
||||
# Validate other parameters
|
||||
validated_params = _validate_parameters(
|
||||
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
|
||||
)
|
||||
|
||||
|
||||
# Prepare arguments for FAL.ai FLUX 2 Pro API
|
||||
arguments = {
|
||||
"prompt": prompt.strip(),
|
||||
@@ -298,51 +295,44 @@ def image_generate_tool(
|
||||
"output_format": validated_params["output_format"],
|
||||
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
|
||||
"safety_tolerance": SAFETY_TOLERANCE,
|
||||
"sync_mode": True # Use sync mode for immediate results
|
||||
"sync_mode": True, # Use sync mode for immediate results
|
||||
}
|
||||
|
||||
|
||||
# Add seed if provided
|
||||
if seed is not None and isinstance(seed, int):
|
||||
arguments["seed"] = seed
|
||||
|
||||
|
||||
logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...")
|
||||
logger.info(" Model: %s", DEFAULT_MODEL)
|
||||
logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size)
|
||||
logger.info(" Steps: %s", validated_params['num_inference_steps'])
|
||||
logger.info(" Guidance: %s", validated_params['guidance_scale'])
|
||||
|
||||
logger.info(" Steps: %s", validated_params["num_inference_steps"])
|
||||
logger.info(" Guidance: %s", validated_params["guidance_scale"])
|
||||
|
||||
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
|
||||
handler = fal_client.submit(
|
||||
DEFAULT_MODEL,
|
||||
arguments=arguments
|
||||
)
|
||||
|
||||
handler = fal_client.submit(DEFAULT_MODEL, arguments=arguments)
|
||||
|
||||
# Get the result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
|
||||
|
||||
# Process the response
|
||||
if not result or "images" not in result:
|
||||
raise ValueError("Invalid response from FAL.ai API - no images returned")
|
||||
|
||||
|
||||
images = result.get("images", [])
|
||||
if not images:
|
||||
raise ValueError("No images were generated")
|
||||
|
||||
|
||||
# Format image data and upscale images
|
||||
formatted_images = []
|
||||
for img in images:
|
||||
if isinstance(img, dict) and "url" in img:
|
||||
original_image = {
|
||||
"url": img["url"],
|
||||
"width": img.get("width", 0),
|
||||
"height": img.get("height", 0)
|
||||
}
|
||||
|
||||
original_image = {"url": img["url"], "width": img.get("width", 0), "height": img.get("height", 0)}
|
||||
|
||||
# Attempt to upscale the image
|
||||
upscaled_image = _upscale_image(img["url"], prompt.strip())
|
||||
|
||||
|
||||
if upscaled_image:
|
||||
# Use upscaled image if successful
|
||||
formatted_images.append(upscaled_image)
|
||||
@@ -351,52 +341,48 @@ def image_generate_tool(
|
||||
logger.warning("Using original image as fallback")
|
||||
original_image["upscaled"] = False
|
||||
formatted_images.append(original_image)
|
||||
|
||||
|
||||
if not formatted_images:
|
||||
raise ValueError("No valid image URLs returned from API")
|
||||
|
||||
|
||||
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
|
||||
logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count)
|
||||
|
||||
logger.info(
|
||||
"Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count
|
||||
)
|
||||
|
||||
# Prepare successful response - minimal format
|
||||
response_data = {
|
||||
"success": True,
|
||||
"image": formatted_images[0]["url"] if formatted_images else None
|
||||
}
|
||||
|
||||
response_data = {"success": True, "image": formatted_images[0]["url"] if formatted_images else None}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["images_generated"] = len(formatted_images)
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("image_generate_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
error_msg = f"Error generating image: {str(e)}"
|
||||
logger.error("%s", error_msg)
|
||||
|
||||
|
||||
# Prepare error response - minimal format
|
||||
response_data = {
|
||||
"success": False,
|
||||
"image": None
|
||||
}
|
||||
|
||||
response_data = {"success": False, "image": None}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
_debug.log_call("image_generate_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_fal_api_key() -> bool:
|
||||
"""
|
||||
Check if the FAL.ai API key is available in environment variables.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
@@ -406,7 +392,7 @@ def check_fal_api_key() -> bool:
|
||||
def check_image_generation_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for image generation tools are met.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
@@ -414,19 +400,20 @@ def check_image_generation_requirements() -> bool:
|
||||
# Check API key
|
||||
if not check_fal_api_key():
|
||||
return False
|
||||
|
||||
|
||||
# Check if fal_client is available
|
||||
import fal_client
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
def get_debug_session_info() -> dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
@@ -439,10 +426,10 @@ if __name__ == "__main__":
|
||||
"""
|
||||
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_fal_api_key()
|
||||
|
||||
|
||||
if not api_available:
|
||||
print("❌ FAL_KEY environment variable not set")
|
||||
print("Please set your API key: export FAL_KEY='your-key-here'")
|
||||
@@ -450,27 +437,28 @@ if __name__ == "__main__":
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ FAL.ai API key found")
|
||||
|
||||
|
||||
# Check if fal_client is available
|
||||
try:
|
||||
import fal_client
|
||||
|
||||
print("✅ fal_client library available")
|
||||
except ImportError:
|
||||
print("❌ fal_client library not found")
|
||||
print("Please install: pip install fal-client")
|
||||
exit(1)
|
||||
|
||||
|
||||
print("🛠️ Image generation tools ready for use!")
|
||||
print(f"🤖 Using model: {DEFAULT_MODEL}")
|
||||
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
|
||||
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from image_generation_tool import image_generate_tool")
|
||||
print(" import asyncio")
|
||||
@@ -484,23 +472,23 @@ if __name__ == "__main__":
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
|
||||
print("\nSupported image sizes:")
|
||||
for size in VALID_IMAGE_SIZES:
|
||||
print(f" - {size}")
|
||||
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
|
||||
|
||||
|
||||
print("\nAcceleration modes:")
|
||||
for mode in VALID_ACCELERATION_MODES:
|
||||
print(f" - {mode}")
|
||||
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
|
||||
print(" - 'Modern architecture building with glass facade, sunset lighting'")
|
||||
print(" - 'Abstract art with vibrant colors and geometric patterns'")
|
||||
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
|
||||
print(" - 'Futuristic cityscape with flying cars and neon lights'")
|
||||
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export IMAGE_TOOLS_DEBUG=true")
|
||||
@@ -521,17 +509,17 @@ IMAGE_GENERATE_SCHEMA = {
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The text prompt describing the desired image. Be detailed and descriptive."
|
||||
"description": "The text prompt describing the desired image. Be detailed and descriptive.",
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["landscape", "square", "portrait"],
|
||||
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
|
||||
"default": "landscape"
|
||||
}
|
||||
"default": "landscape",
|
||||
},
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
"required": ["prompt"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -91,9 +91,11 @@ _MCP_SAMPLING_TYPES = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
_MCP_AVAILABLE = True
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
_MCP_HTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
@@ -108,6 +110,7 @@ try:
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
@@ -118,27 +121,36 @@ except ImportError:
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
|
||||
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
|
||||
_DEFAULT_TOOL_TIMEOUT = 120 # seconds for tool calls
|
||||
_DEFAULT_CONNECT_TIMEOUT = 60 # seconds for initial connection per server
|
||||
_MAX_RECONNECT_RETRIES = 5
|
||||
_MAX_BACKOFF_SECONDS = 60
|
||||
|
||||
# Environment variables that are safe to pass to stdio subprocesses
|
||||
_SAFE_ENV_KEYS = frozenset({
|
||||
"PATH", "HOME", "USER", "LANG", "LC_ALL", "TERM", "SHELL", "TMPDIR",
|
||||
})
|
||||
_SAFE_ENV_KEYS = frozenset(
|
||||
{
|
||||
"PATH",
|
||||
"HOME",
|
||||
"USER",
|
||||
"LANG",
|
||||
"LC_ALL",
|
||||
"TERM",
|
||||
"SHELL",
|
||||
"TMPDIR",
|
||||
}
|
||||
)
|
||||
|
||||
# Regex for credential patterns to strip from error messages
|
||||
_CREDENTIAL_PATTERN = re.compile(
|
||||
r"(?:"
|
||||
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
|
||||
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
|
||||
r"|Bearer\s+\S+" # Bearer token
|
||||
r"|token=[^\s&,;\"']{1,255}" # token=...
|
||||
r"|key=[^\s&,;\"']{1,255}" # key=...
|
||||
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
|
||||
r"|password=[^\s&,;\"']{1,255}" # password=...
|
||||
r"|secret=[^\s&,;\"']{1,255}" # secret=...
|
||||
r"ghp_[A-Za-z0-9_]{1,255}" # GitHub PAT
|
||||
r"|sk-[A-Za-z0-9_]{1,255}" # OpenAI-style key
|
||||
r"|Bearer\s+\S+" # Bearer token
|
||||
r"|token=[^\s&,;\"']{1,255}" # token=...
|
||||
r"|key=[^\s&,;\"']{1,255}" # key=...
|
||||
r"|API_KEY=[^\s&,;\"']{1,255}" # API_KEY=...
|
||||
r"|password=[^\s&,;\"']{1,255}" # password=...
|
||||
r"|secret=[^\s&,;\"']{1,255}" # secret=...
|
||||
r")",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
@@ -148,7 +160,8 @@ _CREDENTIAL_PATTERN = re.compile(
|
||||
# Security helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_safe_env(user_env: Optional[dict]) -> dict:
|
||||
|
||||
def _build_safe_env(user_env: dict | None) -> dict:
|
||||
"""Build a filtered environment dict for stdio subprocesses.
|
||||
|
||||
Only passes through safe baseline variables (PATH, HOME, etc.) and XDG_*
|
||||
@@ -180,6 +193,7 @@ def _sanitize_error(text: str) -> str:
|
||||
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _safe_numeric(value, default, coerce=int, minimum=1):
|
||||
"""Coerce a config value to a numeric type, returning *default* on failure.
|
||||
|
||||
@@ -216,18 +230,22 @@ class SamplingHandler:
|
||||
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
|
||||
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
|
||||
self.max_tool_rounds = _safe_numeric(
|
||||
config.get("max_tool_rounds", 5), 5, int, minimum=0,
|
||||
config.get("max_tool_rounds", 5),
|
||||
5,
|
||||
int,
|
||||
minimum=0,
|
||||
)
|
||||
self.model_override = config.get("model")
|
||||
self.allowed_models = config.get("allowed_models", [])
|
||||
|
||||
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
|
||||
self.audit_level = _log_levels.get(
|
||||
str(config.get("log_level", "info")).lower(), logging.INFO,
|
||||
str(config.get("log_level", "info")).lower(),
|
||||
logging.INFO,
|
||||
)
|
||||
|
||||
# Per-instance state
|
||||
self._rate_timestamps: List[float] = []
|
||||
self._rate_timestamps: list[float] = []
|
||||
self._tool_loop_count = 0
|
||||
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
@@ -245,7 +263,7 @@ class SamplingHandler:
|
||||
|
||||
# -- Model resolution ----------------------------------------------------
|
||||
|
||||
def _resolve_model(self, preferences) -> Optional[str]:
|
||||
def _resolve_model(self, preferences) -> str | None:
|
||||
"""Config override > server hint > None (use default)."""
|
||||
if self.model_override:
|
||||
return self.model_override
|
||||
@@ -265,7 +283,7 @@ class SamplingHandler:
|
||||
items = block.content if isinstance(block.content, list) else [block.content]
|
||||
return "\n".join(item.text for item in items if hasattr(item, "text"))
|
||||
|
||||
def _convert_messages(self, params) -> List[dict]:
|
||||
def _convert_messages(self, params) -> list[dict]:
|
||||
"""Convert MCP SamplingMessages to OpenAI format.
|
||||
|
||||
Uses ``msg.content_as_list`` (SDK helper) so single-block and
|
||||
@@ -273,37 +291,47 @@ class SamplingHandler:
|
||||
with ``isinstance`` on real SDK types when available, falling back
|
||||
to duck-typing via ``hasattr`` for compatibility.
|
||||
"""
|
||||
messages: List[dict] = []
|
||||
messages: list[dict] = []
|
||||
for msg in params.messages:
|
||||
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
|
||||
msg.content if isinstance(msg.content, list) else [msg.content]
|
||||
blocks = (
|
||||
msg.content_as_list
|
||||
if hasattr(msg, "content_as_list")
|
||||
else (msg.content if isinstance(msg.content, list) else [msg.content])
|
||||
)
|
||||
|
||||
# Separate blocks by kind
|
||||
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
|
||||
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
|
||||
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
|
||||
tool_uses = [
|
||||
b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")
|
||||
]
|
||||
content_blocks = [
|
||||
b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))
|
||||
]
|
||||
|
||||
# Emit tool result messages (role: tool)
|
||||
for tr in tool_results:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
}
|
||||
)
|
||||
|
||||
# Emit assistant tool_calls message
|
||||
if tool_uses:
|
||||
tc_list = []
|
||||
for tu in tool_uses:
|
||||
tc_list.append({
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
})
|
||||
tc_list.append(
|
||||
{
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
}
|
||||
)
|
||||
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
|
||||
# Include any accompanying text
|
||||
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
|
||||
@@ -320,10 +348,12 @@ class SamplingHandler:
|
||||
if hasattr(block, "text"):
|
||||
parts.append({"type": "text", "text": block.text})
|
||||
elif hasattr(block, "data") and hasattr(block, "mimeType"):
|
||||
parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
})
|
||||
parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported sampling content block type: %s (skipped)",
|
||||
@@ -352,16 +382,13 @@ class SamplingHandler:
|
||||
# Tool loop governance
|
||||
if self.max_tool_rounds == 0:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
|
||||
)
|
||||
return self._error(f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)")
|
||||
|
||||
self._tool_loop_count += 1
|
||||
if self._tool_loop_count > self.max_tool_rounds:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' "
|
||||
f"(max {self.max_tool_rounds} rounds)"
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' (max {self.max_tool_rounds} rounds)"
|
||||
)
|
||||
|
||||
content_blocks = []
|
||||
@@ -372,25 +399,28 @@ class SamplingHandler:
|
||||
parsed = json.loads(args)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(
|
||||
"MCP server '%s': malformed tool_calls arguments "
|
||||
"from LLM (wrapping as raw): %.100s",
|
||||
self.server_name, args,
|
||||
"MCP server '%s': malformed tool_calls arguments from LLM (wrapping as raw): %.100s",
|
||||
self.server_name,
|
||||
args,
|
||||
)
|
||||
parsed = {"_raw": args}
|
||||
else:
|
||||
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
|
||||
|
||||
content_blocks.append(ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
))
|
||||
content_blocks.append(
|
||||
ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
)
|
||||
)
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
|
||||
self.server_name, response.model,
|
||||
self.server_name,
|
||||
response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
len(content_blocks),
|
||||
)
|
||||
@@ -410,7 +440,8 @@ class SamplingHandler:
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s",
|
||||
self.server_name, response.model,
|
||||
self.server_name,
|
||||
response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
)
|
||||
|
||||
@@ -445,12 +476,12 @@ class SamplingHandler:
|
||||
if not self._check_rate_limit():
|
||||
logger.warning(
|
||||
"MCP server '%s' sampling rate limit exceeded (%d/min)",
|
||||
self.server_name, self.max_rpm,
|
||||
self.server_name,
|
||||
self.max_rpm,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' "
|
||||
f"({self.max_rpm} requests/minute)"
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' ({self.max_rpm} requests/minute)"
|
||||
)
|
||||
|
||||
# Resolve model
|
||||
@@ -458,6 +489,7 @@ class SamplingHandler:
|
||||
|
||||
# Get auxiliary LLM client
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
|
||||
client, default_model = get_text_auxiliary_client()
|
||||
if client is None:
|
||||
self.metrics["errors"] += 1
|
||||
@@ -469,7 +501,8 @@ class SamplingHandler:
|
||||
if self.allowed_models and resolved_model not in self.allowed_models:
|
||||
logger.warning(
|
||||
"MCP server '%s' requested model '%s' not in allowed_models",
|
||||
self.server_name, resolved_model,
|
||||
self.server_name,
|
||||
resolved_model,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
@@ -515,7 +548,10 @@ class SamplingHandler:
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
|
||||
self.server_name, resolved_model, max_tokens, len(messages),
|
||||
self.server_name,
|
||||
resolved_model,
|
||||
max_tokens,
|
||||
len(messages),
|
||||
)
|
||||
|
||||
# Offload sync LLM call to thread (non-blocking)
|
||||
@@ -524,19 +560,15 @@ class SamplingHandler:
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(_sync_call), timeout=self.timeout,
|
||||
asyncio.to_thread(_sync_call),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call timed out after {self.timeout}s "
|
||||
f"for server '{self.server_name}'"
|
||||
)
|
||||
return self._error(f"Sampling LLM call timed out after {self.timeout}s for server '{self.server_name}'")
|
||||
except Exception as exc:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
|
||||
)
|
||||
return self._error(f"Sampling LLM call failed: {_sanitize_error(str(exc))}")
|
||||
|
||||
# Track metrics
|
||||
choice = response.choices[0]
|
||||
@@ -546,11 +578,7 @@ class SamplingHandler:
|
||||
self.metrics["tokens_used"] += total_tokens
|
||||
|
||||
# Dispatch based on response type
|
||||
if (
|
||||
choice.finish_reason == "tool_calls"
|
||||
and hasattr(choice.message, "tool_calls")
|
||||
and choice.message.tool_calls
|
||||
):
|
||||
if choice.finish_reason == "tool_calls" and hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
|
||||
return self._build_tool_use_result(choice, response)
|
||||
|
||||
return self._build_text_result(choice, response)
|
||||
@@ -560,6 +588,7 @@ class SamplingHandler:
|
||||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MCPServerTask:
|
||||
"""Manages a single MCP server connection in a dedicated asyncio Task.
|
||||
|
||||
@@ -571,22 +600,29 @@ class MCPServerTask:
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"name",
|
||||
"session",
|
||||
"tool_timeout",
|
||||
"_task",
|
||||
"_ready",
|
||||
"_shutdown_event",
|
||||
"_tools",
|
||||
"_error",
|
||||
"_config",
|
||||
"_sampling",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.session: Optional[Any] = None
|
||||
self.session: Any | None = None
|
||||
self.tool_timeout: float = _DEFAULT_TOOL_TIMEOUT
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._task: asyncio.Task | None = None
|
||||
self._ready = asyncio.Event()
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
self._error: Exception | None = None
|
||||
self._config: dict = {}
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
self._sampling: SamplingHandler | None = None
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
@@ -599,9 +635,7 @@ class MCPServerTask:
|
||||
user_env = config.get("env")
|
||||
|
||||
if not command:
|
||||
raise ValueError(
|
||||
f"MCP server '{self.name}' has no 'command' in config"
|
||||
)
|
||||
raise ValueError(f"MCP server '{self.name}' has no 'command' in config")
|
||||
|
||||
safe_env = _build_safe_env(user_env)
|
||||
server_params = StdioServerParameters(
|
||||
@@ -650,11 +684,7 @@ class MCPServerTask:
|
||||
if self.session is None:
|
||||
return
|
||||
tools_result = await self.session.list_tools()
|
||||
self._tools = (
|
||||
tools_result.tools
|
||||
if hasattr(tools_result, "tools")
|
||||
else []
|
||||
)
|
||||
self._tools = tools_result.tools if hasattr(tools_result, "tools") else []
|
||||
|
||||
async def run(self, config: dict):
|
||||
"""Long-lived coroutine: connect, discover tools, wait, disconnect.
|
||||
@@ -704,24 +734,28 @@ class MCPServerTask:
|
||||
if self._shutdown_event.is_set():
|
||||
logger.debug(
|
||||
"MCP server '%s' disconnected during shutdown: %s",
|
||||
self.name, exc,
|
||||
self.name,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
retries += 1
|
||||
if retries > _MAX_RECONNECT_RETRIES:
|
||||
logger.warning(
|
||||
"MCP server '%s' failed after %d reconnection attempts, "
|
||||
"giving up: %s",
|
||||
self.name, _MAX_RECONNECT_RETRIES, exc,
|
||||
"MCP server '%s' failed after %d reconnection attempts, giving up: %s",
|
||||
self.name,
|
||||
_MAX_RECONNECT_RETRIES,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"MCP server '%s' connection lost (attempt %d/%d), "
|
||||
"reconnecting in %.0fs: %s",
|
||||
self.name, retries, _MAX_RECONNECT_RETRIES,
|
||||
backoff, exc,
|
||||
"MCP server '%s' connection lost (attempt %d/%d), reconnecting in %.0fs: %s",
|
||||
self.name,
|
||||
retries,
|
||||
_MAX_RECONNECT_RETRIES,
|
||||
backoff,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * 2, _MAX_BACKOFF_SECONDS)
|
||||
@@ -745,7 +779,7 @@ class MCPServerTask:
|
||||
if self._task and not self._task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"MCP server '%s' shutdown timed out, cancelling task",
|
||||
self.name,
|
||||
@@ -762,11 +796,11 @@ class MCPServerTask:
|
||||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_servers: Dict[str, MCPServerTask] = {}
|
||||
_servers: dict[str, MCPServerTask] = {}
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
_mcp_loop: asyncio.AbstractEventLoop | None = None
|
||||
_mcp_thread: threading.Thread | None = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
|
||||
_lock = threading.Lock()
|
||||
@@ -801,7 +835,8 @@ def _run_on_mcp_loop(coro, timeout: float = 30):
|
||||
# Config loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_mcp_config() -> Dict[str, dict]:
|
||||
|
||||
def _load_mcp_config() -> dict[str, dict]:
|
||||
"""Read ``mcp_servers`` from the Hermes config file.
|
||||
|
||||
Returns a dict of ``{server_name: server_config}`` or empty dict.
|
||||
@@ -811,6 +846,7 @@ def _load_mcp_config() -> Dict[str, dict]:
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
@@ -825,6 +861,7 @@ def _load_mcp_config() -> Dict[str, dict]:
|
||||
# Server connection helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
"""Create an MCPServerTask, start it, and return when ready.
|
||||
|
||||
@@ -845,6 +882,7 @@ async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
# Handler / check-fn factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
"""Return a sync handler that calls an MCP tool via the background loop.
|
||||
|
||||
@@ -856,27 +894,21 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.call_tool(tool_name, arguments=args)
|
||||
# MCP CallToolResult has .content (list of content blocks) and .isError
|
||||
if result.isError:
|
||||
error_text = ""
|
||||
for block in (result.content or []):
|
||||
for block in result.content or []:
|
||||
if hasattr(block, "text"):
|
||||
error_text += block.text
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
error_text or "MCP tool returned an error"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(error_text or "MCP tool returned an error")})
|
||||
|
||||
# Collect text from content blocks
|
||||
parts: List[str] = []
|
||||
for block in (result.content or []):
|
||||
parts: list[str] = []
|
||||
for block in result.content or []:
|
||||
if hasattr(block, "text"):
|
||||
parts.append(block.text)
|
||||
return json.dumps({"result": "\n".join(parts) if parts else ""})
|
||||
@@ -886,13 +918,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP tool %s/%s call failed: %s",
|
||||
server_name, tool_name, exc,
|
||||
server_name,
|
||||
tool_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -904,14 +934,12 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_resources()
|
||||
resources = []
|
||||
for r in (result.resources if hasattr(result, "resources") else []):
|
||||
for r in result.resources if hasattr(result, "resources") else []:
|
||||
entry = {}
|
||||
if hasattr(r, "uri"):
|
||||
entry["uri"] = str(r.uri)
|
||||
@@ -928,13 +956,11 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/list_resources failed: %s", server_name, exc,
|
||||
"MCP %s/list_resources failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -946,9 +972,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
uri = args.get("uri")
|
||||
if not uri:
|
||||
@@ -957,7 +981,7 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
async def _call():
|
||||
result = await server.session.read_resource(uri)
|
||||
# read_resource returns ReadResourceResult with .contents list
|
||||
parts: List[str] = []
|
||||
parts: list[str] = []
|
||||
contents = result.contents if hasattr(result, "contents") else []
|
||||
for block in contents:
|
||||
if hasattr(block, "text"):
|
||||
@@ -970,13 +994,11 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/read_resource failed: %s", server_name, exc,
|
||||
"MCP %s/read_resource failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -988,14 +1010,12 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.list_prompts()
|
||||
prompts = []
|
||||
for p in (result.prompts if hasattr(result, "prompts") else []):
|
||||
for p in result.prompts if hasattr(result, "prompts") else []:
|
||||
entry = {}
|
||||
if hasattr(p, "name"):
|
||||
entry["name"] = p.name
|
||||
@@ -1017,13 +1037,11 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/list_prompts failed: %s", server_name, exc,
|
||||
"MCP %s/list_prompts failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -1035,9 +1053,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
return json.dumps({"error": f"MCP server '{server_name}' is not connected"})
|
||||
|
||||
name = args.get("name")
|
||||
if not name:
|
||||
@@ -1048,7 +1064,7 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
result = await server.session.get_prompt(name, arguments=arguments)
|
||||
# GetPromptResult has .messages list
|
||||
messages = []
|
||||
for msg in (result.messages if hasattr(result, "messages") else []):
|
||||
for msg in result.messages if hasattr(result, "messages") else []:
|
||||
entry = {}
|
||||
if hasattr(msg, "role"):
|
||||
entry["role"] = msg.role
|
||||
@@ -1070,13 +1086,11 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"MCP %s/get_prompt failed: %s", server_name, exc,
|
||||
"MCP %s/get_prompt failed: %s",
|
||||
server_name,
|
||||
exc,
|
||||
)
|
||||
return json.dumps({
|
||||
"error": _sanitize_error(
|
||||
f"MCP call failed: {type(exc).__name__}: {exc}"
|
||||
)
|
||||
})
|
||||
return json.dumps({"error": _sanitize_error(f"MCP call failed: {type(exc).__name__}: {exc}")})
|
||||
|
||||
return _handler
|
||||
|
||||
@@ -1096,6 +1110,7 @@ def _make_check_fn(server_name: str):
|
||||
# Discovery & registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
"""Convert an MCP tool listing to the Hermes registry schema format.
|
||||
|
||||
@@ -1114,14 +1129,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
return {
|
||||
"name": prefixed_name,
|
||||
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
|
||||
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
|
||||
"parameters": mcp_tool.inputSchema
|
||||
if mcp_tool.inputSchema
|
||||
else {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
def _build_utility_schemas(server_name: str) -> list[dict]:
|
||||
"""Build schemas for the MCP utility tools (resources & prompts).
|
||||
|
||||
Returns a list of (schema, handler_factory_name) tuples encoded as dicts
|
||||
@@ -1192,9 +1209,9 @@ def _build_utility_schemas(server_name: str) -> List[dict]:
|
||||
]
|
||||
|
||||
|
||||
def _existing_tool_names() -> List[str]:
|
||||
def _existing_tool_names() -> list[str]:
|
||||
"""Return tool names for all currently connected servers."""
|
||||
names: List[str] = []
|
||||
names: list[str] = []
|
||||
for sname, server in _servers.items():
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(sname, mcp_tool)
|
||||
@@ -1205,7 +1222,7 @@ def _existing_tool_names() -> List[str]:
|
||||
return names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
async def _discover_and_register_server(name: str, config: dict) -> list[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
|
||||
Also registers utility tools for MCP Resources and Prompts support
|
||||
@@ -1224,7 +1241,7 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
|
||||
registered_names: List[str] = []
|
||||
registered_names: list[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
|
||||
for mcp_tool in server._tools:
|
||||
@@ -1277,7 +1294,9 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
transport_type = "HTTP" if "url" in config else "stdio"
|
||||
logger.info(
|
||||
"MCP server '%s' (%s): registered %d tool(s): %s",
|
||||
name, transport_type, len(registered_names),
|
||||
name,
|
||||
transport_type,
|
||||
len(registered_names),
|
||||
", ".join(registered_names),
|
||||
)
|
||||
return registered_names
|
||||
@@ -1287,7 +1306,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def discover_mcp_tools() -> List[str]:
|
||||
|
||||
def discover_mcp_tools() -> list[str]:
|
||||
"""Entry point: load config, connect to MCP servers, register tools.
|
||||
|
||||
Called from ``model_tools._discover_tools()``. Safe to call even when
|
||||
@@ -1318,12 +1338,12 @@ def discover_mcp_tools() -> List[str]:
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
all_tools: List[str] = []
|
||||
all_tools: list[str] = []
|
||||
failed_count = 0
|
||||
|
||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||
async def _discover_one(name: str, cfg: dict) -> list[str]:
|
||||
"""Connect to a single server and return its registered tool names."""
|
||||
transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}')
|
||||
transport_desc = cfg.get("url", f"{cfg.get('command', '?')} {' '.join(cfg.get('args', [])[:2])}")
|
||||
try:
|
||||
registered = await _discover_and_register_server(name, cfg)
|
||||
transport_type = "HTTP" if "url" in cfg else "stdio"
|
||||
@@ -1331,7 +1351,8 @@ def discover_mcp_tools() -> List[str]:
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to connect to MCP server '%s': %s",
|
||||
name, exc,
|
||||
name,
|
||||
exc,
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -1358,6 +1379,7 @@ def discover_mcp_tools() -> List[str]:
|
||||
if all_tools:
|
||||
# Dynamically inject into all hermes-* platform toolsets
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
for tool_name in all_tools:
|
||||
@@ -1377,13 +1399,13 @@ def discover_mcp_tools() -> List[str]:
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def get_mcp_status() -> List[dict]:
|
||||
def get_mcp_status() -> list[dict]:
|
||||
"""Return status of all configured MCP servers for banner display.
|
||||
|
||||
Returns a list of dicts with keys: name, transport, tools, connected.
|
||||
Includes both successfully connected servers and configured-but-failed ones.
|
||||
"""
|
||||
result: List[dict] = []
|
||||
result: list[dict] = []
|
||||
|
||||
# Get configured servers from config
|
||||
configured = _load_mcp_config()
|
||||
@@ -1407,12 +1429,14 @@ def get_mcp_status() -> List[dict]:
|
||||
entry["sampling"] = dict(server._sampling.metrics)
|
||||
result.append(entry)
|
||||
else:
|
||||
result.append({
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": 0,
|
||||
"connected": False,
|
||||
})
|
||||
result.append(
|
||||
{
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": 0,
|
||||
"connected": False,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -1440,7 +1464,9 @@ def shutdown_mcp_servers():
|
||||
for server, result in zip(servers_snapshot, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.debug(
|
||||
"Error closing MCP server '%s': %s", server.name, result,
|
||||
"Error closing MCP server '%s': %s",
|
||||
server.name,
|
||||
result,
|
||||
)
|
||||
with _lock:
|
||||
_servers.clear()
|
||||
|
||||
@@ -29,7 +29,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,30 +46,38 @@ ENTRY_DELIMITER = "\n§\n"
|
||||
|
||||
_MEMORY_THREAT_PATTERNS = [
|
||||
# Prompt injection
|
||||
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
|
||||
(r'you\s+are\s+now\s+', "role_hijack"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
|
||||
(r"ignore\s+(previous|all|above|prior)\s+instructions", "prompt_injection"),
|
||||
(r"you\s+are\s+now\s+", "role_hijack"),
|
||||
(r"do\s+not\s+tell\s+the\s+user", "deception_hide"),
|
||||
(r"system\s+prompt\s+override", "sys_prompt_override"),
|
||||
(r"disregard\s+(your|all|any)\s+(instructions|rules|guidelines)", "disregard_rules"),
|
||||
(r"act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)", "bypass_restrictions"),
|
||||
# Exfiltration via curl/wget with secrets
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', "read_secrets"),
|
||||
(r"curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_curl"),
|
||||
(r"wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)", "exfil_wget"),
|
||||
(r"cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)", "read_secrets"),
|
||||
# Persistence via shell rc
|
||||
(r'authorized_keys', "ssh_backdoor"),
|
||||
(r'\$HOME/\.ssh|\~/\.ssh', "ssh_access"),
|
||||
(r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', "hermes_env"),
|
||||
(r"authorized_keys", "ssh_backdoor"),
|
||||
(r"\$HOME/\.ssh|\~/\.ssh", "ssh_access"),
|
||||
(r"\$HOME/\.hermes/\.env|\~/\.hermes/\.env", "hermes_env"),
|
||||
]
|
||||
|
||||
# Subset of invisible chars for injection detection
|
||||
_INVISIBLE_CHARS = {
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
"\u200b",
|
||||
"\u200c",
|
||||
"\u200d",
|
||||
"\u2060",
|
||||
"\ufeff",
|
||||
"\u202a",
|
||||
"\u202b",
|
||||
"\u202c",
|
||||
"\u202d",
|
||||
"\u202e",
|
||||
}
|
||||
|
||||
|
||||
def _scan_memory_content(content: str) -> Optional[str]:
|
||||
def _scan_memory_content(content: str) -> str | None:
|
||||
"""Scan memory content for injection/exfil patterns. Returns error string if blocked."""
|
||||
# Check invisible unicode
|
||||
for char in _INVISIBLE_CHARS:
|
||||
@@ -96,12 +104,12 @@ class MemoryStore:
|
||||
"""
|
||||
|
||||
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375):
|
||||
self.memory_entries: List[str] = []
|
||||
self.user_entries: List[str] = []
|
||||
self.memory_entries: list[str] = []
|
||||
self.user_entries: list[str] = []
|
||||
self.memory_char_limit = memory_char_limit
|
||||
self.user_char_limit = user_char_limit
|
||||
# Frozen snapshot for system prompt -- set once at load_from_disk()
|
||||
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""}
|
||||
self._system_prompt_snapshot: dict[str, str] = {"memory": "", "user": ""}
|
||||
|
||||
def load_from_disk(self):
|
||||
"""Load entries from MEMORY.md and USER.md, capture system prompt snapshot."""
|
||||
@@ -129,12 +137,12 @@ class MemoryStore:
|
||||
elif target == "user":
|
||||
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
|
||||
|
||||
def _entries_for(self, target: str) -> List[str]:
|
||||
def _entries_for(self, target: str) -> list[str]:
|
||||
if target == "user":
|
||||
return self.user_entries
|
||||
return self.memory_entries
|
||||
|
||||
def _set_entries(self, target: str, entries: List[str]):
|
||||
def _set_entries(self, target: str, entries: list[str]):
|
||||
if target == "user":
|
||||
self.user_entries = entries
|
||||
else:
|
||||
@@ -151,7 +159,7 @@ class MemoryStore:
|
||||
return self.user_char_limit
|
||||
return self.memory_char_limit
|
||||
|
||||
def add(self, target: str, content: str) -> Dict[str, Any]:
|
||||
def add(self, target: str, content: str) -> dict[str, Any]:
|
||||
"""Append a new entry. Returns error if it would exceed the char limit."""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
@@ -192,7 +200,7 @@ class MemoryStore:
|
||||
|
||||
return self._success_response(target, "Entry added.")
|
||||
|
||||
def replace(self, target: str, old_text: str, new_content: str) -> Dict[str, Any]:
|
||||
def replace(self, target: str, old_text: str, new_content: str) -> dict[str, Any]:
|
||||
"""Find entry containing old_text substring, replace it with new_content."""
|
||||
old_text = old_text.strip()
|
||||
new_content = new_content.strip()
|
||||
@@ -247,7 +255,7 @@ class MemoryStore:
|
||||
|
||||
return self._success_response(target, "Entry replaced.")
|
||||
|
||||
def remove(self, target: str, old_text: str) -> Dict[str, Any]:
|
||||
def remove(self, target: str, old_text: str) -> dict[str, Any]:
|
||||
"""Remove the entry containing old_text substring."""
|
||||
old_text = old_text.strip()
|
||||
if not old_text:
|
||||
@@ -278,7 +286,7 @@ class MemoryStore:
|
||||
|
||||
return self._success_response(target, "Entry removed.")
|
||||
|
||||
def format_for_system_prompt(self, target: str) -> Optional[str]:
|
||||
def format_for_system_prompt(self, target: str) -> str | None:
|
||||
"""
|
||||
Return the frozen snapshot for system prompt injection.
|
||||
|
||||
@@ -293,7 +301,7 @@ class MemoryStore:
|
||||
|
||||
# -- Internal helpers --
|
||||
|
||||
def _success_response(self, target: str, message: str = None) -> Dict[str, Any]:
|
||||
def _success_response(self, target: str, message: str = None) -> dict[str, Any]:
|
||||
entries = self._entries_for(target)
|
||||
current = self._char_count(target)
|
||||
limit = self._char_limit(target)
|
||||
@@ -310,7 +318,7 @@ class MemoryStore:
|
||||
resp["message"] = message
|
||||
return resp
|
||||
|
||||
def _render_block(self, target: str, entries: List[str]) -> str:
|
||||
def _render_block(self, target: str, entries: list[str]) -> str:
|
||||
"""Render a system prompt block with header and usage indicator."""
|
||||
if not entries:
|
||||
return ""
|
||||
@@ -329,7 +337,7 @@ class MemoryStore:
|
||||
return f"{separator}\n{header}\n{separator}\n{content}"
|
||||
|
||||
@staticmethod
|
||||
def _read_file(path: Path) -> List[str]:
|
||||
def _read_file(path: Path) -> list[str]:
|
||||
"""Read a memory file and split into entries.
|
||||
|
||||
No file locking needed: _write_file uses atomic rename, so readers
|
||||
@@ -339,7 +347,7 @@ class MemoryStore:
|
||||
return []
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except (OSError, IOError):
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
if not raw.strip():
|
||||
@@ -351,7 +359,7 @@ class MemoryStore:
|
||||
return [e for e in entries if e]
|
||||
|
||||
@staticmethod
|
||||
def _write_file(path: Path, entries: List[str]):
|
||||
def _write_file(path: Path, entries: list[str]):
|
||||
"""Write entries to a memory file using atomic temp-file + rename.
|
||||
|
||||
Previous implementation used open("w") + flock, but "w" truncates the
|
||||
@@ -362,9 +370,7 @@ class MemoryStore:
|
||||
content = ENTRY_DELIMITER.join(entries) if entries else ""
|
||||
try:
|
||||
# Write to temp file in same directory (same filesystem for atomic rename)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(path.parent), suffix=".tmp", prefix=".mem_"
|
||||
)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp", prefix=".mem_")
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
@@ -378,7 +384,7 @@ class MemoryStore:
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to write memory file {path}: {e}")
|
||||
|
||||
|
||||
@@ -387,7 +393,7 @@ def memory_tool(
|
||||
target: str = "memory",
|
||||
content: str = None,
|
||||
old_text: str = None,
|
||||
store: Optional[MemoryStore] = None,
|
||||
store: MemoryStore | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Single entry point for the memory tool. Dispatches to MemoryStore methods.
|
||||
@@ -395,10 +401,15 @@ def memory_tool(
|
||||
Returns JSON string with results.
|
||||
"""
|
||||
if store is None:
|
||||
return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "Memory is not available. It may be disabled in config or this environment."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
if target not in ("memory", "user"):
|
||||
return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False
|
||||
)
|
||||
|
||||
if action == "add":
|
||||
if not content:
|
||||
@@ -407,18 +418,26 @@ def memory_tool(
|
||||
|
||||
elif action == "replace":
|
||||
if not old_text:
|
||||
return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False
|
||||
)
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False
|
||||
)
|
||||
result = store.replace(target, old_text, content)
|
||||
|
||||
elif action == "remove":
|
||||
if not old_text:
|
||||
return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False
|
||||
)
|
||||
result = store.remove(target, old_text)
|
||||
|
||||
else:
|
||||
return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False
|
||||
)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
@@ -457,23 +476,16 @@ MEMORY_SCHEMA = {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "replace", "remove"],
|
||||
"description": "The action to perform."
|
||||
},
|
||||
"action": {"type": "string", "enum": ["add", "replace", "remove"], "description": "The action to perform."},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["memory", "user"],
|
||||
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The entry content. Required for 'add' and 'replace'."
|
||||
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile.",
|
||||
},
|
||||
"content": {"type": "string", "description": "The entry content. Required for 'add' and 'replace'."},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "Short unique substring identifying the entry to replace or remove."
|
||||
"description": "Short unique substring identifying the entry to replace or remove.",
|
||||
},
|
||||
},
|
||||
"required": ["action", "target"],
|
||||
@@ -493,10 +505,7 @@ registry.register(
|
||||
target=args.get("target", "memory"),
|
||||
content=args.get("content"),
|
||||
old_text=args.get("old_text"),
|
||||
store=kw.get("store")),
|
||||
store=kw.get("store"),
|
||||
),
|
||||
check_fn=check_memory_requirements,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -38,21 +38,27 @@ Configuration:
|
||||
Usage:
|
||||
from mixture_of_agents_tool import mixture_of_agents_tool
|
||||
import asyncio
|
||||
|
||||
|
||||
# Process a complex query
|
||||
result = await mixture_of_agents_tool(
|
||||
user_prompt="Solve this complex mathematical proof..."
|
||||
)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key
|
||||
from typing import Any
|
||||
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.openrouter_client import (
|
||||
check_api_key as check_openrouter_api_key,
|
||||
)
|
||||
from tools.openrouter_client import (
|
||||
get_async_client as _get_openrouter_client,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,9 +66,9 @@ logger = logging.getLogger(__name__)
|
||||
# Reference models - these generate diverse initial responses in parallel (OpenRouter slugs)
|
||||
REFERENCE_MODELS = [
|
||||
"anthropic/claude-opus-4.5",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-pro-preview",
|
||||
"openai/gpt-5.2-pro",
|
||||
"deepseek/deepseek-v3.2"
|
||||
"deepseek/deepseek-v3.2",
|
||||
]
|
||||
|
||||
# Aggregator model - synthesizes reference responses into final output
|
||||
@@ -83,18 +89,18 @@ Responses from models:"""
|
||||
_debug = DebugSession("moa_tools", env_var="MOA_TOOLS_DEBUG")
|
||||
|
||||
|
||||
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
|
||||
def _construct_aggregator_prompt(system_prompt: str, responses: list[str]) -> str:
|
||||
"""
|
||||
Construct the final system prompt for the aggregator including all model responses.
|
||||
|
||||
|
||||
Args:
|
||||
system_prompt (str): Base system prompt for aggregation
|
||||
responses (List[str]): List of responses from reference models
|
||||
|
||||
|
||||
Returns:
|
||||
str: Complete system prompt with enumerated responses
|
||||
"""
|
||||
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
|
||||
response_text = "\n".join([f"{i + 1}. {response}" for i, response in enumerate(responses)])
|
||||
return f"{system_prompt}\n\n{response_text}"
|
||||
|
||||
|
||||
@@ -103,48 +109,43 @@ async def _run_reference_model_safe(
|
||||
user_prompt: str,
|
||||
temperature: float = REFERENCE_TEMPERATURE,
|
||||
max_tokens: int = 32000,
|
||||
max_retries: int = 6
|
||||
max_retries: int = 6,
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
Run a single reference model with retry logic and graceful failure handling.
|
||||
|
||||
|
||||
Args:
|
||||
model (str): Model identifier to use
|
||||
user_prompt (str): The user's query
|
||||
temperature (float): Sampling temperature for response generation
|
||||
max_tokens (int): Maximum tokens in response
|
||||
max_retries (int): Maximum number of retry attempts
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info("Querying %s (attempt %s/%s)", model, attempt + 1, max_retries)
|
||||
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": user_prompt}],
|
||||
"extra_body": {
|
||||
"reasoning": {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
}
|
||||
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
|
||||
}
|
||||
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not model.lower().startswith('gpt-'):
|
||||
if not model.lower().startswith("gpt-"):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("%s responded (%s characters)", model, len(content))
|
||||
return model, content, True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Log more detailed error information for debugging
|
||||
@@ -154,7 +155,7 @@ async def _run_reference_model_safe(
|
||||
logger.warning("%s rate limit error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
else:
|
||||
logger.warning("%s unknown error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
|
||||
sleep_time = min(2 ** (attempt + 1), 60)
|
||||
@@ -167,60 +168,47 @@ async def _run_reference_model_safe(
|
||||
|
||||
|
||||
async def _run_aggregator_model(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float = AGGREGATOR_TEMPERATURE,
|
||||
max_tokens: int = None
|
||||
system_prompt: str, user_prompt: str, temperature: float = AGGREGATOR_TEMPERATURE, max_tokens: int = None
|
||||
) -> str:
|
||||
"""
|
||||
Run the aggregator model to synthesize the final response.
|
||||
|
||||
|
||||
Args:
|
||||
system_prompt (str): System prompt with all reference responses
|
||||
user_prompt (str): Original user query
|
||||
temperature (float): Focused temperature for consistent aggregation
|
||||
max_tokens (int): Maximum tokens in final response
|
||||
|
||||
|
||||
Returns:
|
||||
str: Synthesized final response
|
||||
"""
|
||||
logger.info("Running aggregator model: %s", AGGREGATOR_MODEL)
|
||||
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": AGGREGATOR_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"extra_body": {
|
||||
"reasoning": {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
}
|
||||
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
||||
"extra_body": {"reasoning": {"enabled": True, "effort": "xhigh"}},
|
||||
}
|
||||
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
|
||||
if not AGGREGATOR_MODEL.lower().startswith("gpt-"):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("Aggregation complete (%s characters)", len(content))
|
||||
return content
|
||||
|
||||
|
||||
async def mixture_of_agents_tool(
|
||||
user_prompt: str,
|
||||
reference_models: Optional[List[str]] = None,
|
||||
aggregator_model: Optional[str] = None
|
||||
user_prompt: str, reference_models: list[str] | None = None, aggregator_model: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Process a complex query using the Mixture-of-Agents methodology.
|
||||
|
||||
|
||||
This tool leverages multiple frontier language models to collaboratively solve
|
||||
extremely difficult problems requiring intense reasoning. It's particularly
|
||||
effective for:
|
||||
@@ -229,16 +217,16 @@ async def mixture_of_agents_tool(
|
||||
- Multi-step analytical reasoning tasks
|
||||
- Problems requiring diverse domain expertise
|
||||
- Tasks where single models show limitations
|
||||
|
||||
|
||||
The MoA approach uses a fixed 2-layer architecture:
|
||||
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
|
||||
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
|
||||
|
||||
|
||||
Args:
|
||||
user_prompt (str): The complex query or problem to solve
|
||||
reference_models (Optional[List[str]]): Custom reference models to use
|
||||
aggregator_model (Optional[str]): Custom aggregator model to use
|
||||
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the MoA results with the following structure:
|
||||
{
|
||||
@@ -250,12 +238,12 @@ async def mixture_of_agents_tool(
|
||||
},
|
||||
"processing_time": float
|
||||
}
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If MoA processing fails or API key is not set
|
||||
"""
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||
@@ -263,7 +251,7 @@ async def mixture_of_agents_tool(
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
@@ -272,161 +260,152 @@ async def mixture_of_agents_tool(
|
||||
"failed_models": [],
|
||||
"final_response_length": 0,
|
||||
"processing_time_seconds": 0,
|
||||
"models_used": {}
|
||||
"models_used": {},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Starting Mixture-of-Agents processing...")
|
||||
logger.info("Query: %s", user_prompt[:100])
|
||||
|
||||
|
||||
# Validate API key availability
|
||||
if not os.getenv("OPENROUTER_API_KEY"):
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable not set")
|
||||
|
||||
|
||||
# Use provided models or defaults
|
||||
ref_models = reference_models or REFERENCE_MODELS
|
||||
agg_model = aggregator_model or AGGREGATOR_MODEL
|
||||
|
||||
|
||||
logger.info("Using %s reference models in 2-layer MoA architecture", len(ref_models))
|
||||
|
||||
|
||||
# Layer 1: Generate diverse responses from reference models (with failure handling)
|
||||
logger.info("Layer 1: Generating reference responses...")
|
||||
model_results = await asyncio.gather(*[
|
||||
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
|
||||
for model in ref_models
|
||||
])
|
||||
|
||||
model_results = await asyncio.gather(
|
||||
*[_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE) for model in ref_models]
|
||||
)
|
||||
|
||||
# Separate successful and failed responses
|
||||
successful_responses = []
|
||||
failed_models = []
|
||||
|
||||
|
||||
for model_name, content, success in model_results:
|
||||
if success:
|
||||
successful_responses.append(content)
|
||||
else:
|
||||
failed_models.append(model_name)
|
||||
|
||||
|
||||
successful_count = len(successful_responses)
|
||||
failed_count = len(failed_models)
|
||||
|
||||
|
||||
logger.info("Reference model results: %s successful, %s failed", successful_count, failed_count)
|
||||
|
||||
|
||||
if failed_models:
|
||||
logger.warning("Failed models: %s", ', '.join(failed_models))
|
||||
|
||||
logger.warning("Failed models: %s", ", ".join(failed_models))
|
||||
|
||||
# Check if we have enough successful responses to proceed
|
||||
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
|
||||
|
||||
raise ValueError(
|
||||
f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses."
|
||||
)
|
||||
|
||||
debug_call_data["reference_responses_count"] = successful_count
|
||||
debug_call_data["failed_models_count"] = failed_count
|
||||
debug_call_data["failed_models"] = failed_models
|
||||
|
||||
|
||||
# Layer 2: Aggregate responses using the aggregator model
|
||||
logger.info("Layer 2: Synthesizing final response...")
|
||||
aggregator_system_prompt = _construct_aggregator_prompt(
|
||||
AGGREGATOR_SYSTEM_PROMPT,
|
||||
successful_responses
|
||||
)
|
||||
|
||||
final_response = await _run_aggregator_model(
|
||||
aggregator_system_prompt,
|
||||
user_prompt,
|
||||
AGGREGATOR_TEMPERATURE
|
||||
)
|
||||
|
||||
aggregator_system_prompt = _construct_aggregator_prompt(AGGREGATOR_SYSTEM_PROMPT, successful_responses)
|
||||
|
||||
final_response = await _run_aggregator_model(aggregator_system_prompt, user_prompt, AGGREGATOR_TEMPERATURE)
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
logger.info("MoA processing completed in %.2f seconds", processing_time)
|
||||
|
||||
|
||||
# Prepare successful response (only final aggregated result, minimal fields)
|
||||
result = {
|
||||
"success": True,
|
||||
"response": final_response,
|
||||
"models_used": {
|
||||
"reference_models": ref_models,
|
||||
"aggregator_model": agg_model
|
||||
}
|
||||
"models_used": {"reference_models": ref_models, "aggregator_model": agg_model},
|
||||
}
|
||||
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["final_response_length"] = len(final_response)
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
debug_call_data["models_used"] = result["models_used"]
|
||||
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("mixture_of_agents_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in MoA processing: {str(e)}"
|
||||
logger.error("%s", error_msg)
|
||||
|
||||
|
||||
# Calculate processing time even for errors
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
# Prepare error response (minimal fields)
|
||||
result = {
|
||||
"success": False,
|
||||
"response": "MoA processing failed. Please try again or use a single model for this query.",
|
||||
"models_used": {
|
||||
"reference_models": reference_models or REFERENCE_MODELS,
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||
},
|
||||
"error": error_msg
|
||||
"error": error_msg,
|
||||
}
|
||||
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
_debug.log_call("mixture_of_agents_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_moa_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for MoA tools are met.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
return check_openrouter_api_key()
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
def get_debug_session_info() -> dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
return _debug.get_session_info()
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, List[str]]:
|
||||
def get_available_models() -> dict[str, list[str]]:
|
||||
"""
|
||||
Get information about available models for MoA processing.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Dictionary with reference and aggregator models
|
||||
"""
|
||||
return {
|
||||
"reference_models": REFERENCE_MODELS,
|
||||
"aggregator_models": [AGGREGATOR_MODEL],
|
||||
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
|
||||
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL],
|
||||
}
|
||||
|
||||
|
||||
def get_moa_configuration() -> Dict[str, Any]:
|
||||
def get_moa_configuration() -> dict[str, Any]:
|
||||
"""
|
||||
Get the current MoA configuration settings.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing all configuration parameters
|
||||
"""
|
||||
@@ -437,7 +416,7 @@ def get_moa_configuration() -> Dict[str, Any]:
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||
"total_reference_models": len(REFERENCE_MODELS),
|
||||
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
|
||||
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail",
|
||||
}
|
||||
|
||||
|
||||
@@ -447,10 +426,10 @@ if __name__ == "__main__":
|
||||
"""
|
||||
print("🤖 Mixture-of-Agents Tool Module")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_openrouter_api_key()
|
||||
|
||||
|
||||
if not api_available:
|
||||
print("❌ OPENROUTER_API_KEY environment variable not set")
|
||||
print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
|
||||
@@ -458,26 +437,26 @@ if __name__ == "__main__":
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ OpenRouter API key found")
|
||||
|
||||
|
||||
print("🛠️ MoA tools ready for use!")
|
||||
|
||||
|
||||
# Show current configuration
|
||||
config = get_moa_configuration()
|
||||
print(f"\n⚙️ Current Configuration:")
|
||||
print("\n⚙️ Current Configuration:")
|
||||
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
|
||||
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
|
||||
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
|
||||
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
|
||||
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
|
||||
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
|
||||
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"\n🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
|
||||
print(" import asyncio")
|
||||
@@ -488,24 +467,26 @@ if __name__ == "__main__":
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
|
||||
print("\nBest use cases:")
|
||||
print(" - Complex mathematical proofs and calculations")
|
||||
print(" - Advanced coding problems and algorithm design")
|
||||
print(" - Multi-step analytical reasoning tasks")
|
||||
print(" - Problems requiring diverse domain expertise")
|
||||
print(" - Tasks where single models show limitations")
|
||||
|
||||
|
||||
print("\nPerformance characteristics:")
|
||||
print(" - Higher latency due to multiple model calls")
|
||||
print(" - Significantly improved quality for complex tasks")
|
||||
print(" - Parallel processing for efficiency")
|
||||
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
|
||||
print(
|
||||
f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation"
|
||||
)
|
||||
print(" - Token-efficient: only returns final aggregated response")
|
||||
print(" - Resilient: continues with partial model failures")
|
||||
print(f" - Configurable: easy to modify models and settings at top of file")
|
||||
print(" - Configurable: easy to modify models and settings at top of file")
|
||||
print(" - State-of-the-art results on challenging benchmarks")
|
||||
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export MOA_TOOLS_DEBUG=true")
|
||||
@@ -526,11 +507,11 @@ MOA_SCHEMA = {
|
||||
"properties": {
|
||||
"user_prompt": {
|
||||
"type": "string",
|
||||
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning."
|
||||
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning.",
|
||||
}
|
||||
},
|
||||
"required": ["user_prompt"]
|
||||
}
|
||||
"required": ["user_prompt"],
|
||||
},
|
||||
}
|
||||
|
||||
registry.register(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Shared OpenRouter API client for Hermes tools.
|
||||
|
||||
Provides a single lazy-initialized AsyncOpenAI client that all tool modules
|
||||
can share, eliminating the duplicated _get_openrouter_client() /
|
||||
can share, eliminating the duplicated _get_openrouter_client() /
|
||||
_get_summarizer_client() pattern previously copy-pasted across web_tools,
|
||||
vision_tools, mixture_of_agents_tool, and session_search_tool.
|
||||
"""
|
||||
@@ -9,6 +9,7 @@ vision_tools, mixture_of_agents_tool, and session_search_tool.
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
_client: AsyncOpenAI | None = None
|
||||
|
||||
@@ -20,7 +20,7 @@ V4A Format:
|
||||
|
||||
Usage:
|
||||
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
|
||||
|
||||
|
||||
operations, error = parse_v4a_patch(patch_content)
|
||||
if error:
|
||||
print(f"Parse error: {error}")
|
||||
@@ -30,8 +30,8 @@ Usage:
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OperationType(Enum):
|
||||
@@ -44,6 +44,7 @@ class OperationType(Enum):
|
||||
@dataclass
|
||||
class HunkLine:
|
||||
"""A single line in a patch hunk."""
|
||||
|
||||
prefix: str # ' ', '-', or '+'
|
||||
content: str
|
||||
|
||||
@@ -51,182 +52,174 @@ class HunkLine:
|
||||
@dataclass
|
||||
class Hunk:
|
||||
"""A group of changes within a file."""
|
||||
context_hint: Optional[str] = None
|
||||
lines: List[HunkLine] = field(default_factory=list)
|
||||
|
||||
context_hint: str | None = None
|
||||
lines: list[HunkLine] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchOperation:
|
||||
"""A single operation in a V4A patch."""
|
||||
|
||||
operation: OperationType
|
||||
file_path: str
|
||||
new_path: Optional[str] = None # For move operations
|
||||
hunks: List[Hunk] = field(default_factory=list)
|
||||
content: Optional[str] = None # For add file operations
|
||||
new_path: str | None = None # For move operations
|
||||
hunks: list[Hunk] = field(default_factory=list)
|
||||
content: str | None = None # For add file operations
|
||||
|
||||
|
||||
def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]:
|
||||
def parse_v4a_patch(patch_content: str) -> tuple[list[PatchOperation], str | None]:
|
||||
"""
|
||||
Parse a V4A format patch.
|
||||
|
||||
|
||||
Args:
|
||||
patch_content: The patch text in V4A format
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (operations, error_message)
|
||||
- If successful: (list_of_operations, None)
|
||||
- If failed: ([], error_description)
|
||||
"""
|
||||
lines = patch_content.split('\n')
|
||||
operations: List[PatchOperation] = []
|
||||
|
||||
lines = patch_content.split("\n")
|
||||
operations: list[PatchOperation] = []
|
||||
|
||||
# Find patch boundaries
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if '*** Begin Patch' in line or '***Begin Patch' in line:
|
||||
if "*** Begin Patch" in line or "***Begin Patch" in line:
|
||||
start_idx = i
|
||||
elif '*** End Patch' in line or '***End Patch' in line:
|
||||
elif "*** End Patch" in line or "***End Patch" in line:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
|
||||
if start_idx is None:
|
||||
# Try to parse without explicit begin marker
|
||||
start_idx = -1
|
||||
|
||||
|
||||
if end_idx is None:
|
||||
end_idx = len(lines)
|
||||
|
||||
|
||||
# Parse operations between boundaries
|
||||
i = start_idx + 1
|
||||
current_op: Optional[PatchOperation] = None
|
||||
current_hunk: Optional[Hunk] = None
|
||||
|
||||
current_op: PatchOperation | None = None
|
||||
current_hunk: Hunk | None = None
|
||||
|
||||
while i < end_idx:
|
||||
line = lines[i]
|
||||
|
||||
|
||||
# Check for file operation markers
|
||||
update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line)
|
||||
add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line)
|
||||
delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line)
|
||||
move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line)
|
||||
|
||||
update_match = re.match(r"\*\*\*\s*Update\s+File:\s*(.+)", line)
|
||||
add_match = re.match(r"\*\*\*\s*Add\s+File:\s*(.+)", line)
|
||||
delete_match = re.match(r"\*\*\*\s*Delete\s+File:\s*(.+)", line)
|
||||
move_match = re.match(r"\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)", line)
|
||||
|
||||
if update_match:
|
||||
# Save previous operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.UPDATE,
|
||||
file_path=update_match.group(1).strip()
|
||||
)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.UPDATE, file_path=update_match.group(1).strip())
|
||||
current_hunk = None
|
||||
|
||||
|
||||
elif add_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.ADD,
|
||||
file_path=add_match.group(1).strip()
|
||||
)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.ADD, file_path=add_match.group(1).strip())
|
||||
current_hunk = Hunk()
|
||||
|
||||
|
||||
elif delete_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.DELETE,
|
||||
file_path=delete_match.group(1).strip()
|
||||
)
|
||||
|
||||
current_op = PatchOperation(operation=OperationType.DELETE, file_path=delete_match.group(1).strip())
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
|
||||
elif move_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.MOVE,
|
||||
file_path=move_match.group(1).strip(),
|
||||
new_path=move_match.group(2).strip()
|
||||
new_path=move_match.group(2).strip(),
|
||||
)
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
elif line.startswith('@@'):
|
||||
|
||||
elif line.startswith("@@"):
|
||||
# Context hint / hunk marker
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
|
||||
|
||||
# Extract context hint
|
||||
hint_match = re.match(r'@@\s*(.+?)\s*@@', line)
|
||||
hint_match = re.match(r"@@\s*(.+?)\s*@@", line)
|
||||
hint = hint_match.group(1) if hint_match else None
|
||||
current_hunk = Hunk(context_hint=hint)
|
||||
|
||||
|
||||
elif current_op and line:
|
||||
# Parse hunk line
|
||||
if current_hunk is None:
|
||||
current_hunk = Hunk()
|
||||
|
||||
if line.startswith('+'):
|
||||
current_hunk.lines.append(HunkLine('+', line[1:]))
|
||||
elif line.startswith('-'):
|
||||
current_hunk.lines.append(HunkLine('-', line[1:]))
|
||||
elif line.startswith(' '):
|
||||
current_hunk.lines.append(HunkLine(' ', line[1:]))
|
||||
elif line.startswith('\\'):
|
||||
|
||||
if line.startswith("+"):
|
||||
current_hunk.lines.append(HunkLine("+", line[1:]))
|
||||
elif line.startswith("-"):
|
||||
current_hunk.lines.append(HunkLine("-", line[1:]))
|
||||
elif line.startswith(" "):
|
||||
current_hunk.lines.append(HunkLine(" ", line[1:]))
|
||||
elif line.startswith("\\"):
|
||||
# "\ No newline at end of file" marker - skip
|
||||
pass
|
||||
else:
|
||||
# Treat as context line (implicit space prefix)
|
||||
current_hunk.lines.append(HunkLine(' ', line))
|
||||
|
||||
current_hunk.lines.append(HunkLine(" ", line))
|
||||
|
||||
i += 1
|
||||
|
||||
|
||||
# Don't forget the last operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
|
||||
return operations, None
|
||||
|
||||
|
||||
def apply_v4a_operations(operations: List[PatchOperation],
|
||||
file_ops: Any) -> 'PatchResult':
|
||||
def apply_v4a_operations(operations: list[PatchOperation], file_ops: Any) -> "PatchResult":
|
||||
"""
|
||||
Apply V4A patch operations using a file operations interface.
|
||||
|
||||
|
||||
Args:
|
||||
operations: List of PatchOperation from parse_v4a_patch
|
||||
file_ops: Object with read_file, write_file methods
|
||||
|
||||
|
||||
Returns:
|
||||
PatchResult with results of all operations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from tools.file_operations import PatchResult
|
||||
|
||||
|
||||
files_modified = []
|
||||
files_created = []
|
||||
files_deleted = []
|
||||
all_diffs = []
|
||||
errors = []
|
||||
|
||||
|
||||
for op in operations:
|
||||
try:
|
||||
if op.operation == OperationType.ADD:
|
||||
@@ -236,7 +229,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to add {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.DELETE:
|
||||
result = _apply_delete(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -244,7 +237,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.MOVE:
|
||||
result = _apply_move(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -252,7 +245,7 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to move {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
elif op.operation == OperationType.UPDATE:
|
||||
result = _apply_update(op, file_ops)
|
||||
if result[0]:
|
||||
@@ -260,19 +253,19 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to update {op.file_path}: {result[1]}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error processing {op.file_path}: {str(e)}")
|
||||
|
||||
|
||||
# Run lint on all modified/created files
|
||||
lint_results = {}
|
||||
for f in files_modified + files_created:
|
||||
if hasattr(file_ops, '_check_lint'):
|
||||
if hasattr(file_ops, "_check_lint"):
|
||||
lint_result = file_ops._check_lint(f)
|
||||
lint_results[f] = lint_result.to_dict()
|
||||
|
||||
combined_diff = '\n'.join(all_diffs)
|
||||
|
||||
|
||||
combined_diff = "\n".join(all_diffs)
|
||||
|
||||
if errors:
|
||||
return PatchResult(
|
||||
success=False,
|
||||
@@ -281,123 +274,124 @@ def apply_v4a_operations(operations: List[PatchOperation],
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None,
|
||||
error='; '.join(errors)
|
||||
error="; ".join(errors),
|
||||
)
|
||||
|
||||
|
||||
return PatchResult(
|
||||
success=True,
|
||||
diff=combined_diff,
|
||||
files_modified=files_modified,
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None
|
||||
lint=lint_results if lint_results else None,
|
||||
)
|
||||
|
||||
|
||||
def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
def _apply_add(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
"""Apply an add file operation."""
|
||||
# Extract content from hunks (all + lines)
|
||||
content_lines = []
|
||||
for hunk in op.hunks:
|
||||
for line in hunk.lines:
|
||||
if line.prefix == '+':
|
||||
if line.prefix == "+":
|
||||
content_lines.append(line.content)
|
||||
|
||||
content = '\n'.join(content_lines)
|
||||
|
||||
|
||||
content = "\n".join(content_lines)
|
||||
|
||||
result = file_ops.write_file(op.file_path, content)
|
||||
if result.error:
|
||||
return False, result.error
|
||||
|
||||
|
||||
diff = f"--- /dev/null\n+++ b/{op.file_path}\n"
|
||||
diff += '\n'.join(f"+{line}" for line in content_lines)
|
||||
|
||||
diff += "\n".join(f"+{line}" for line in content_lines)
|
||||
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
"""Apply a delete file operation."""
|
||||
# Read file first for diff
|
||||
read_result = file_ops.read_file(op.file_path)
|
||||
|
||||
|
||||
if read_result.error and "not found" in read_result.error.lower():
|
||||
# File doesn't exist, nothing to delete
|
||||
return True, f"# {op.file_path} already deleted or doesn't exist"
|
||||
|
||||
|
||||
# Delete directly via shell command using the underlying environment
|
||||
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
|
||||
|
||||
|
||||
if rm_result.exit_code != 0:
|
||||
return False, rm_result.stdout
|
||||
|
||||
|
||||
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
"""Apply a move file operation."""
|
||||
# Use shell mv command
|
||||
mv_result = file_ops._exec(
|
||||
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
|
||||
)
|
||||
|
||||
|
||||
if mv_result.exit_code != 0:
|
||||
return False, mv_result.stdout
|
||||
|
||||
|
||||
diff = f"# Moved: {op.file_path} -> {op.new_path}"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> tuple[bool, str]:
|
||||
"""Apply an update file operation."""
|
||||
# Read current content
|
||||
read_result = file_ops.read_file(op.file_path, limit=10000)
|
||||
|
||||
|
||||
if read_result.error:
|
||||
return False, f"Cannot read file: {read_result.error}"
|
||||
|
||||
|
||||
# Parse content (remove line numbers)
|
||||
current_lines = []
|
||||
for line in read_result.content.split('\n'):
|
||||
if '|' in line:
|
||||
for line in read_result.content.split("\n"):
|
||||
if "|" in line:
|
||||
# Line format: " 123|content"
|
||||
parts = line.split('|', 1)
|
||||
parts = line.split("|", 1)
|
||||
if len(parts) == 2:
|
||||
current_lines.append(parts[1])
|
||||
else:
|
||||
current_lines.append(line)
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
current_content = '\n'.join(current_lines)
|
||||
|
||||
|
||||
current_content = "\n".join(current_lines)
|
||||
|
||||
# Apply each hunk
|
||||
new_content = current_content
|
||||
|
||||
|
||||
for hunk in op.hunks:
|
||||
# Build search pattern from context and removed lines
|
||||
search_lines = []
|
||||
replace_lines = []
|
||||
|
||||
|
||||
for line in hunk.lines:
|
||||
if line.prefix == ' ':
|
||||
if line.prefix == " ":
|
||||
search_lines.append(line.content)
|
||||
replace_lines.append(line.content)
|
||||
elif line.prefix == '-':
|
||||
elif line.prefix == "-":
|
||||
search_lines.append(line.content)
|
||||
elif line.prefix == '+':
|
||||
elif line.prefix == "+":
|
||||
replace_lines.append(line.content)
|
||||
|
||||
|
||||
if search_lines:
|
||||
search_pattern = '\n'.join(search_lines)
|
||||
replacement = '\n'.join(replace_lines)
|
||||
|
||||
search_pattern = "\n".join(search_lines)
|
||||
replacement = "\n".join(replace_lines)
|
||||
|
||||
# Use fuzzy matching
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, count, error = fuzzy_find_and_replace(
|
||||
new_content, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
|
||||
if error and count == 0:
|
||||
# Try with context hint if available
|
||||
if hunk.context_hint:
|
||||
@@ -408,31 +402,32 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
window_start = max(0, hint_pos - 500)
|
||||
window_end = min(len(new_content), hint_pos + 2000)
|
||||
window = new_content[window_start:window_end]
|
||||
|
||||
|
||||
window_new, count, error = fuzzy_find_and_replace(
|
||||
window, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
|
||||
if count > 0:
|
||||
new_content = new_content[:window_start] + window_new + new_content[window_end:]
|
||||
error = None
|
||||
|
||||
|
||||
if error:
|
||||
return False, f"Could not apply hunk: {error}"
|
||||
|
||||
|
||||
# Write new content
|
||||
write_result = file_ops.write_file(op.file_path, new_content)
|
||||
if write_result.error:
|
||||
return False, write_result.error
|
||||
|
||||
|
||||
# Generate diff
|
||||
import difflib
|
||||
|
||||
diff_lines = difflib.unified_diff(
|
||||
current_content.splitlines(keepends=True),
|
||||
new_content.splitlines(keepends=True),
|
||||
fromfile=f"a/{op.file_path}",
|
||||
tofile=f"b/{op.file_path}"
|
||||
tofile=f"b/{op.file_path}",
|
||||
)
|
||||
diff = ''.join(diff_lines)
|
||||
|
||||
diff = "".join(diff_lines)
|
||||
|
||||
return True, diff
|
||||
|
||||
@@ -34,7 +34,6 @@ import logging
|
||||
import os
|
||||
import platform
|
||||
import shlex
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
@@ -42,10 +41,11 @@ import time
|
||||
import uuid
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from tools.environments.local import _find_shell
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from tools.environments.local import _find_shell
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,30 +54,31 @@ logger = logging.getLogger(__name__)
|
||||
CHECKPOINT_PATH = Path(os.path.expanduser("~/.hermes/processes.json"))
|
||||
|
||||
# Limits
|
||||
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
|
||||
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
|
||||
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
|
||||
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
|
||||
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
|
||||
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessSession:
|
||||
"""A tracked background process with output buffering."""
|
||||
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
|
||||
command: str # Original command string
|
||||
task_id: str = "" # Task/sandbox isolation key
|
||||
session_key: str = "" # Gateway session key (for reset protection)
|
||||
pid: Optional[int] = None # OS process ID
|
||||
process: Optional[subprocess.Popen] = None # Popen handle (local only)
|
||||
env_ref: Any = None # Reference to the environment object
|
||||
cwd: Optional[str] = None # Working directory
|
||||
started_at: float = 0.0 # time.time() of spawn
|
||||
exited: bool = False # Whether the process has finished
|
||||
exit_code: Optional[int] = None # Exit code (None if still running)
|
||||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||
|
||||
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
|
||||
command: str # Original command string
|
||||
task_id: str = "" # Task/sandbox isolation key
|
||||
session_key: str = "" # Gateway session key (for reset protection)
|
||||
pid: int | None = None # OS process ID
|
||||
process: subprocess.Popen | None = None # Popen handle (local only)
|
||||
env_ref: Any = None # Reference to the environment object
|
||||
cwd: str | None = None # Working directory
|
||||
started_at: float = 0.0 # time.time() of spawn
|
||||
exited: bool = False # Whether the process has finished
|
||||
exit_code: int | None = None # Exit code (None if still running)
|
||||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||
max_output_chars: int = MAX_OUTPUT_CHARS
|
||||
detached: bool = False # True if recovered from crash (no pipe)
|
||||
detached: bool = False # True if recovered from crash (no pipe)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
||||
_reader_thread: threading.Thread | None = field(default=None, repr=False)
|
||||
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
||||
|
||||
|
||||
@@ -100,12 +101,12 @@ class ProcessRegistry:
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self._running: Dict[str, ProcessSession] = {}
|
||||
self._finished: Dict[str, ProcessSession] = {}
|
||||
self._running: dict[str, ProcessSession] = {}
|
||||
self._finished: dict[str, ProcessSession] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Side-channel for check_interval watchers (gateway reads after agent run)
|
||||
self.pending_watchers: List[Dict[str, Any]] = []
|
||||
self.pending_watchers: list[dict[str, Any]] = []
|
||||
|
||||
@staticmethod
|
||||
def _clean_shell_noise(text: str) -> str:
|
||||
@@ -149,6 +150,7 @@ class ProcessRegistry:
|
||||
# Try PTY mode for interactive CLI tools
|
||||
try:
|
||||
import ptyprocess
|
||||
|
||||
user_shell = _find_shell()
|
||||
pty_env = os.environ | (env_vars or {})
|
||||
pty_env["PYTHONUNBUFFERED"] = "1"
|
||||
@@ -260,10 +262,7 @@ class ProcessRegistry:
|
||||
log_path = f"/tmp/hermes_bg_{session.id}.log"
|
||||
pid_path = f"/tmp/hermes_bg_{session.id}.pid"
|
||||
quoted_command = shlex.quote(command)
|
||||
bg_command = (
|
||||
f"nohup bash -c {quoted_command} > {log_path} 2>&1 & "
|
||||
f"echo $! > {pid_path} && cat {pid_path}"
|
||||
)
|
||||
bg_command = f"nohup bash -c {quoted_command} > {log_path} 2>&1 & echo $! > {pid_path} && cat {pid_path}"
|
||||
|
||||
try:
|
||||
result = env.execute(bg_command, timeout=timeout)
|
||||
@@ -313,7 +312,7 @@ class ProcessRegistry:
|
||||
with session._lock:
|
||||
session.output_buffer += chunk
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars :]
|
||||
except Exception as e:
|
||||
logger.debug("Process stdout reader ended: %s", e)
|
||||
|
||||
@@ -326,9 +325,7 @@ class ProcessRegistry:
|
||||
session.exit_code = session.process.returncode
|
||||
self._move_to_finished(session)
|
||||
|
||||
def _env_poller_loop(
|
||||
self, session: ProcessSession, env: Any, log_path: str, pid_path: str
|
||||
):
|
||||
def _env_poller_loop(self, session: ProcessSession, env: Any, log_path: str, pid_path: str):
|
||||
"""Background thread: poll a sandbox log file for non-local backends."""
|
||||
while not session.exited:
|
||||
time.sleep(2) # Poll every 2 seconds
|
||||
@@ -340,7 +337,7 @@ class ProcessRegistry:
|
||||
with session._lock:
|
||||
session.output_buffer = new_output
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars :]
|
||||
|
||||
# Check if process is still running
|
||||
check = env.execute(
|
||||
@@ -383,7 +380,7 @@ class ProcessRegistry:
|
||||
with session._lock:
|
||||
session.output_buffer += text
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars :]
|
||||
except EOFError:
|
||||
break
|
||||
except Exception:
|
||||
@@ -397,7 +394,7 @@ class ProcessRegistry:
|
||||
except Exception as e:
|
||||
logger.debug("PTY wait timed out or failed: %s", e)
|
||||
session.exited = True
|
||||
session.exit_code = pty.exitstatus if hasattr(pty, 'exitstatus') else -1
|
||||
session.exit_code = pty.exitstatus if hasattr(pty, "exitstatus") else -1
|
||||
self._move_to_finished(session)
|
||||
|
||||
def _move_to_finished(self, session: ProcessSession):
|
||||
@@ -409,7 +406,7 @@ class ProcessRegistry:
|
||||
|
||||
# ----- Query Methods -----
|
||||
|
||||
def get(self, session_id: str) -> Optional[ProcessSession]:
|
||||
def get(self, session_id: str) -> ProcessSession | None:
|
||||
"""Get a session by ID (running or finished)."""
|
||||
with self._lock:
|
||||
return self._running.get(session_id) or self._finished.get(session_id)
|
||||
@@ -454,7 +451,7 @@ class ProcessRegistry:
|
||||
if offset == 0 and limit > 0:
|
||||
selected = lines[-limit:]
|
||||
else:
|
||||
selected = lines[offset:offset + limit]
|
||||
selected = lines[offset : offset + limit]
|
||||
|
||||
return {
|
||||
"session_id": session.id,
|
||||
@@ -485,10 +482,7 @@ class ProcessRegistry:
|
||||
|
||||
if requested_timeout and requested_timeout > max_timeout:
|
||||
effective_timeout = max_timeout
|
||||
timeout_note = (
|
||||
f"Requested wait of {requested_timeout}s was clamped "
|
||||
f"to configured limit of {max_timeout}s"
|
||||
)
|
||||
timeout_note = f"Requested wait of {requested_timeout}s was clamped to configured limit of {max_timeout}s"
|
||||
else:
|
||||
effective_timeout = requested_timeout or max_timeout
|
||||
|
||||
@@ -581,7 +575,7 @@ class ProcessRegistry:
|
||||
return {"status": "already_exited", "error": "Process has already finished"}
|
||||
|
||||
# PTY mode -- write through pty handle (expects bytes)
|
||||
if hasattr(session, '_pty') and session._pty:
|
||||
if hasattr(session, "_pty") and session._pty:
|
||||
try:
|
||||
pty_data = data.encode("utf-8") if isinstance(data, str) else data
|
||||
session._pty.write(pty_data)
|
||||
@@ -635,26 +629,17 @@ class ProcessRegistry:
|
||||
def has_active_processes(self, task_id: str) -> bool:
|
||||
"""Check if there are active (running) processes for a task_id."""
|
||||
with self._lock:
|
||||
return any(
|
||||
s.task_id == task_id and not s.exited
|
||||
for s in self._running.values()
|
||||
)
|
||||
return any(s.task_id == task_id and not s.exited for s in self._running.values())
|
||||
|
||||
def has_active_for_session(self, session_key: str) -> bool:
|
||||
"""Check if there are active processes for a gateway session key."""
|
||||
with self._lock:
|
||||
return any(
|
||||
s.session_key == session_key and not s.exited
|
||||
for s in self._running.values()
|
||||
)
|
||||
return any(s.session_key == session_key and not s.exited for s in self._running.values())
|
||||
|
||||
def kill_all(self, task_id: str = None) -> int:
|
||||
"""Kill all running processes, optionally filtered by task_id. Returns count killed."""
|
||||
with self._lock:
|
||||
targets = [
|
||||
s for s in self._running.values()
|
||||
if (task_id is None or s.task_id == task_id) and not s.exited
|
||||
]
|
||||
targets = [s for s in self._running.values() if (task_id is None or s.task_id == task_id) and not s.exited]
|
||||
|
||||
killed = 0
|
||||
for session in targets:
|
||||
@@ -669,10 +654,7 @@ class ProcessRegistry:
|
||||
"""Remove oldest finished sessions if over MAX_PROCESSES. Must hold _lock."""
|
||||
# First prune expired finished sessions
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in self._finished.items()
|
||||
if (now - s.started_at) > FINISHED_TTL_SECONDS
|
||||
]
|
||||
expired = [sid for sid, s in self._finished.items() if (now - s.started_at) > FINISHED_TTL_SECONDS]
|
||||
for sid in expired:
|
||||
del self._finished[sid]
|
||||
|
||||
@@ -696,18 +678,21 @@ class ProcessRegistry:
|
||||
entries = []
|
||||
for s in self._running.values():
|
||||
if not s.exited:
|
||||
entries.append({
|
||||
"session_id": s.id,
|
||||
"command": s.command,
|
||||
"pid": s.pid,
|
||||
"cwd": s.cwd,
|
||||
"started_at": s.started_at,
|
||||
"task_id": s.task_id,
|
||||
"session_key": s.session_key,
|
||||
})
|
||||
|
||||
entries.append(
|
||||
{
|
||||
"session_id": s.id,
|
||||
"command": s.command,
|
||||
"pid": s.pid,
|
||||
"cwd": s.cwd,
|
||||
"started_at": s.started_at,
|
||||
"task_id": s.task_id,
|
||||
"session_key": s.session_key,
|
||||
}
|
||||
)
|
||||
|
||||
# Atomic write to avoid corruption on crash
|
||||
from utils import atomic_json_write
|
||||
|
||||
atomic_json_write(CHECKPOINT_PATH, entries)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
|
||||
@@ -759,6 +744,7 @@ class ProcessRegistry:
|
||||
# Clear the checkpoint (will be rewritten as processes finish)
|
||||
try:
|
||||
from utils import atomic_json_write
|
||||
|
||||
atomic_json_write(CHECKPOINT_PATH, [])
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
|
||||
@@ -790,38 +776,32 @@ PROCESS_SCHEMA = {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["list", "poll", "log", "wait", "kill", "write", "submit"],
|
||||
"description": "Action to perform on background processes"
|
||||
"description": "Action to perform on background processes",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "Process session ID (from terminal background output). Required for all actions except 'list'."
|
||||
"description": "Process session ID (from terminal background output). Required for all actions except 'list'.",
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "Text to send to process stdin (for 'write' and 'submit' actions)"
|
||||
"description": "Text to send to process stdin (for 'write' and 'submit' actions)",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to block for 'wait' action. Returns partial output on timeout.",
|
||||
"minimum": 1
|
||||
"minimum": 1,
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line offset for 'log' action (default: last 200 lines)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max lines to return for 'log' action",
|
||||
"minimum": 1
|
||||
}
|
||||
"offset": {"type": "integer", "description": "Line offset for 'log' action (default: last 200 lines)"},
|
||||
"limit": {"type": "integer", "description": "Max lines to return for 'log' action", "minimum": 1},
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_process(args, **kw):
|
||||
import json as _json
|
||||
|
||||
task_id = kw.get("task_id")
|
||||
action = args.get("action", "")
|
||||
# Coerce to string — some models send session_id as an integer
|
||||
@@ -835,8 +815,10 @@ def _handle_process(args, **kw):
|
||||
if action == "poll":
|
||||
return _json.dumps(process_registry.poll(session_id), ensure_ascii=False)
|
||||
elif action == "log":
|
||||
return _json.dumps(process_registry.read_log(
|
||||
session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)), ensure_ascii=False)
|
||||
return _json.dumps(
|
||||
process_registry.read_log(session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
elif action == "wait":
|
||||
return _json.dumps(process_registry.wait(session_id, timeout=args.get("timeout")), ensure_ascii=False)
|
||||
elif action == "kill":
|
||||
@@ -845,7 +827,10 @@ def _handle_process(args, **kw):
|
||||
return _json.dumps(process_registry.write_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
|
||||
elif action == "submit":
|
||||
return _json.dumps(process_registry.submit_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
|
||||
return _json.dumps({"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"}, ensure_ascii=False)
|
||||
return _json.dumps(
|
||||
{"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
registry.register(
|
||||
|
||||
@@ -16,7 +16,7 @@ Import chain (circular-import safe):
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,12 +25,17 @@ class ToolEntry:
|
||||
"""Metadata for a single registered tool."""
|
||||
|
||||
__slots__ = (
|
||||
"name", "toolset", "schema", "handler", "check_fn",
|
||||
"requires_env", "is_async", "description",
|
||||
"name",
|
||||
"toolset",
|
||||
"schema",
|
||||
"handler",
|
||||
"check_fn",
|
||||
"requires_env",
|
||||
"is_async",
|
||||
"description",
|
||||
)
|
||||
|
||||
def __init__(self, name, toolset, schema, handler, check_fn,
|
||||
requires_env, is_async, description):
|
||||
def __init__(self, name, toolset, schema, handler, check_fn, requires_env, is_async, description):
|
||||
self.name = name
|
||||
self.toolset = toolset
|
||||
self.schema = schema
|
||||
@@ -45,8 +50,8 @@ class ToolRegistry:
|
||||
"""Singleton registry that collects tool schemas + handlers from tool files."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, ToolEntry] = {}
|
||||
self._toolset_checks: Dict[str, Callable] = {}
|
||||
self._tools: dict[str, ToolEntry] = {}
|
||||
self._toolset_checks: dict[str, Callable] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
@@ -81,7 +86,7 @@ class ToolRegistry:
|
||||
# Schema retrieval
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]:
|
||||
def get_definitions(self, tool_names: set[str], quiet: bool = False) -> list[dict]:
|
||||
"""Return OpenAI-format tool schemas for the requested tool names.
|
||||
|
||||
Only tools whose ``check_fn()`` returns True (or have no check_fn)
|
||||
@@ -122,6 +127,7 @@ class ToolRegistry:
|
||||
try:
|
||||
if entry.is_async:
|
||||
from model_tools import _run_async
|
||||
|
||||
return _run_async(entry.handler(args, **kwargs))
|
||||
return entry.handler(args, **kwargs)
|
||||
except Exception as e:
|
||||
@@ -132,16 +138,16 @@ class ToolRegistry:
|
||||
# Query helpers (replace redundant dicts in model_tools.py)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_all_tool_names(self) -> List[str]:
|
||||
def get_all_tool_names(self) -> list[str]:
|
||||
"""Return sorted list of all registered tool names."""
|
||||
return sorted(self._tools.keys())
|
||||
|
||||
def get_toolset_for_tool(self, name: str) -> Optional[str]:
|
||||
def get_toolset_for_tool(self, name: str) -> str | None:
|
||||
"""Return the toolset a tool belongs to, or None."""
|
||||
entry = self._tools.get(name)
|
||||
return entry.toolset if entry else None
|
||||
|
||||
def get_tool_to_toolset_map(self) -> Dict[str, str]:
|
||||
def get_tool_to_toolset_map(self) -> dict[str, str]:
|
||||
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
|
||||
return {name: e.toolset for name, e in self._tools.items()}
|
||||
|
||||
@@ -160,14 +166,14 @@ class ToolRegistry:
|
||||
logger.debug("Toolset %s check raised; marking unavailable", toolset)
|
||||
return False
|
||||
|
||||
def check_toolset_requirements(self) -> Dict[str, bool]:
|
||||
def check_toolset_requirements(self) -> dict[str, bool]:
|
||||
"""Return ``{toolset: available_bool}`` for every toolset."""
|
||||
toolsets = set(e.toolset for e in self._tools.values())
|
||||
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
|
||||
|
||||
def get_available_toolsets(self) -> Dict[str, dict]:
|
||||
def get_available_toolsets(self) -> dict[str, dict]:
|
||||
"""Return toolset metadata for UI display."""
|
||||
toolsets: Dict[str, dict] = {}
|
||||
toolsets: dict[str, dict] = {}
|
||||
for entry in self._tools.values():
|
||||
ts = entry.toolset
|
||||
if ts not in toolsets:
|
||||
@@ -184,9 +190,9 @@ class ToolRegistry:
|
||||
toolsets[ts]["requirements"].append(env)
|
||||
return toolsets
|
||||
|
||||
def get_toolset_requirements(self) -> Dict[str, dict]:
|
||||
def get_toolset_requirements(self) -> dict[str, dict]:
|
||||
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
|
||||
result: Dict[str, dict] = {}
|
||||
result: dict[str, dict] = {}
|
||||
for entry in self._tools.values():
|
||||
ts = entry.toolset
|
||||
if ts not in result:
|
||||
@@ -217,11 +223,13 @@ class ToolRegistry:
|
||||
if self.is_toolset_available(ts):
|
||||
available.append(ts)
|
||||
else:
|
||||
unavailable.append({
|
||||
"name": ts,
|
||||
"env_vars": entry.requires_env,
|
||||
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
|
||||
})
|
||||
unavailable.append(
|
||||
{
|
||||
"name": ts,
|
||||
"env_vars": entry.requires_env,
|
||||
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
|
||||
}
|
||||
)
|
||||
return available, unavailable
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,19 +29,16 @@ SEND_MESSAGE_SCHEMA = {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["send", "list"],
|
||||
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms."
|
||||
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'"
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', or 'platform:chat_id'. Examples: 'telegram', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'",
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message text to send"
|
||||
}
|
||||
"message": {"type": "string", "description": "The message text to send"},
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -59,6 +56,7 @@ def _handle_list():
|
||||
"""Return formatted list of available messaging targets."""
|
||||
try:
|
||||
from gateway.channel_directory import format_directory_for_display
|
||||
|
||||
return json.dumps({"targets": format_directory_for_display()})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to load channel directory: {e}"})
|
||||
@@ -79,26 +77,30 @@ def _handle_send(args):
|
||||
if chat_id and not chat_id.lstrip("-").isdigit():
|
||||
try:
|
||||
from gateway.channel_directory import resolve_channel_name
|
||||
|
||||
resolved = resolve_channel_name(platform_name, chat_id)
|
||||
if resolved:
|
||||
chat_id = resolved
|
||||
else:
|
||||
return json.dumps({
|
||||
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
|
||||
f"Use send_message(action='list') to see available targets."
|
||||
})
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
|
||||
f"Use send_message(action='list') to see available targets."
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return json.dumps({
|
||||
"error": f"Could not resolve '{chat_id}' on {platform_name}. "
|
||||
f"Try using a numeric channel ID instead."
|
||||
})
|
||||
return json.dumps(
|
||||
{"error": f"Could not resolve '{chat_id}' on {platform_name}. Try using a numeric channel ID instead."}
|
||||
)
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
if is_interrupted():
|
||||
return json.dumps({"error": "Interrupted"})
|
||||
|
||||
try:
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
from gateway.config import Platform, load_gateway_config
|
||||
|
||||
config = load_gateway_config()
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to load gateway config: {e}"})
|
||||
@@ -117,7 +119,11 @@ def _handle_send(args):
|
||||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
return json.dumps({"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."})
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/gateway.json or environment variables."
|
||||
}
|
||||
)
|
||||
|
||||
used_home_channel = False
|
||||
if not chat_id:
|
||||
@@ -126,14 +132,17 @@ def _handle_send(args):
|
||||
chat_id = home.chat_id
|
||||
used_home_channel = True
|
||||
else:
|
||||
return json.dumps({
|
||||
"error": f"No home channel set for {platform_name} to determine where to send the message. "
|
||||
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
|
||||
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
|
||||
})
|
||||
return json.dumps(
|
||||
{
|
||||
"error": f"No home channel set for {platform_name} to determine where to send the message. "
|
||||
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
|
||||
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
from model_tools import _run_async
|
||||
|
||||
result = _run_async(_send_to_platform(platform, pconfig, chat_id, message))
|
||||
if used_home_channel and isinstance(result, dict) and result.get("success"):
|
||||
result["note"] = f"Sent to {platform_name} home channel (chat_id: {chat_id})"
|
||||
@@ -142,6 +151,7 @@ def _handle_send(args):
|
||||
if isinstance(result, dict) and result.get("success"):
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
|
||||
source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli")
|
||||
if mirror_to_session(platform_name, chat_id, message, source_label=source_label):
|
||||
result["mirrored"] = True
|
||||
@@ -156,6 +166,7 @@ def _handle_send(args):
|
||||
async def _send_to_platform(platform, pconfig, chat_id, message):
|
||||
"""Route a message to the appropriate platform sender."""
|
||||
from gateway.config import Platform
|
||||
|
||||
if platform == Platform.TELEGRAM:
|
||||
return await _send_telegram(pconfig.token, chat_id, message)
|
||||
elif platform == Platform.DISCORD:
|
||||
@@ -171,6 +182,7 @@ async def _send_telegram(token, chat_id, message):
|
||||
"""Send via Telegram Bot API (one-shot, no polling needed)."""
|
||||
try:
|
||||
from telegram import Bot
|
||||
|
||||
bot = Bot(token=token)
|
||||
msg = await bot.send_message(chat_id=int(chat_id), text=message)
|
||||
return {"success": True, "platform": "telegram", "chat_id": chat_id, "message_id": str(msg.message_id)}
|
||||
@@ -189,7 +201,7 @@ async def _send_discord(token, chat_id, message):
|
||||
try:
|
||||
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"}
|
||||
chunks = [message[i:i+2000] for i in range(0, len(message), 2000)]
|
||||
chunks = [message[i : i + 2000] for i in range(0, len(message), 2000)]
|
||||
message_ids = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for chunk in chunks:
|
||||
@@ -266,6 +278,7 @@ def _check_send_message():
|
||||
return True
|
||||
try:
|
||||
from gateway.status import is_gateway_running
|
||||
|
||||
return is_gateway_running()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -18,11 +18,8 @@ Flow:
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from typing import Any
|
||||
|
||||
from agent.auxiliary_client import get_async_text_auxiliary_client
|
||||
|
||||
@@ -33,7 +30,7 @@ MAX_SESSION_CHARS = 100_000
|
||||
MAX_SUMMARY_TOKENS = 10000
|
||||
|
||||
|
||||
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
||||
def _format_timestamp(ts: int | float | str | None) -> str:
|
||||
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
|
||||
|
||||
Returns "unknown" for None, str(ts) if conversion fails.
|
||||
@@ -43,11 +40,13 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
||||
try:
|
||||
if isinstance(ts, (int, float)):
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromtimestamp(ts)
|
||||
return dt.strftime("%B %d, %Y at %I:%M %p")
|
||||
if isinstance(ts, str):
|
||||
if ts.replace(".", "").replace("-", "").isdigit():
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromtimestamp(float(ts))
|
||||
return dt.strftime("%B %d, %Y at %I:%M %p")
|
||||
return ts
|
||||
@@ -59,7 +58,7 @@ def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
||||
return str(ts)
|
||||
|
||||
|
||||
def _format_conversation(messages: List[Dict[str, Any]]) -> str:
|
||||
def _format_conversation(messages: list[dict[str, Any]]) -> str:
|
||||
"""Format session messages into a readable transcript for summarization."""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
@@ -93,9 +92,7 @@ def _format_conversation(messages: List[Dict[str, Any]]) -> str:
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _truncate_around_matches(
|
||||
full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS
|
||||
) -> str:
|
||||
def _truncate_around_matches(full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS) -> str:
|
||||
"""
|
||||
Truncate a conversation transcript to max_chars, centered around
|
||||
where the query terms appear. Keeps content near matches, trims the edges.
|
||||
@@ -129,9 +126,7 @@ def _truncate_around_matches(
|
||||
return prefix + truncated + suffix
|
||||
|
||||
|
||||
async def _summarize_session(
|
||||
conversation_text: str, query: str, session_meta: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
async def _summarize_session(conversation_text: str, query: str, session_meta: dict[str, Any]) -> str | None:
|
||||
"""Summarize a single session conversation focused on the search query."""
|
||||
system_prompt = (
|
||||
"You are reviewing a past conversation transcript to help recall what happened. "
|
||||
@@ -163,7 +158,8 @@ async def _summarize_session(
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
from agent.auxiliary_client import get_auxiliary_extra_body, auxiliary_max_tokens_param
|
||||
from agent.auxiliary_client import auxiliary_max_tokens_param, get_auxiliary_extra_body
|
||||
|
||||
_extra = get_auxiliary_extra_body()
|
||||
response = await _async_aux_client.chat.completions.create(
|
||||
model=_SUMMARIZER_MODEL,
|
||||
@@ -221,13 +217,16 @@ def session_search(
|
||||
)
|
||||
|
||||
if not raw_results:
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": [],
|
||||
"count": 0,
|
||||
"message": "No matching sessions found.",
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": [],
|
||||
"count": 0,
|
||||
"message": "No matching sessions found.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Resolve child sessions to their parent — delegation stores detailed
|
||||
# content in child sessions, but the user's conversation is the parent.
|
||||
@@ -283,12 +282,9 @@ def session_search(
|
||||
logging.warning(f"Failed to prepare session {session_id}: {e}")
|
||||
|
||||
# Summarize all sessions in parallel
|
||||
async def _summarize_all() -> List[Union[str, Exception]]:
|
||||
async def _summarize_all() -> list[str | Exception]:
|
||||
"""Summarize all sessions in parallel."""
|
||||
coros = [
|
||||
_summarize_session(text, query, meta)
|
||||
for _, _, text, meta in tasks
|
||||
]
|
||||
coros = [_summarize_session(text, query, meta) for _, _, text, meta in tasks]
|
||||
return await asyncio.gather(*coros, return_exceptions=True)
|
||||
|
||||
try:
|
||||
@@ -300,10 +296,13 @@ def session_search(
|
||||
results = asyncio.run(_summarize_all())
|
||||
except concurrent.futures.TimeoutError:
|
||||
logging.warning("Session summarization timed out after 60 seconds")
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
summaries = []
|
||||
for (session_id, match_info, _, _), result in zip(tasks, results):
|
||||
@@ -311,21 +310,26 @@ def session_search(
|
||||
logging.warning(f"Failed to summarize session {session_id}: {result}")
|
||||
continue
|
||||
if result:
|
||||
summaries.append({
|
||||
"session_id": session_id,
|
||||
"when": _format_timestamp(match_info.get("session_started")),
|
||||
"source": match_info.get("source", "unknown"),
|
||||
"model": match_info.get("model"),
|
||||
"summary": result,
|
||||
})
|
||||
summaries.append(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"when": _format_timestamp(match_info.get("session_started")),
|
||||
"source": match_info.get("source", "unknown"),
|
||||
"model": match_info.get("model"),
|
||||
"summary": result,
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": summaries,
|
||||
"count": len(summaries),
|
||||
"sessions_searched": len(seen_sessions),
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": summaries,
|
||||
"count": len(summaries),
|
||||
"sessions_searched": len(seen_sessions),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"success": False, "error": f"Search failed: {str(e)}"}, ensure_ascii=False)
|
||||
@@ -337,6 +341,7 @@ def check_session_search_requirements() -> bool:
|
||||
return False
|
||||
try:
|
||||
from hermes_state import DEFAULT_DB_PATH
|
||||
|
||||
return DEFAULT_DB_PATH.parent.exists()
|
||||
except ImportError:
|
||||
return False
|
||||
@@ -356,7 +361,7 @@ SESSION_SEARCH_SCHEMA = {
|
||||
"Don't hesitate to search -- it's fast and cheap. Better to search and confirm "
|
||||
"than to guess or ask the user to repeat themselves.\n\n"
|
||||
"Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), "
|
||||
"phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). "
|
||||
'phrases for exact match ("docker networking"), boolean (python NOT java), prefix (deploy*). '
|
||||
"IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses "
|
||||
"sessions that only mention some terms. If a broad OR query returns nothing, try individual "
|
||||
"keyword searches in parallel. Returns summaries of the top matching sessions."
|
||||
@@ -395,6 +400,7 @@ registry.register(
|
||||
role_filter=args.get("role_filter"),
|
||||
limit=args.get("limit", 3),
|
||||
db=kw.get("db"),
|
||||
current_session_id=kw.get("current_session_id")),
|
||||
current_session_id=kw.get("current_session_id"),
|
||||
),
|
||||
check_fn=check_session_search_requirements,
|
||||
)
|
||||
|
||||
@@ -38,20 +38,21 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
|
||||
from tools.skills_guard import format_scan_report, scan_skill, should_allow_install
|
||||
|
||||
_GUARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
_GUARD_AVAILABLE = False
|
||||
|
||||
|
||||
def _security_scan_skill(skill_dir: Path) -> Optional[str]:
|
||||
def _security_scan_skill(skill_dir: Path) -> str | None:
|
||||
"""Scan a skill directory after write. Returns error string if blocked, else None."""
|
||||
if not _GUARD_AVAILABLE:
|
||||
return None
|
||||
@@ -65,8 +66,8 @@ def _security_scan_skill(skill_dir: Path) -> Optional[str]:
|
||||
logger.warning("Security scan failed for %s: %s", skill_dir, e)
|
||||
return None
|
||||
|
||||
import yaml
|
||||
|
||||
import yaml
|
||||
|
||||
# All skills live in ~/.hermes/skills/ (single source of truth)
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
@@ -76,7 +77,7 @@ MAX_NAME_LENGTH = 64
|
||||
MAX_DESCRIPTION_LENGTH = 1024
|
||||
|
||||
# Characters allowed in skill names (filesystem-safe, URL-friendly)
|
||||
VALID_NAME_RE = re.compile(r'^[a-z0-9][a-z0-9._-]*$')
|
||||
VALID_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9._-]*$")
|
||||
|
||||
# Subdirectories allowed for write_file/remove_file
|
||||
ALLOWED_SUBDIRS = {"references", "templates", "scripts", "assets"}
|
||||
@@ -91,7 +92,8 @@ def check_skill_manage_requirements() -> bool:
|
||||
# Validation helpers
|
||||
# =============================================================================
|
||||
|
||||
def _validate_name(name: str) -> Optional[str]:
|
||||
|
||||
def _validate_name(name: str) -> str | None:
|
||||
"""Validate a skill name. Returns error message or None if valid."""
|
||||
if not name:
|
||||
return "Skill name is required."
|
||||
@@ -105,7 +107,7 @@ def _validate_name(name: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _validate_frontmatter(content: str) -> Optional[str]:
|
||||
def _validate_frontmatter(content: str) -> str | None:
|
||||
"""
|
||||
Validate that SKILL.md content has proper frontmatter with required fields.
|
||||
Returns error message or None if valid.
|
||||
@@ -116,11 +118,11 @@ def _validate_frontmatter(content: str) -> Optional[str]:
|
||||
if not content.startswith("---"):
|
||||
return "SKILL.md must start with YAML frontmatter (---). See existing skills for format."
|
||||
|
||||
end_match = re.search(r'\n---\s*\n', content[3:])
|
||||
end_match = re.search(r"\n---\s*\n", content[3:])
|
||||
if not end_match:
|
||||
return "SKILL.md frontmatter is not closed. Ensure you have a closing '---' line."
|
||||
|
||||
yaml_content = content[3:end_match.start() + 3]
|
||||
yaml_content = content[3 : end_match.start() + 3]
|
||||
|
||||
try:
|
||||
parsed = yaml.safe_load(yaml_content)
|
||||
@@ -137,7 +139,7 @@ def _validate_frontmatter(content: str) -> Optional[str]:
|
||||
if len(str(parsed["description"])) > MAX_DESCRIPTION_LENGTH:
|
||||
return f"Description exceeds {MAX_DESCRIPTION_LENGTH} characters."
|
||||
|
||||
body = content[end_match.end() + 3:].strip()
|
||||
body = content[end_match.end() + 3 :].strip()
|
||||
if not body:
|
||||
return "SKILL.md must have content after the frontmatter (instructions, procedures, etc.)."
|
||||
|
||||
@@ -151,7 +153,7 @@ def _resolve_skill_dir(name: str, category: str = None) -> Path:
|
||||
return SKILLS_DIR / name
|
||||
|
||||
|
||||
def _find_skill(name: str) -> Optional[Dict[str, Any]]:
|
||||
def _find_skill(name: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Find a skill by name in ~/.hermes/skills/.
|
||||
Returns {"path": Path} or None.
|
||||
@@ -164,7 +166,7 @@ def _find_skill(name: str) -> Optional[Dict[str, Any]]:
|
||||
return None
|
||||
|
||||
|
||||
def _validate_file_path(file_path: str) -> Optional[str]:
|
||||
def _validate_file_path(file_path: str) -> str | None:
|
||||
"""
|
||||
Validate a file path for write_file/remove_file.
|
||||
Must be under an allowed subdirectory and not escape the skill dir.
|
||||
@@ -194,7 +196,8 @@ def _validate_file_path(file_path: str) -> Optional[str]:
|
||||
# Core actions
|
||||
# =============================================================================
|
||||
|
||||
def _create_skill(name: str, content: str, category: str = None) -> Dict[str, Any]:
|
||||
|
||||
def _create_skill(name: str, content: str, category: str = None) -> dict[str, Any]:
|
||||
"""Create a new user skill with SKILL.md content."""
|
||||
# Validate name
|
||||
err = _validate_name(name)
|
||||
@@ -209,10 +212,7 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
|
||||
# Check for name collisions across all directories
|
||||
existing = _find_skill(name)
|
||||
if existing:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"A skill named '{name}' already exists at {existing['path']}."
|
||||
}
|
||||
return {"success": False, "error": f"A skill named '{name}' already exists at {existing['path']}."}
|
||||
|
||||
# Create the skill directory
|
||||
skill_dir = _resolve_skill_dir(name, category)
|
||||
@@ -238,12 +238,12 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An
|
||||
result["category"] = category
|
||||
result["hint"] = (
|
||||
"To add reference files, templates, or scripts, use "
|
||||
"skill_manage(action='write_file', name='{}', file_path='references/example.md', file_content='...')".format(name)
|
||||
f"skill_manage(action='write_file', name='{name}', file_path='references/example.md', file_content='...')"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
||||
def _edit_skill(name: str, content: str) -> dict[str, Any]:
|
||||
"""Replace the SKILL.md of any existing skill (full rewrite)."""
|
||||
err = _validate_frontmatter(content)
|
||||
if err:
|
||||
@@ -278,7 +278,7 @@ def _patch_skill(
|
||||
new_string: str,
|
||||
file_path: str = None,
|
||||
replace_all: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Targeted find-and-replace within a skill file.
|
||||
|
||||
Defaults to SKILL.md. Use file_path to patch a supporting file instead.
|
||||
@@ -287,7 +287,10 @@ def _patch_skill(
|
||||
if not old_string:
|
||||
return {"success": False, "error": "old_string is required for 'patch'."}
|
||||
if new_string is None:
|
||||
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "new_string is required for 'patch'. Use an empty string to delete matched text.",
|
||||
}
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
@@ -357,7 +360,7 @@ def _patch_skill(
|
||||
}
|
||||
|
||||
|
||||
def _delete_skill(name: str) -> Dict[str, Any]:
|
||||
def _delete_skill(name: str) -> dict[str, Any]:
|
||||
"""Delete a skill."""
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
@@ -377,7 +380,7 @@ def _delete_skill(name: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||
def _write_file(name: str, file_path: str, file_content: str) -> dict[str, Any]:
|
||||
"""Add or overwrite a supporting file within any skill directory."""
|
||||
err = _validate_file_path(file_path)
|
||||
if err:
|
||||
@@ -412,7 +415,7 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
|
||||
def _remove_file(name: str, file_path: str) -> dict[str, Any]:
|
||||
"""Remove a supporting file from any skill directory."""
|
||||
err = _validate_file_path(file_path)
|
||||
if err:
|
||||
@@ -456,6 +459,7 @@ def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
|
||||
# Main entry point
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def skill_manage(
|
||||
action: str,
|
||||
name: str,
|
||||
@@ -474,19 +478,37 @@ def skill_manage(
|
||||
"""
|
||||
if action == "create":
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body)."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body).",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
result = _create_skill(name, content, category)
|
||||
|
||||
elif action == "edit":
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
result = _edit_skill(name, content)
|
||||
|
||||
elif action == "patch":
|
||||
if not old_string:
|
||||
return json.dumps({"success": False, "error": "old_string is required for 'patch'. Provide the text to find."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "old_string is required for 'patch'. Provide the text to find."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if new_string is None:
|
||||
return json.dumps({"success": False, "error": "new_string is required for 'patch'. Use empty string to delete matched text."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "new_string is required for 'patch'. Use empty string to delete matched text.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
result = _patch_skill(name, old_string, new_string, file_path, replace_all)
|
||||
|
||||
elif action == "delete":
|
||||
@@ -494,18 +516,31 @@ def skill_manage(
|
||||
|
||||
elif action == "write_file":
|
||||
if not file_path:
|
||||
return json.dumps({"success": False, "error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'"}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if file_content is None:
|
||||
return json.dumps({"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False
|
||||
)
|
||||
result = _write_file(name, file_path, file_content)
|
||||
|
||||
elif action == "remove_file":
|
||||
if not file_path:
|
||||
return json.dumps({"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False
|
||||
)
|
||||
result = _remove_file(name, file_path)
|
||||
|
||||
else:
|
||||
result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"}
|
||||
result = {
|
||||
"success": False,
|
||||
"error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file",
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
@@ -540,14 +575,14 @@ SKILL_MANAGE_SCHEMA = {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["create", "patch", "edit", "delete", "write_file", "remove_file"],
|
||||
"description": "The action to perform."
|
||||
"description": "The action to perform.",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Skill name (lowercase, hyphens/underscores, max 64 chars). "
|
||||
"Must match an existing skill for patch/edit/delete/write_file/remove_file."
|
||||
)
|
||||
),
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
@@ -555,7 +590,7 @@ SKILL_MANAGE_SCHEMA = {
|
||||
"Full SKILL.md content (YAML frontmatter + markdown body). "
|
||||
"Required for 'create' and 'edit'. For 'edit', read the skill "
|
||||
"first with skill_view() and provide the complete updated text."
|
||||
)
|
||||
),
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
@@ -563,18 +598,17 @@ SKILL_MANAGE_SCHEMA = {
|
||||
"Text to find in the file (required for 'patch'). Must be unique "
|
||||
"unless replace_all=true. Include enough surrounding context to "
|
||||
"ensure uniqueness."
|
||||
)
|
||||
),
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Replacement text (required for 'patch'). Can be empty string "
|
||||
"to delete the matched text."
|
||||
)
|
||||
"Replacement text (required for 'patch'). Can be empty string to delete the matched text."
|
||||
),
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false)."
|
||||
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false).",
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
@@ -582,7 +616,7 @@ SKILL_MANAGE_SCHEMA = {
|
||||
"Optional category/domain for organizing the skill (e.g., 'devops', "
|
||||
"'data-science', 'mlops'). Creates a subdirectory grouping. "
|
||||
"Only used with 'create'."
|
||||
)
|
||||
),
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
@@ -591,12 +625,9 @@ SKILL_MANAGE_SCHEMA = {
|
||||
"For 'write_file'/'remove_file': required, must be under references/, "
|
||||
"templates/, scripts/, or assets/. "
|
||||
"For 'patch': optional, defaults to SKILL.md if omitted."
|
||||
)
|
||||
},
|
||||
"file_content": {
|
||||
"type": "string",
|
||||
"description": "Content for the file. Required for 'write_file'."
|
||||
),
|
||||
},
|
||||
"file_content": {"type": "string", "description": "Content for the file. Required for 'write_file'."},
|
||||
},
|
||||
"required": ["action", "name"],
|
||||
},
|
||||
@@ -619,5 +650,6 @@ registry.register(
|
||||
file_content=args.get("file_content"),
|
||||
old_string=args.get("old_string"),
|
||||
new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False)),
|
||||
replace_all=args.get("replace_all", False),
|
||||
),
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user