mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 19:57:07 +08:00
Compare commits
130 Commits
feat/insig
...
pass-sessi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e80320069b | ||
|
|
f2027b8bff | ||
|
|
a648022137 | ||
|
|
95b1130485 | ||
|
|
3fb8938cd3 | ||
|
|
c5e8166c8b | ||
|
|
2b88568653 | ||
|
|
34b4fe495e | ||
|
|
4fdd6c0dac | ||
|
|
60b6abefd9 | ||
|
|
4d53b7ccaa | ||
|
|
cd77c7100c | ||
|
|
cf810c2950 | ||
|
|
a23bcb81ce | ||
|
|
d07d867718 | ||
|
|
666f2dd486 | ||
|
|
34792dd907 | ||
|
|
7ad6fc8a40 | ||
|
|
f824c10429 | ||
|
|
132e5ec179 | ||
|
|
66d3e6a0c2 | ||
|
|
4a09ae2985 | ||
|
|
8c734f2f27 | ||
|
|
245d174359 | ||
|
|
77f47768dd | ||
|
|
90fa9e54ca | ||
|
|
9d3a44e0e8 | ||
|
|
932d596466 | ||
|
|
d518f40e8b | ||
|
|
f016cfca46 | ||
|
|
b8120df860 | ||
|
|
0df7df52f3 | ||
|
|
bfa27d0a68 | ||
|
|
5a20c486e3 | ||
|
|
78e19ebc95 | ||
|
|
b383cafc44 | ||
|
|
b10ff83566 | ||
|
|
daa1f542f9 | ||
|
|
d507f593d0 | ||
|
|
f210510276 | ||
|
|
19b6f81ee7 | ||
|
|
76545ab365 | ||
|
|
b8c3bc7841 | ||
|
|
a680367568 | ||
|
|
dfd37a4b31 | ||
|
|
5ee9b67d9b | ||
|
|
542faf225f | ||
|
|
5684c68121 | ||
|
|
4be783446a | ||
|
|
8d719b180a | ||
|
|
bf048c8aec | ||
|
|
c5a9d1ef9d | ||
|
|
c7b6f423c7 | ||
|
|
6d34207167 | ||
|
|
fcde9be10d | ||
|
|
3830bbda41 | ||
|
|
4447e7d71a | ||
|
|
7bccd904c7 | ||
|
|
313d522b61 | ||
|
|
9ee4fe41fe | ||
|
|
39ee3512cb | ||
|
|
42673556af | ||
|
|
faab73ad58 | ||
|
|
7e36468511 | ||
|
|
9ba5d399e5 | ||
|
|
306d92a9d7 | ||
|
|
5baae0df88 | ||
|
|
24f6a193e7 | ||
|
|
8c0f8baf32 | ||
|
|
d80c30cc92 | ||
|
|
e64d646bad | ||
|
|
b84f9e410c | ||
|
|
ee5daba061 | ||
|
|
23e84de830 | ||
|
|
48e0dc8791 | ||
|
|
fb0f579b16 | ||
|
|
5a711f32b1 | ||
|
|
4d34427cc7 | ||
|
|
41877183bc | ||
|
|
451a007fb1 | ||
|
|
0a82396718 | ||
|
|
5da55ea1e3 | ||
|
|
064c009deb | ||
|
|
caab1cf453 | ||
|
|
55c70f3508 | ||
|
|
d29249b8fa | ||
|
|
f668e9fc75 | ||
|
|
74fe1e2254 | ||
|
|
348936752a | ||
|
|
69a36a3361 | ||
|
|
8712dd6d1c | ||
|
|
55a21fe37b | ||
|
|
f55f625277 | ||
|
|
9dac85b069 | ||
|
|
99bd69baa8 | ||
|
|
a62a137a4f | ||
|
|
82b18e8ac2 | ||
|
|
0111c9848d | ||
|
|
ab9cadfeee | ||
|
|
8bf28e1441 | ||
|
|
ce28f847ce | ||
|
|
5609117882 | ||
|
|
b4fbb6fe10 | ||
|
|
82d7e9429e | ||
|
|
e2821effb5 | ||
|
|
9742f11fda | ||
|
|
388dd4789c | ||
|
|
fdebca4573 | ||
|
|
479dfc096a | ||
|
|
3c6c11b7c9 | ||
|
|
bc091eb7ef | ||
|
|
f75b1d21b4 | ||
|
|
94053d75a6 | ||
|
|
2a68099675 | ||
|
|
6cd3bc6640 | ||
|
|
211b55815e | ||
|
|
8ae4a6f824 | ||
|
|
b98301677a | ||
|
|
f2fdde5ba4 | ||
|
|
4f56e31dc7 | ||
|
|
6d3804770c | ||
|
|
ab0f4126cf | ||
|
|
68fbae5692 | ||
|
|
32636ecf8a | ||
|
|
3221818b6e | ||
|
|
a1c25046a9 | ||
|
|
d10108f8ca | ||
|
|
8b520f9848 | ||
|
|
a718aed1be | ||
|
|
5f29e7b63c |
32
.env.example
32
.env.example
@@ -13,6 +13,38 @@ OPENROUTER_API_KEY=
|
||||
# Examples: anthropic/claude-opus-4.6, openai/gpt-4o, google/gemini-3-flash-preview, zhipuai/glm-4-plus
|
||||
LLM_MODEL=anthropic/claude-opus-4.6
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (z.ai / GLM)
|
||||
# =============================================================================
|
||||
# z.ai provides access to ZhipuAI GLM models (GLM-4-Plus, etc.)
|
||||
# Get your key at: https://z.ai or https://open.bigmodel.cn
|
||||
GLM_API_KEY=
|
||||
# GLM_BASE_URL=https://api.z.ai/api/paas/v4 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Kimi / Moonshot)
|
||||
# =============================================================================
|
||||
# Kimi Code provides access to Moonshot AI coding models (kimi-k2.5, etc.)
|
||||
# Get your key at: https://platform.kimi.ai (Kimi Code console)
|
||||
# Keys prefixed sk-kimi- use the Kimi Code API (api.kimi.com) by default.
|
||||
# Legacy keys from platform.moonshot.ai need KIMI_BASE_URL override below.
|
||||
KIMI_API_KEY=
|
||||
# KIMI_BASE_URL=https://api.kimi.com/coding/v1 # Default for sk-kimi- keys
|
||||
# KIMI_BASE_URL=https://api.moonshot.ai/v1 # For legacy Moonshot keys
|
||||
# KIMI_BASE_URL=https://api.moonshot.cn/v1 # For Moonshot China keys
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (MiniMax)
|
||||
# =============================================================================
|
||||
# MiniMax provides access to MiniMax models (global endpoint)
|
||||
# Get your key at: https://www.minimax.io
|
||||
MINIMAX_API_KEY=
|
||||
# MINIMAX_BASE_URL=https://api.minimax.io/v1 # Override default base URL
|
||||
|
||||
# MiniMax China endpoint (for users in mainland China)
|
||||
MINIMAX_CN_API_KEY=
|
||||
# MINIMAX_CN_BASE_URL=https://api.minimaxi.com/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -47,4 +47,5 @@ cli-config.yaml
|
||||
|
||||
# Skills Hub state (lives in ~/.hermes/skills/.hub/ at runtime, but just in case)
|
||||
skills/.hub/
|
||||
ignored/
|
||||
ignored/
|
||||
.worktrees/
|
||||
|
||||
36
AGENTS.md
36
AGENTS.md
@@ -58,6 +58,7 @@ hermes-agent/
|
||||
├── skills/ # Bundled skill sources
|
||||
├── optional-skills/ # Official optional skills (not activated by default)
|
||||
├── cli.py # Interactive CLI orchestrator (HermesCLI class)
|
||||
├── hermes_state.py # SessionDB — SQLite session store (schema, titles, FTS5 search)
|
||||
├── run_agent.py # AIAgent class (core conversation loop)
|
||||
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
|
||||
├── toolsets.py # Tool groupings
|
||||
@@ -98,7 +99,7 @@ The main agent is implemented in `run_agent.py`:
|
||||
class AIAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4",
|
||||
model: str = "anthropic/claude-sonnet-4.6",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_iterations: int = 60, # Max tool-calling loops
|
||||
@@ -204,7 +205,7 @@ Every installed skill in `~/.hermes/skills/` is automatically registered as a sl
|
||||
The skill name (from frontmatter or folder name) becomes the command: `axolotl` → `/axolotl`.
|
||||
|
||||
Implementation (`agent/skill_commands.py`, shared between CLI and gateway):
|
||||
1. `scan_skill_commands()` scans all SKILL.md files at startup
|
||||
1. `scan_skill_commands()` scans all SKILL.md files at startup, filtering out skills incompatible with the current OS platform (via the `platforms` frontmatter field)
|
||||
2. `build_skill_invocation_message()` loads the SKILL.md content and builds a user-turn message
|
||||
3. The message includes the full skill content, a list of supporting files (not loaded), and the user's instruction
|
||||
4. Supporting files can be loaded on demand via the `skill_view` tool
|
||||
@@ -226,6 +227,10 @@ The unified `hermes` command provides all functionality:
|
||||
|---------|-------------|
|
||||
| `hermes` | Interactive chat (default) |
|
||||
| `hermes chat -q "..."` | Single query mode |
|
||||
| `hermes -c` / `hermes --continue` | Resume the most recent session |
|
||||
| `hermes -c "my project"` | Resume a session by name (latest in lineage) |
|
||||
| `hermes --resume <session_id>` | Resume a specific session by ID or title |
|
||||
| `hermes -w` / `hermes --worktree` | Start in isolated git worktree (for parallel agents) |
|
||||
| `hermes setup` | Configure API keys and settings |
|
||||
| `hermes config` | View current configuration |
|
||||
| `hermes config edit` | Open config in editor |
|
||||
@@ -239,6 +244,8 @@ The unified `hermes` command provides all functionality:
|
||||
| `hermes gateway` | Start gateway (messaging + cron scheduler) |
|
||||
| `hermes gateway setup` | Configure messaging platforms interactively |
|
||||
| `hermes gateway install` | Install gateway as system service |
|
||||
| `hermes sessions list` | List past sessions (title, preview, last active) |
|
||||
| `hermes sessions rename <id> <title>` | Rename/title a session |
|
||||
| `hermes cron list` | View scheduled jobs |
|
||||
| `hermes cron status` | Check if cron scheduler is running |
|
||||
| `hermes version` | Show version info |
|
||||
@@ -657,6 +664,7 @@ SKILL.md files use YAML frontmatter (agentskills.io format):
|
||||
name: skill-name
|
||||
description: Brief description for listing
|
||||
version: 1.0.0
|
||||
platforms: [macos] # Optional — restrict to specific OS (macos/linux/windows)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [tag1, tag2]
|
||||
@@ -665,6 +673,8 @@ metadata:
|
||||
# Skill Content...
|
||||
```
|
||||
|
||||
**Platform filtering** — Skills with a `platforms` field are automatically excluded from the system prompt index, `skills_list()`, and slash commands on incompatible platforms. Skills without the field load everywhere (backward compatible). See `skills/apple/` for macOS-only examples (iMessage, Reminders, Notes, FindMy).
|
||||
|
||||
**Skills Hub** — user-driven skill search/install from online registries and official optional skills. Sources: official optional skills (shipped with repo, labeled "official"), GitHub (openai/skills, anthropics/skills, custom taps), ClawHub, Claude marketplace, LobeHub. Not exposed as an agent tool — the model cannot search for or install skills. Users manage skills via `hermes skills browse/search/install` CLI commands or the `/skills` slash command in chat.
|
||||
|
||||
Key files:
|
||||
@@ -675,6 +685,28 @@ Key files:
|
||||
|
||||
---
|
||||
|
||||
## Known Pitfalls
|
||||
|
||||
### DO NOT use `simple_term_menu` for interactive menus
|
||||
|
||||
`simple_term_menu` has rendering bugs in tmux, iTerm2, and other non-standard terminals. When the user scrolls with arrow keys, previously highlighted items "ghost" — duplicating upward and corrupting the display. This happens because the library uses ANSI cursor-up codes to redraw in place, and tmux/iTerm miscalculate positions when the menu is near the bottom of the viewport.
|
||||
|
||||
**Rule:** All interactive menus in `hermes_cli/` must use `curses` (Python stdlib) instead. See `tools_config.py` for the pattern — both `_prompt_choice()` (single-select) and `_prompt_toolset_checklist()` (multi-select with space toggle) use `curses.wrapper()`. The numbered-input fallback handles Windows where curses isn't available.
|
||||
|
||||
### DO NOT use `\033[K` (ANSI erase-to-EOL) in spinner/display code
|
||||
|
||||
The ANSI escape `\033[K` leaks as literal `?[K` text when `prompt_toolkit`'s `patch_stdout` is active. Use space-padding instead to clear lines: `f"\r{line}{' ' * pad}"`. See `agent/display.py` `KawaiiSpinner`.
|
||||
|
||||
### `_last_resolved_tool_names` is a process-global in `model_tools.py`
|
||||
|
||||
The `execute_code` sandbox uses `_last_resolved_tool_names` (set by `get_tool_definitions()`) to decide which tool stubs to generate. When subagents run with restricted toolsets, they overwrite this global. After delegation returns to the parent, `execute_code` may see the child's restricted list instead of the parent's full list. This is a known bug — `execute_code` calls after delegation may fail with `ImportError: cannot import name 'patch' from 'hermes_tools'`.
|
||||
|
||||
### Tests must not write to `~/.hermes/`
|
||||
|
||||
The `autouse` fixture `_isolate_hermes_home` in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Every test runs in isolation. If you add a test that creates `AIAgent` instances or writes session logs, the fixture handles cleanup automatically. Never hardcode `~/.hermes/` paths in tests.
|
||||
|
||||
---
|
||||
|
||||
## Testing Changes
|
||||
|
||||
After making changes:
|
||||
|
||||
@@ -118,7 +118,7 @@ hermes-agent/
|
||||
├── cli.py # HermesCLI class — interactive TUI, prompt_toolkit integration
|
||||
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
|
||||
├── toolsets.py # Tool groupings and presets (hermes-cli, hermes-telegram, etc.)
|
||||
├── hermes_state.py # SQLite session database with FTS5 full-text search
|
||||
├── hermes_state.py # SQLite session database with FTS5 full-text search, session titles
|
||||
├── batch_runner.py # Parallel batch processing for trajectory generation
|
||||
│
|
||||
├── agent/ # Agent internals (extracted modules)
|
||||
@@ -218,7 +218,7 @@ User message → AIAgent._run_agent_loop()
|
||||
|
||||
- **Self-registering tools**: Each tool file calls `registry.register()` at import time. `model_tools.py` triggers discovery by importing all tool modules.
|
||||
- **Toolset grouping**: Tools are grouped into toolsets (`web`, `terminal`, `file`, `browser`, etc.) that can be enabled/disabled per platform.
|
||||
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search. JSON logs go to `~/.hermes/sessions/`.
|
||||
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search and unique session titles. JSON logs go to `~/.hermes/sessions/`.
|
||||
- **Ephemeral injection**: System prompts and prefill messages are injected at API call time, never persisted to the database or logs.
|
||||
- **Provider abstraction**: The agent works with any OpenAI-compatible API. Provider resolution happens at init time (Nous Portal OAuth, OpenRouter API key, or custom endpoint).
|
||||
- **Provider routing**: When using OpenRouter, `provider_routing` in config.yaml controls provider selection (sort by throughput/latency/price, allow/ignore specific providers, data retention policies). These are injected as `extra_body.provider` in API requests.
|
||||
@@ -325,6 +325,9 @@ description: Brief description (shown in skill search results)
|
||||
version: 1.0.0
|
||||
author: Your Name
|
||||
license: MIT
|
||||
platforms: [macos, linux] # Optional — restrict to specific OS platforms
|
||||
# Valid: macos, linux, windows
|
||||
# Omit to load on all platforms (default)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Category, Subcategory, Keywords]
|
||||
@@ -351,6 +354,18 @@ Known failure modes and how to handle them.
|
||||
How the agent confirms it worked.
|
||||
```
|
||||
|
||||
### Platform-specific skills
|
||||
|
||||
Skills can declare which OS platforms they support via the `platforms` frontmatter field. Skills with this field are automatically hidden from the system prompt, `skills_list()`, and slash commands on incompatible platforms.
|
||||
|
||||
```yaml
|
||||
platforms: [macos] # macOS only (e.g., iMessage, Apple Reminders)
|
||||
platforms: [macos, linux] # macOS and Linux
|
||||
platforms: [windows] # Windows only
|
||||
```
|
||||
|
||||
If the field is omitted or empty, the skill loads on all platforms (backward compatible). See `skills/apple/` for examples of macOS-only skills.
|
||||
|
||||
### Skill guidelines
|
||||
|
||||
- **No external dependencies unless absolutely necessary.** Prefer stdlib Python, curl, and existing Hermes tools (`web_extract`, `terminal`, `read_file`).
|
||||
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Nous Research
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
|
||||
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
|
||||
<table>
|
||||
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
|
||||
|
||||
129
TODO.md
129
TODO.md
@@ -1,129 +0,0 @@
|
||||
# Hermes Agent - Future Improvements
|
||||
|
||||
---
|
||||
|
||||
|
||||
|
||||
## 3. Local Browser Control via CDP 🌐
|
||||
|
||||
**Status:** Not started (currently Browserbase cloud only)
|
||||
**Priority:** Medium
|
||||
|
||||
Support local Chrome/Chromium via Chrome DevTools Protocol alongside existing Browserbase cloud backend.
|
||||
|
||||
**What other agents do:**
|
||||
- **OpenClaw**: Full CDP-based Chrome control with snapshots, actions, uploads, profiles, file chooser, PDF save, console messages, tab management. Uses local Chrome for persistent login sessions.
|
||||
- **Cline**: Headless browser with Computer Use (click, type, scroll, screenshot, console logs)
|
||||
|
||||
**Our approach:**
|
||||
- Add a `local` backend option to `browser_tool.py` using Playwright or raw CDP
|
||||
- Config toggle: `browser.backend: local | browserbase | auto`
|
||||
- `auto` mode: try local first, fall back to Browserbase
|
||||
- Local advantages: free, persistent login sessions, no API key needed
|
||||
- Local disadvantages: no CAPTCHA solving, no stealth mode, requires Chrome installed
|
||||
- Reuse the same 10-tool interface -- just swap the backend
|
||||
- Later: Chrome profile management for persistent sessions across restarts
|
||||
|
||||
---
|
||||
|
||||
## 4. Signal Integration 📡
|
||||
|
||||
**Status:** Not started
|
||||
**Priority:** Low
|
||||
|
||||
New platform adapter using signal-cli daemon (JSON-RPC HTTP + SSE). Requires Java runtime and phone number registration.
|
||||
|
||||
**Reference:** OpenClaw has Signal support via signal-cli.
|
||||
|
||||
---
|
||||
|
||||
## 5. Plugin/Extension System 🔌
|
||||
|
||||
**Status:** Partially implemented (event hooks exist in `gateway/hooks.py`)
|
||||
**Priority:** Medium
|
||||
|
||||
Full Python plugin interface that goes beyond the current hook system.
|
||||
|
||||
**What other agents do:**
|
||||
- **OpenClaw**: Plugin SDK with tool-send capabilities, lifecycle phase hooks (before-agent-start, after-tool-call, model-override), plugin registry with install/uninstall.
|
||||
- **Pi**: Extensions are TypeScript modules that can register tools, commands, keyboard shortcuts, custom UI widgets, overlays, status lines, dialogs, compaction hooks, raw terminal input listeners. Extremely comprehensive.
|
||||
- **OpenCode**: MCP client support (stdio, SSE, StreamableHTTP), OAuth auth for MCP servers. Also has Copilot/Codex plugins.
|
||||
- **Codex**: Full MCP integration with skill dependencies.
|
||||
- **Cline**: MCP integration + lifecycle hooks with cancellation support.
|
||||
|
||||
**Our approach (phased):**
|
||||
|
||||
### Phase 1: Enhanced hooks
|
||||
- Expand the existing `gateway/hooks.py` to support more events: `before-tool-call`, `after-tool-call`, `before-response`, `context-compress`, `session-end`
|
||||
- Allow hooks to modify tool results (e.g., filter sensitive output)
|
||||
|
||||
### Phase 2: Plugin interface
|
||||
- `~/.hermes/plugins/<name>/plugin.yaml` + `handler.py`
|
||||
- Plugins can: register new tools, add CLI commands, subscribe to events, inject system prompt sections
|
||||
- `hermes plugin list|install|uninstall|create` CLI commands
|
||||
- Plugin discovery and validation on startup
|
||||
|
||||
### Phase 3: MCP support (industry standard) ✅ DONE
|
||||
- ✅ MCP client that connects to external MCP servers (stdio + HTTP/StreamableHTTP)
|
||||
- ✅ Config: `mcp_servers` in config.yaml with connection details
|
||||
- ✅ Each MCP server's tools auto-registered as a dynamic toolset
|
||||
- Future: Resources, Prompts, Progress notifications, `hermes mcp` CLI command
|
||||
|
||||
---
|
||||
|
||||
## 6. MCP (Model Context Protocol) Support 🔗 ✅ DONE
|
||||
|
||||
**Status:** Implemented (PR #301)
|
||||
**Priority:** Complete
|
||||
|
||||
Native MCP client support with stdio and HTTP/StreamableHTTP transports, auto-discovery, reconnection with exponential backoff, env var filtering, and credential stripping. See `docs/mcp.md` for full documentation.
|
||||
|
||||
**Still TODO:**
|
||||
- `hermes mcp` CLI subcommand (list/test/status)
|
||||
- `hermes tools` UI integration for MCP toolsets
|
||||
- MCP Resources and Prompts support
|
||||
- OAuth authentication for remote servers
|
||||
- Progress notifications for long-running tools
|
||||
|
||||
---
|
||||
|
||||
## 8. Filesystem Checkpointing / Rollback 🔄
|
||||
|
||||
**Status:** Not started
|
||||
**Priority:** Low-Medium
|
||||
|
||||
Automatic filesystem snapshots after each agent loop iteration so the user can roll back destructive changes to their project.
|
||||
|
||||
**What other agents do:**
|
||||
- **Cline**: Workspace checkpoints at each step with Compare/Restore UI
|
||||
- **OpenCode**: Git-backed workspace snapshots per step, with weekly gc
|
||||
- **Codex**: Sandboxed execution with commit-per-step, rollback on failure
|
||||
|
||||
**Our approach:**
|
||||
- After each tool call (or batch of tool calls in a single turn) that modifies files, create a lightweight checkpoint of the affected files
|
||||
- Git-based when the project is a repo: auto-commit to a detached/temporary branch (`hermes/checkpoints/<session>`) after each agent turn, squash or discard on session end
|
||||
- Non-git fallback: tar snapshots of changed files in `~/.hermes/checkpoints/<session_id>/`
|
||||
- `hermes rollback` CLI command to restore to a previous checkpoint
|
||||
- Agent-accessible via a `checkpoint` tool: `list` (show available restore points), `restore` (roll back to a named point), `diff` (show what changed since a checkpoint)
|
||||
- Configurable: off by default (opt-in via `config.yaml`), since auto-committing can be surprising
|
||||
- Cleanup: checkpoints expire after session ends (or configurable retention period)
|
||||
- Integration with the terminal backend: works with local, SSH, and Docker backends (snapshots happen on the execution host)
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority Order
|
||||
|
||||
### Tier 1: Next Up
|
||||
|
||||
1. ~~MCP Support -- #6~~ ✅ Done (PR #301)
|
||||
|
||||
### Tier 2: Quality of Life
|
||||
|
||||
3. Local Browser Control via CDP -- #3
|
||||
4. Plugin/Extension System -- #5
|
||||
|
||||
### Tier 3: Nice to Have
|
||||
|
||||
5. Session Branching / Checkpoints -- #7
|
||||
6. Filesystem Checkpointing / Rollback -- #8
|
||||
7. Signal Integration -- #4
|
||||
@@ -10,7 +10,9 @@ Resolution order for text tasks:
|
||||
3. Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY)
|
||||
4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex,
|
||||
wrapped to look like a chat.completions client)
|
||||
5. None
|
||||
5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN)
|
||||
— checked via PROVIDER_REGISTRY entries with auth_type='api_key'
|
||||
6. None
|
||||
|
||||
Resolution order for vision/multimodal tasks:
|
||||
1. OpenRouter
|
||||
@@ -31,6 +33,14 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"zai": "glm-4.5-flash",
|
||||
"kimi-coding": "kimi-k2-turbo-preview",
|
||||
"minimax": "MiniMax-M2.5-highspeed",
|
||||
"minimax-cn": "MiniMax-M2.5-highspeed",
|
||||
}
|
||||
|
||||
# OpenRouter app attribution headers
|
||||
_OR_HEADERS = {
|
||||
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
|
||||
@@ -282,12 +292,58 @@ def _read_codex_access_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
Returns (client, model) for the first provider whose env var is set,
|
||||
or (None, None) if none are configured.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
except ImportError:
|
||||
logger.debug("Could not import PROVIDER_REGISTRY for API-key fallback")
|
||||
return None, None
|
||||
|
||||
for provider_id, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
# Check if any of the provider's env vars are set
|
||||
api_key = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
break
|
||||
if not api_key:
|
||||
continue
|
||||
# Resolve base URL (with optional env-var override)
|
||||
# Kimi Code keys (sk-kimi-) need api.kimi.com/coding/v1
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
if env_url:
|
||||
base_url = env_url.rstrip("/")
|
||||
elif provider_id == "kimi-coding" and api_key.startswith("sk-kimi-"):
|
||||
base_url = "https://api.kimi.com/coding/v1"
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
def get_text_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, model_slug) for text-only auxiliary tasks.
|
||||
|
||||
Falls through OpenRouter -> Nous Portal -> custom endpoint -> Codex OAuth -> (None, None).
|
||||
Falls through OpenRouter -> Nous Portal -> custom endpoint -> Codex OAuth
|
||||
-> direct API-key providers -> (None, None).
|
||||
"""
|
||||
# 1. OpenRouter
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
@@ -323,7 +379,12 @@ def get_text_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
# 5. Nothing available
|
||||
# 5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, etc.)
|
||||
api_client, api_model = _resolve_api_key_provider()
|
||||
if api_client is not None:
|
||||
return api_client, api_model
|
||||
|
||||
# 6. Nothing available
|
||||
logger.debug("Auxiliary text client: none available")
|
||||
return None, None
|
||||
|
||||
@@ -350,6 +411,8 @@ def get_async_text_auxiliary_client():
|
||||
}
|
||||
if "openrouter" in str(sync_client.base_url).lower():
|
||||
async_kwargs["default_headers"] = dict(_OR_HEADERS)
|
||||
elif "api.kimi.com" in str(sync_client.base_url).lower():
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ protecting head and tail context.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
from agent.model_metadata import (
|
||||
@@ -82,11 +82,14 @@ class ContextCompressor:
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> str:
|
||||
"""Generate a concise summary of conversation turns using a fast model."""
|
||||
if not self.client:
|
||||
return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed to save space. The assistant performed various actions and received responses."
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
|
||||
Tries the auxiliary model first, then falls back to the user's main
|
||||
model. Returns None if all attempts fail — the caller should drop
|
||||
the middle turns without a summary rather than inject a useless
|
||||
placeholder.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns_to_summarize:
|
||||
role = msg.get("role", "unknown")
|
||||
@@ -117,28 +120,28 @@ TURNS TO SUMMARIZE:
|
||||
|
||||
Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
|
||||
try:
|
||||
return self._call_summary_model(self.client, self.summary_model, prompt)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
|
||||
# 1. Try the auxiliary model (cheap/fast)
|
||||
if self.client:
|
||||
try:
|
||||
return self._call_summary_model(self.client, self.summary_model, prompt)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
|
||||
|
||||
# Fallback: try the main model's endpoint. This handles the common
|
||||
# case where the user switched providers (e.g. OpenRouter → local LLM)
|
||||
# but a stale API key causes the auxiliary client to pick the old
|
||||
# provider which then fails (402, auth error, etc.).
|
||||
fallback_client, fallback_model = self._get_fallback_client()
|
||||
if fallback_client is not None:
|
||||
try:
|
||||
logger.info("Retrying context summary with fallback client (%s)", fallback_model)
|
||||
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
|
||||
# Success — swap in the working client for future compressions
|
||||
self.client = fallback_client
|
||||
self.summary_model = fallback_model
|
||||
return summary
|
||||
except Exception as fallback_err:
|
||||
logging.warning(f"Fallback summary model also failed: {fallback_err}")
|
||||
# 2. Fallback: try the user's main model endpoint
|
||||
fallback_client, fallback_model = self._get_fallback_client()
|
||||
if fallback_client is not None:
|
||||
try:
|
||||
logger.info("Retrying context summary with main model (%s)", fallback_model)
|
||||
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
|
||||
self.client = fallback_client
|
||||
self.summary_model = fallback_model
|
||||
return summary
|
||||
except Exception as fallback_err:
|
||||
logging.warning(f"Main model summary also failed: {fallback_err}")
|
||||
|
||||
return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed. The assistant performed tool calls and received responses."
|
||||
# 3. All models failed — return None so the caller drops turns without a summary
|
||||
logging.warning("Context compression: no model available for summary. Middle turns will be dropped without summary.")
|
||||
return None
|
||||
|
||||
def _call_summary_model(self, client, model: str, prompt: str) -> str:
|
||||
"""Make the actual LLM call to generate a summary. Raises on failure."""
|
||||
@@ -196,10 +199,111 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
logger.debug("Could not build fallback auxiliary client: %s", exc)
|
||||
return None, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call / tool-result pair integrity helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_call_id(tc) -> str:
|
||||
"""Extract the call ID from a tool_call entry (dict or SimpleNamespace)."""
|
||||
if isinstance(tc, dict):
|
||||
return tc.get("id", "")
|
||||
return getattr(tc, "id", "") or ""
|
||||
|
||||
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Fix orphaned tool_call / tool_result pairs after compression.
|
||||
|
||||
Two failure modes:
|
||||
1. A tool *result* references a call_id whose assistant tool_call was
|
||||
removed (summarized/truncated). The API rejects this with
|
||||
"No tool call found for function call output with call_id ...".
|
||||
2. An assistant message has tool_calls whose results were dropped.
|
||||
The API rejects this because every tool_call must be followed by
|
||||
a tool result with the matching call_id.
|
||||
|
||||
This method removes orphaned results and inserts stub results for
|
||||
orphaned calls so the message list is always well-formed.
|
||||
"""
|
||||
surviving_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid:
|
||||
surviving_call_ids.add(cid)
|
||||
|
||||
result_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
cid = msg.get("tool_call_id")
|
||||
if cid:
|
||||
result_call_ids.add(cid)
|
||||
|
||||
# 1. Remove tool results whose call_id has no matching assistant tool_call
|
||||
orphaned_results = result_call_ids - surviving_call_ids
|
||||
if orphaned_results:
|
||||
messages = [
|
||||
m for m in messages
|
||||
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
]
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
|
||||
|
||||
# 2. Add stub results for assistant tool_calls whose results were dropped
|
||||
missing_results = surviving_call_ids - result_call_ids
|
||||
if missing_results:
|
||||
patched: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid in missing_results:
|
||||
patched.append({
|
||||
"role": "tool",
|
||||
"content": "[Result from earlier conversation — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
})
|
||||
messages = patched
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
|
||||
|
||||
return messages
|
||||
|
||||
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Push a compress-start boundary forward past any orphan tool results.
|
||||
|
||||
If ``messages[idx]`` is a tool result, slide forward until we hit a
|
||||
non-tool message so we don't start the summarised region mid-group.
|
||||
"""
|
||||
while idx < len(messages) and messages[idx].get("role") == "tool":
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Pull a compress-end boundary backward to avoid splitting a
|
||||
tool_call / result group.
|
||||
|
||||
If the message just before ``idx`` is an assistant message with
|
||||
tool_calls, those tool results will start at ``idx`` and would be
|
||||
separated from their parent. Move backwards to include the whole
|
||||
group in the summarised region.
|
||||
"""
|
||||
if idx <= 0 or idx >= len(messages):
|
||||
return idx
|
||||
prev = messages[idx - 1]
|
||||
if prev.get("role") == "assistant" and prev.get("tool_calls"):
|
||||
# The results for this assistant turn sit at idx..idx+k.
|
||||
# Include the assistant message in the summarised region too.
|
||||
idx -= 1
|
||||
return idx
|
||||
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Keeps first N + last N turns, summarizes everything in between.
|
||||
After compression, orphaned tool_call / tool_result pairs are cleaned
|
||||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
@@ -212,6 +316,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
# Adjust boundaries to avoid splitting tool_call/result groups.
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
compress_end = self._align_boundary_backward(messages, compress_end)
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
@@ -219,24 +329,6 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
|
||||
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
|
||||
|
||||
# Truncation fallback when no auxiliary model is available
|
||||
if self.client is None:
|
||||
print("⚠️ Context compression: no auxiliary model available. Falling back to message truncation.")
|
||||
# Keep system message(s) at the front and the protected tail;
|
||||
# simply drop the oldest non-system messages until under threshold.
|
||||
kept = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
kept.append(msg.copy())
|
||||
else:
|
||||
break
|
||||
tail = messages[-self.protect_last_n:]
|
||||
kept.extend(m.copy() for m in tail)
|
||||
self.compression_count += 1
|
||||
if not self.quiet_mode:
|
||||
print(f" ✂️ Truncated: {len(messages)} → {len(kept)} messages (dropped middle turns)")
|
||||
return kept
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
|
||||
@@ -249,13 +341,19 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
msg["content"] = (msg.get("content") or "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
|
||||
compressed.append(msg)
|
||||
|
||||
compressed.append({"role": "user", "content": summary})
|
||||
if summary:
|
||||
compressed.append({"role": "user", "content": summary})
|
||||
else:
|
||||
if not self.quiet_mode:
|
||||
print(" ⚠️ No summary model available — middle turns dropped without summary")
|
||||
|
||||
for i in range(compress_end, n_messages):
|
||||
compressed.append(messages[i].copy())
|
||||
|
||||
self.compression_count += 1
|
||||
|
||||
compressed = self._sanitize_tool_pairs(compressed)
|
||||
|
||||
if not self.quiet_mode:
|
||||
new_estimate = estimate_messages_tokens_rough(compressed)
|
||||
saved_estimate = display_tokens - new_estimate
|
||||
|
||||
@@ -55,6 +55,20 @@ MODEL_PRICING = {
|
||||
# Meta (via providers)
|
||||
"llama-4-maverick": {"input": 0.50, "output": 0.70},
|
||||
"llama-4-scout": {"input": 0.20, "output": 0.30},
|
||||
# Z.AI / GLM (direct provider — pricing not published externally, treat as local)
|
||||
"glm-5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.7": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
|
||||
# Kimi / Moonshot (direct provider — pricing not published externally, treat as local)
|
||||
"kimi-k2.5": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
|
||||
# MiniMax (direct provider — pricing not published externally, treat as local)
|
||||
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
|
||||
}
|
||||
|
||||
# Fallback: unknown/custom models get zero cost (we can't assume pricing
|
||||
|
||||
@@ -49,6 +49,17 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"meta-llama/llama-3.3-70b-instruct": 131072,
|
||||
"deepseek/deepseek-chat-v3": 65536,
|
||||
"qwen/qwen-2.5-72b-instruct": 32768,
|
||||
"glm-4.7": 202752,
|
||||
"glm-5": 202752,
|
||||
"glm-4.5": 131072,
|
||||
"glm-4.5-flash": 131072,
|
||||
"kimi-k2.5": 262144,
|
||||
"kimi-k2-thinking": 262144,
|
||||
"kimi-k2-turbo-preview": 262144,
|
||||
"kimi-k2-0905-preview": 131072,
|
||||
"MiniMax-M2.5": 204800,
|
||||
"MiniMax-M2.5-highspeed": 204800,
|
||||
"MiniMax-M2.1": 204800,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,8 @@ DEFAULT_AGENT_IDENTITY = (
|
||||
"range of tasks including answering questions, writing and editing code, "
|
||||
"analyzing information, creative work, and executing actions via your tools. "
|
||||
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
|
||||
"being genuinely useful over being verbose unless otherwise directed below."
|
||||
"being genuinely useful over being verbose unless otherwise directed below. "
|
||||
"Be targeted and efficient in your exploration and investigations."
|
||||
)
|
||||
|
||||
MEMORY_GUIDANCE = (
|
||||
@@ -102,12 +103,24 @@ PLATFORM_HINTS = {
|
||||
"You are on a text messaging communication platform, Telegram. "
|
||||
"Please do not use markdown as it does not render. "
|
||||
"You can send media files natively: to deliver a file to the user, "
|
||||
"include MEDIA:/absolute/path/to/file in your response. Audio "
|
||||
"(.ogg) sends as voice bubbles. You can also include image URLs "
|
||||
"in markdown format  and they will be sent as native photos."
|
||||
"include MEDIA:/absolute/path/to/file in your response. Images "
|
||||
"(.png, .jpg, .webp) appear as photos, audio (.ogg) sends as voice "
|
||||
"bubbles, and videos (.mp4) play inline. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as native photos."
|
||||
),
|
||||
"discord": (
|
||||
"You are in a Discord server or group chat communicating with your user."
|
||||
"You are in a Discord server or group chat communicating with your user. "
|
||||
"You can send media files natively: include MEDIA:/absolute/path/to/file "
|
||||
"in your response. Images (.png, .jpg, .webp) are sent as photo "
|
||||
"attachments, audio as file attachments. You can also include image URLs "
|
||||
"in markdown format  and they will be sent as attachments."
|
||||
),
|
||||
"slack": (
|
||||
"You are in a Slack workspace communicating with your user. "
|
||||
"You can send media files natively: include MEDIA:/absolute/path/to/file "
|
||||
"in your response. Images (.png, .jpg, .webp) are uploaded as photo "
|
||||
"attachments, audio as file attachments. You can also include image URLs "
|
||||
"in markdown format  and they will be uploaded as attachments."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
@@ -142,12 +155,28 @@ def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _skill_is_platform_compatible(skill_file: Path) -> bool:
|
||||
"""Quick check if a SKILL.md is compatible with the current OS platform.
|
||||
|
||||
Reads just enough to parse the ``platforms`` frontmatter field.
|
||||
Skills without the field (the vast majority) are always compatible.
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
return skill_matches_platform(frontmatter)
|
||||
except Exception:
|
||||
return True # Err on the side of showing the skill
|
||||
|
||||
|
||||
def build_skills_system_prompt() -> str:
|
||||
"""Build a compact skill index for the system prompt.
|
||||
|
||||
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
|
||||
Includes per-skill descriptions from frontmatter so the model can
|
||||
match skills by meaning, not just name.
|
||||
Filters out skills incompatible with the current OS platform.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
skills_dir = hermes_home / "skills"
|
||||
@@ -159,6 +188,9 @@ def build_skills_system_prompt() -> str:
|
||||
# Each entry: (skill_name, description)
|
||||
skills_by_category: dict[str, list[tuple[str, str]]] = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not _skill_is_platform_compatible(skill_file):
|
||||
continue
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
|
||||
@@ -22,7 +22,7 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
global _skill_commands
|
||||
_skill_commands = {}
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
|
||||
if not SKILLS_DIR.exists():
|
||||
return _skill_commands
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
@@ -31,6 +31,9 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
try:
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
name = frontmatter.get('name', skill_md.parent.name)
|
||||
description = frontmatter.get('description', '')
|
||||
if not description:
|
||||
|
||||
@@ -1112,7 +1112,7 @@ def main(
|
||||
batch_size: int = None,
|
||||
run_name: str = None,
|
||||
distribution: str = "default",
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-sonnet-4.6",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 10,
|
||||
@@ -1155,7 +1155,7 @@ def main(
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh")
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "medium")
|
||||
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
|
||||
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
@@ -1216,7 +1216,7 @@ def main(
|
||||
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
|
||||
|
||||
# Build reasoning_config from CLI flags
|
||||
# --reasoning_disabled takes priority, then --reasoning_effort, then default (xhigh)
|
||||
# --reasoning_disabled takes priority, then --reasoning_effort, then default (medium)
|
||||
reasoning_config = None
|
||||
if reasoning_disabled:
|
||||
# Completely disable reasoning/thinking tokens
|
||||
|
||||
@@ -13,6 +13,10 @@ model:
|
||||
# "auto" - Use Nous Portal if logged in, otherwise OpenRouter/env vars (default)
|
||||
# "openrouter" - Always use OpenRouter API key from OPENROUTER_API_KEY
|
||||
# "nous" - Always use Nous Portal (requires: hermes login)
|
||||
# "zai" - Use z.ai / ZhipuAI GLM models (requires: GLM_API_KEY)
|
||||
# "kimi-coding"- Use Kimi / Moonshot AI models (requires: KIMI_API_KEY)
|
||||
# "minimax" - Use MiniMax global endpoint (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - Use MiniMax China endpoint (requires: MINIMAX_CN_API_KEY)
|
||||
# Can also be overridden with --provider flag or HERMES_INFERENCE_PROVIDER env var.
|
||||
provider: "auto"
|
||||
|
||||
@@ -46,6 +50,16 @@ model:
|
||||
# # Data policy: "allow" (default) or "deny" to exclude providers that may store data
|
||||
# # data_collection: "deny"
|
||||
|
||||
# =============================================================================
|
||||
# Git Worktree Isolation
|
||||
# =============================================================================
|
||||
# When enabled, each CLI session creates an isolated git worktree so multiple
|
||||
# agents can work on the same repo concurrently without file collisions.
|
||||
# Equivalent to always passing --worktree / -w on the command line.
|
||||
#
|
||||
# worktree: true # Always create a worktree when in a git repo
|
||||
# worktree: false # Default — only create when -w flag is passed
|
||||
|
||||
# =============================================================================
|
||||
# Terminal Tool Configuration
|
||||
# =============================================================================
|
||||
@@ -281,7 +295,7 @@ agent:
|
||||
# Reasoning effort level (OpenRouter and Nous Portal)
|
||||
# Controls how much "thinking" the model does before responding.
|
||||
# Options: "xhigh" (max), "high", "medium", "low", "minimal", "none" (disable)
|
||||
reasoning_effort: "xhigh"
|
||||
reasoning_effort: "medium"
|
||||
|
||||
# Predefined personalities (use with /personality command)
|
||||
personalities:
|
||||
|
||||
45
cron/jobs.py
45
cron/jobs.py
@@ -14,6 +14,8 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
try:
|
||||
from croniter import croniter
|
||||
HAS_CRONITER = True
|
||||
@@ -128,7 +130,7 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
# Duration like "30m", "2h", "1d" → one-shot from now
|
||||
try:
|
||||
minutes = parse_duration(schedule)
|
||||
run_at = datetime.now() + timedelta(minutes=minutes)
|
||||
run_at = _hermes_now() + timedelta(minutes=minutes)
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": run_at.isoformat(),
|
||||
@@ -146,37 +148,50 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_aware(dt: datetime) -> datetime:
|
||||
"""Make a naive datetime tz-aware using the configured timezone.
|
||||
|
||||
Handles backward compatibility: timestamps stored before timezone support
|
||||
are naive (server-local). We assume they were in the same timezone as
|
||||
the current configuration so comparisons work without crashing.
|
||||
"""
|
||||
if dt.tzinfo is None:
|
||||
tz = _hermes_now().tzinfo
|
||||
return dt.replace(tzinfo=tz)
|
||||
return dt
|
||||
|
||||
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
|
||||
|
||||
Returns ISO timestamp string, or None if no more runs.
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
now = _hermes_now()
|
||||
|
||||
if schedule["kind"] == "once":
|
||||
run_at = datetime.fromisoformat(schedule["run_at"])
|
||||
run_at = _ensure_aware(datetime.fromisoformat(schedule["run_at"]))
|
||||
# If in the future, return it; if in the past, no more runs
|
||||
return schedule["run_at"] if run_at > now else None
|
||||
|
||||
|
||||
elif schedule["kind"] == "interval":
|
||||
minutes = schedule["minutes"]
|
||||
if last_run_at:
|
||||
# Next run is last_run + interval
|
||||
last = datetime.fromisoformat(last_run_at)
|
||||
last = _ensure_aware(datetime.fromisoformat(last_run_at))
|
||||
next_run = last + timedelta(minutes=minutes)
|
||||
else:
|
||||
# First run is now + interval
|
||||
next_run = now + timedelta(minutes=minutes)
|
||||
return next_run.isoformat()
|
||||
|
||||
|
||||
elif schedule["kind"] == "cron":
|
||||
if not HAS_CRONITER:
|
||||
return None
|
||||
cron = croniter(schedule["expr"], now)
|
||||
next_run = cron.get_next(datetime)
|
||||
return next_run.isoformat()
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -204,7 +219,7 @@ def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_')
|
||||
try:
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
json.dump({"jobs": jobs, "updated_at": datetime.now().isoformat()}, f, indent=2)
|
||||
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, JOBS_FILE)
|
||||
@@ -249,7 +264,7 @@ def create_job(
|
||||
deliver = "origin" if origin else "local"
|
||||
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
now = datetime.now().isoformat()
|
||||
now = _hermes_now().isoformat()
|
||||
|
||||
job = {
|
||||
"id": job_id,
|
||||
@@ -328,7 +343,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
if job["id"] == job_id:
|
||||
now = datetime.now().isoformat()
|
||||
now = _hermes_now().isoformat()
|
||||
job["last_run_at"] = now
|
||||
job["last_status"] = "ok" if success else "error"
|
||||
job["last_error"] = error if not success else None
|
||||
@@ -361,7 +376,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
now = datetime.now()
|
||||
now = _hermes_now()
|
||||
jobs = load_jobs()
|
||||
due = []
|
||||
|
||||
@@ -373,7 +388,7 @@ def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
if not next_run:
|
||||
continue
|
||||
|
||||
next_run_dt = datetime.fromisoformat(next_run)
|
||||
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
|
||||
if next_run_dt <= now:
|
||||
due.append(job)
|
||||
|
||||
@@ -386,7 +401,7 @@ def save_job_output(job_id: str, output: str):
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
output_file = job_output_dir / f"{timestamp}.md"
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
@@ -27,6 +27,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add parent directory to path for imports
|
||||
@@ -174,6 +176,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
model = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
|
||||
|
||||
# Load config.yaml for model, reasoning, prefill, toolsets, provider routing
|
||||
_cfg = {}
|
||||
try:
|
||||
import yaml
|
||||
_cfg_path = str(_hermes_home / "config.yaml")
|
||||
@@ -188,6 +192,41 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reasoning config from env or config.yaml
|
||||
reasoning_config = None
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
effort = str(_cfg.get("agent", {}).get("reasoning_effort", "")).strip()
|
||||
if effort and effort.lower() != "none":
|
||||
valid = ("xhigh", "high", "medium", "low", "minimal")
|
||||
if effort.lower() in valid:
|
||||
reasoning_config = {"enabled": True, "effort": effort.lower()}
|
||||
elif effort.lower() == "none":
|
||||
reasoning_config = {"enabled": False}
|
||||
|
||||
# Prefill messages from env or config.yaml
|
||||
prefill_messages = None
|
||||
prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "")
|
||||
if prefill_file:
|
||||
import json as _json
|
||||
pfpath = Path(prefill_file).expanduser()
|
||||
if not pfpath.is_absolute():
|
||||
pfpath = _hermes_home / pfpath
|
||||
if pfpath.exists():
|
||||
try:
|
||||
with open(pfpath, "r", encoding="utf-8") as _pf:
|
||||
prefill_messages = _json.load(_pf)
|
||||
if not isinstance(prefill_messages, list):
|
||||
prefill_messages = None
|
||||
except Exception:
|
||||
prefill_messages = None
|
||||
|
||||
# Max iterations
|
||||
max_iterations = _cfg.get("agent", {}).get("max_turns") or _cfg.get("max_turns") or 90
|
||||
|
||||
# Provider routing
|
||||
pr = _cfg.get("provider_routing", {})
|
||||
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
format_runtime_provider_error,
|
||||
@@ -206,8 +245,15 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
base_url=runtime.get("base_url"),
|
||||
provider=runtime.get("provider"),
|
||||
api_mode=runtime.get("api_mode"),
|
||||
max_iterations=max_iterations,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
@@ -219,7 +265,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
@@ -241,7 +287,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
output = f"""# Cron Job: {job_name} (FAILED)
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
@@ -297,11 +343,11 @@ def tick(verbose: bool = True) -> int:
|
||||
due_jobs = get_due_jobs()
|
||||
|
||||
if verbose and not due_jobs:
|
||||
logger.info("%s - No jobs due", datetime.now().strftime('%H:%M:%S'))
|
||||
logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S'))
|
||||
return 0
|
||||
|
||||
if verbose:
|
||||
logger.info("%s - %s job(s) due", datetime.now().strftime('%H:%M:%S'), len(due_jobs))
|
||||
logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs))
|
||||
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
|
||||
@@ -115,8 +115,9 @@
|
||||
- `edit_message(chat_id, message_id, content)` — edit sent messages
|
||||
|
||||
### What's missing:
|
||||
- **Telegram:** No override for `send_document` or `send_image_file` — falls back to text!
|
||||
- **Discord:** No override for `send_document` — falls back to text!
|
||||
- **Telegram:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
|
||||
- **Discord:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
|
||||
- **Slack:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
|
||||
- **WhatsApp:** Has `send_document` and `send_image_file` via bridge — COMPLETE.
|
||||
- The base class defaults just send "📎 File: /path" as text — useless for actual file delivery.
|
||||
|
||||
@@ -126,13 +127,13 @@
|
||||
- `send()` — MarkdownV2 text with fallback to plain
|
||||
- `send_voice()` — `.ogg`/`.opus` as `send_voice()`, others as `send_audio()`
|
||||
- `send_image()` — URL-based via `send_photo()`
|
||||
- `send_image_file()` — local file via `send_photo(photo=open(path, 'rb'))` ✅
|
||||
- `send_animation()` — GIF via `send_animation()`
|
||||
- `send_typing()` — "typing" chat action
|
||||
- `edit_message()` — edit text messages
|
||||
|
||||
### MISSING:
|
||||
- **`send_document()` NOT overridden** — Need to add `self._bot.send_document(chat_id, document=open(file_path, 'rb'), ...)`
|
||||
- **`send_image_file()` NOT overridden** — Need to add `self._bot.send_photo(chat_id, photo=open(path, 'rb'), ...)`
|
||||
- **`send_video()` NOT overridden** — Need to add `self._bot.send_video(...)`
|
||||
|
||||
## 8. gateway/platforms/discord.py — Send Method Analysis
|
||||
@@ -141,12 +142,12 @@
|
||||
- `send()` — text messages with chunking
|
||||
- `send_voice()` — discord.File attachment
|
||||
- `send_image()` — downloads URL, creates discord.File attachment
|
||||
- `send_image_file()` — local file via discord.File attachment ✅
|
||||
- `send_typing()` — channel.typing()
|
||||
- `edit_message()` — edit text messages
|
||||
|
||||
### MISSING:
|
||||
- **`send_document()` NOT overridden** — Need to add discord.File attachment
|
||||
- **`send_image_file()` NOT overridden** — Need to add discord.File from local path
|
||||
- **`send_video()` NOT overridden** — Need to add discord.File attachment
|
||||
|
||||
## 9. gateway/run.py — User File Attachment Handling
|
||||
|
||||
@@ -195,8 +195,12 @@ environments/
|
||||
│ └── hermes_swe_env.py
|
||||
│
|
||||
└── benchmarks/ # Evaluation benchmarks
|
||||
└── terminalbench_2/
|
||||
└── terminalbench2_env.py
|
||||
├── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
|
||||
│ └── terminalbench2_env.py
|
||||
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
|
||||
│ └── tblite_env.py
|
||||
└── yc_bench/ # Long-horizon strategic benchmark
|
||||
└── yc_bench_env.py
|
||||
```
|
||||
|
||||
## Concrete Environments
|
||||
|
||||
115
environments/benchmarks/yc_bench/README.md
Normal file
115
environments/benchmarks/yc_bench/README.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# YC-Bench: Long-Horizon Agent Benchmark
|
||||
|
||||
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
|
||||
|
||||
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install yc-bench (optional dependency)
|
||||
pip install "hermes-agent[yc-bench]"
|
||||
|
||||
# Or install from source
|
||||
git clone https://github.com/collinear-ai/yc-bench
|
||||
cd yc-bench && pip install -e .
|
||||
|
||||
# Verify
|
||||
yc-bench --help
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# From the repo root:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
|
||||
# Or directly:
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
# Override model:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
# Quick single-preset test:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
HermesAgentLoop (our agent)
|
||||
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
|
||||
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
|
||||
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
|
||||
-> ... (100-500 turns per run)
|
||||
```
|
||||
|
||||
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
|
||||
|
||||
### Simulation Mechanics
|
||||
|
||||
- **4 skill domains**: research, inference, data_environment, training
|
||||
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
|
||||
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
|
||||
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
|
||||
- **Financial pressure**: Monthly payroll, bankruptcy = game over
|
||||
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
|
||||
|
||||
### Difficulty Presets
|
||||
|
||||
| Preset | Employees | Tasks | Focus |
|
||||
|-----------|-----------|-------|-------|
|
||||
| tutorial | 3 | 50 | Basic loop mechanics |
|
||||
| easy | 5 | 100 | Throughput awareness |
|
||||
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
|
||||
| **hard** | 7 | 200 | Precise ETA reasoning |
|
||||
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
|
||||
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
|
||||
|
||||
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
|
||||
|
||||
### Scoring
|
||||
|
||||
```
|
||||
composite = 0.5 × survival + 0.5 × normalised_funds
|
||||
```
|
||||
|
||||
- **Survival** (binary): Did the company avoid bankruptcy?
|
||||
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
|
||||
|
||||
## Configuration
|
||||
|
||||
Key fields in `default.yaml`:
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
|
||||
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
|
||||
| `max_agent_turns` | 200 | Max LLM calls per run |
|
||||
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
|
||||
| `survival_weight` | 0.5 | Weight of survival in composite score |
|
||||
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
|
||||
| `horizon_years` | null | Override horizon (null = auto from preset) |
|
||||
|
||||
## Cost & Time Estimates
|
||||
|
||||
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
|
||||
|
||||
| Preset | Turns | Time | Est. Cost |
|
||||
|--------|-------|------|-----------|
|
||||
| fast_test | ~50 | 5-10 min | $1-5 |
|
||||
| medium | ~200 | 20-40 min | $5-15 |
|
||||
| hard | ~300 | 30-60 min | $10-25 |
|
||||
|
||||
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
|
||||
|
||||
## References
|
||||
|
||||
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
|
||||
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
|
||||
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)
|
||||
0
environments/benchmarks/yc_bench/__init__.py
Normal file
0
environments/benchmarks/yc_bench/__init__.py
Normal file
43
environments/benchmarks/yc_bench/default.yaml
Normal file
43
environments/benchmarks/yc_bench/default.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
# YC-Bench Evaluation -- Default Configuration
|
||||
#
|
||||
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
|
||||
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal"]
|
||||
max_agent_turns: 200
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.0
|
||||
terminal_backend: "local"
|
||||
terminal_timeout: 60
|
||||
presets: ["fast_test", "medium", "hard"]
|
||||
seeds: [1, 2, 3]
|
||||
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
|
||||
survival_weight: 0.5 # weight of binary survival in composite score
|
||||
funds_weight: 0.5 # weight of normalised final funds in composite score
|
||||
db_dir: "/tmp/yc_bench_dbs"
|
||||
company_name: "BenchCo"
|
||||
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "yc-bench"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
34
environments/benchmarks/yc_bench/run_eval.sh
Executable file
34
environments/benchmarks/yc_bench/run_eval.sh
Executable file
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
# YC-Bench Evaluation
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
#
|
||||
# Run a single preset:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/yc-bench
|
||||
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "YC-Bench Evaluation"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
847
environments/benchmarks/yc_bench/yc_bench_env.py
Normal file
847
environments/benchmarks/yc_bench/yc_bench_env.py
Normal file
@@ -0,0 +1,847 @@
|
||||
"""
|
||||
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
|
||||
|
||||
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
|
||||
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
|
||||
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
|
||||
interacting exclusively via CLI subprocess calls against a SQLite-backed
|
||||
discrete-event simulation.
|
||||
|
||||
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
|
||||
multi-turn strategic coherence -- whether an agent can manage compounding
|
||||
decisions over hundreds of turns without going bankrupt.
|
||||
|
||||
This is an eval-only environment. Run via:
|
||||
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
|
||||
2. evaluate() -- Iterates over all runs sequentially through:
|
||||
a. rollout_and_score_eval() -- Per-run agent loop
|
||||
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
|
||||
- Runs HermesAgentLoop with terminal tool only
|
||||
- Reads final SQLite DB to extract score
|
||||
- Returns survival (0/1) + normalised funds score
|
||||
b. Aggregates per-preset and overall metrics
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
|
||||
- Deterministic: same seed + preset = same world (SHA256-based RNG)
|
||||
- Multi-dimensional scoring: survival + normalised final funds
|
||||
- Per-preset difficulty breakdown in results
|
||||
- Isolated SQLite DB per run (no cross-run state leakage)
|
||||
|
||||
Requires: pip install hermes-agent[yc-bench]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# System prompt
|
||||
# =============================================================================
|
||||
|
||||
YC_BENCH_SYSTEM_PROMPT = """\
|
||||
You are the autonomous CEO of an early-stage AI startup in a deterministic
|
||||
business simulation. You manage the company exclusively through the `yc-bench`
|
||||
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
|
||||
without going bankrupt, while **maximising final funds**.
|
||||
|
||||
## Simulation Mechanics
|
||||
|
||||
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
|
||||
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige − 1))`.
|
||||
- **Domains**: There are 4 skill domains: **research**, **inference**,
|
||||
**data_environment**, and **training**. Each has its own prestige level
|
||||
(1.0-10.0). Higher prestige unlocks better-paying tasks.
|
||||
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
|
||||
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
|
||||
is the number of active tasks assigned to that employee. Focus beats breadth.
|
||||
- **Payroll**: Deducted automatically on the first business day of each month.
|
||||
Running out of funds = bankruptcy = game over.
|
||||
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
|
||||
Time only advances when you call `yc-bench sim resume`.
|
||||
|
||||
## Task Lifecycle
|
||||
|
||||
1. Browse market tasks with `market browse`
|
||||
2. Accept a task with `task accept` (this sets its deadline)
|
||||
3. Assign employees with `task assign`
|
||||
4. Dispatch with `task dispatch` to start work
|
||||
5. Call `sim resume` to advance time and let employees make progress
|
||||
6. Tasks complete when all domain requirements are fulfilled
|
||||
|
||||
**Penalties for failure vary by difficulty preset.** Completing a task on time
|
||||
earns full reward + prestige gain. Missing a deadline or cancelling a task
|
||||
incurs prestige penalties -- cancelling is always more costly than letting a
|
||||
task fail, so cancel only as a last resort.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
### Observe
|
||||
- `yc-bench company status` -- funds, prestige, runway
|
||||
- `yc-bench employee list` -- skills, salary, active tasks
|
||||
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
|
||||
- `yc-bench task list [--status active|planned]` -- your tasks
|
||||
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
|
||||
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
|
||||
- `yc-bench report monthly` -- monthly P&L
|
||||
|
||||
### Act
|
||||
- `yc-bench task accept --task-id UUID` -- accept from market
|
||||
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
|
||||
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
|
||||
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
|
||||
- `yc-bench sim resume` -- advance simulation clock
|
||||
|
||||
### Memory (persists across context truncation)
|
||||
- `yc-bench scratchpad read` -- read your persistent notes
|
||||
- `yc-bench scratchpad write --content "text"` -- overwrite notes
|
||||
- `yc-bench scratchpad append --content "text"` -- append to notes
|
||||
- `yc-bench scratchpad clear` -- clear notes
|
||||
|
||||
## Strategy Guidelines
|
||||
|
||||
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
|
||||
high-reward tasks. Don't spread thin across all 4 domains early on.
|
||||
2. **Focus employees** -- assigning one employee to many tasks halves their
|
||||
throughput per additional task. Keep assignments concentrated.
|
||||
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
|
||||
employee assignments. This persists even if conversation context is truncated.
|
||||
4. **Monitor runway** -- always know how many months of payroll you can cover.
|
||||
Accept high-reward tasks before payroll dates.
|
||||
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
|
||||
into prestige loss, locking you out of profitable contracts.
|
||||
6. Use `finance ledger` and `report monthly` to track revenue trends.
|
||||
|
||||
## Your Turn
|
||||
|
||||
Each turn:
|
||||
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
|
||||
2. Check for completed tasks and pending deadlines.
|
||||
3. Browse market for profitable tasks within your prestige level.
|
||||
4. Accept, assign, and dispatch tasks strategically.
|
||||
5. Call `yc-bench sim resume` to advance time.
|
||||
6. Repeat until the simulation ends.
|
||||
|
||||
Think step by step before acting."""
|
||||
|
||||
# Starting funds in cents ($250,000)
|
||||
INITIAL_FUNDS_CENTS = 25_000_000
|
||||
|
||||
# Default horizon per preset (years)
|
||||
_PRESET_HORIZONS = {
|
||||
"tutorial": 1,
|
||||
"easy": 1,
|
||||
"medium": 1,
|
||||
"hard": 1,
|
||||
"nightmare": 1,
|
||||
"fast_test": 1,
|
||||
"default": 3,
|
||||
"high_reward": 1,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the YC-Bench evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
|
||||
preset selection, seed control, scoring, and simulation parameters.
|
||||
"""
|
||||
|
||||
presets: List[str] = Field(
|
||||
default=["fast_test", "medium", "hard"],
|
||||
description="YC-Bench preset names to evaluate.",
|
||||
)
|
||||
seeds: List[int] = Field(
|
||||
default=[1, 2, 3],
|
||||
description="Random seeds -- each preset x seed = one run.",
|
||||
)
|
||||
run_timeout: int = Field(
|
||||
default=3600,
|
||||
description="Maximum wall-clock seconds per run. Default 60 minutes.",
|
||||
)
|
||||
survival_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of survival (0/1) in composite score.",
|
||||
)
|
||||
funds_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of normalised final funds in composite score.",
|
||||
)
|
||||
db_dir: str = Field(
|
||||
default="/tmp/yc_bench_dbs",
|
||||
description="Directory for per-run SQLite databases.",
|
||||
)
|
||||
horizon_years: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Simulation horizon in years. If None (default), inferred from "
|
||||
"preset name (1 year for most, 3 for 'default')."
|
||||
),
|
||||
)
|
||||
company_name: str = Field(
|
||||
default="BenchCo",
|
||||
description="Name of the simulated company.",
|
||||
)
|
||||
start_date: str = Field(
|
||||
default="01/01/2025",
|
||||
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scoring helpers
|
||||
# =============================================================================
|
||||
|
||||
def _read_final_score(db_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read final game state from a YC-Bench SQLite database.
|
||||
|
||||
Returns dict with final_funds_cents (int), survived (bool),
|
||||
terminal_reason (str).
|
||||
|
||||
Note: yc-bench table names are plural -- 'companies' not 'company',
|
||||
'sim_events' not 'simulation_log'.
|
||||
"""
|
||||
if not os.path.exists(db_path):
|
||||
logger.warning("DB not found at %s", db_path)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": "db_missing",
|
||||
}
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
|
||||
# Read final funds from the 'companies' table
|
||||
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
funds = row[0] if row else 0
|
||||
|
||||
# Determine terminal reason from 'sim_events' table
|
||||
terminal_reason = "unknown"
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT event_type FROM sim_events "
|
||||
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
|
||||
"ORDER BY scheduled_at DESC LIMIT 1"
|
||||
)
|
||||
event_row = cur.fetchone()
|
||||
if event_row:
|
||||
terminal_reason = event_row[0]
|
||||
except sqlite3.OperationalError:
|
||||
# Table may not exist if simulation didn't progress
|
||||
pass
|
||||
|
||||
survived = funds >= 0 and terminal_reason != "bankruptcy"
|
||||
return {
|
||||
"final_funds_cents": funds,
|
||||
"survived": survived,
|
||||
"terminal_reason": terminal_reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to read DB %s: %s", db_path, e)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": f"db_error: {e}",
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _compute_composite_score(
|
||||
final_funds_cents: int,
|
||||
survived: bool,
|
||||
survival_weight: float = 0.5,
|
||||
funds_weight: float = 0.5,
|
||||
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score from survival and final funds.
|
||||
|
||||
Score = survival_weight * survival_score
|
||||
+ funds_weight * normalised_funds_score
|
||||
|
||||
Normalised funds uses log-scale relative to initial capital:
|
||||
- funds <= 0: 0.0
|
||||
- funds == initial: ~0.15
|
||||
- funds == 10x: ~0.52
|
||||
- funds == 100x: 1.0
|
||||
"""
|
||||
survival_score = 1.0 if survived else 0.0
|
||||
|
||||
if final_funds_cents <= 0:
|
||||
funds_score = 0.0
|
||||
else:
|
||||
max_ratio = 100.0
|
||||
ratio = final_funds_cents / max(initial_funds_cents, 1)
|
||||
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
|
||||
|
||||
return survival_weight * survival_score + funds_weight * funds_score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
YC-Bench long-horizon agent benchmark environment (eval-only).
|
||||
|
||||
Each eval item is a (preset, seed) pair. The environment initialises the
|
||||
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
|
||||
a competing built-in agent loop). The HermesAgentLoop then drives the
|
||||
interaction by calling individual yc-bench CLI commands via the terminal tool.
|
||||
|
||||
After the agent loop ends, the SQLite DB is read to extract the final score.
|
||||
|
||||
Scoring:
|
||||
composite = 0.5 * survival + 0.5 * normalised_funds
|
||||
"""
|
||||
|
||||
name = "yc-bench"
|
||||
env_config_cls = YCBenchEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
|
||||
env_config = YCBenchEvalConfig(
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
max_agent_turns=200,
|
||||
max_token_length=32000,
|
||||
agent_temperature=0.0,
|
||||
system_prompt=YC_BENCH_SYSTEM_PROMPT,
|
||||
terminal_backend="local",
|
||||
terminal_timeout=60,
|
||||
presets=["fast_test", "medium", "hard"],
|
||||
seeds=[1, 2, 3],
|
||||
run_timeout=3600,
|
||||
survival_weight=0.5,
|
||||
funds_weight=0.5,
|
||||
db_dir="/tmp/yc_bench_dbs",
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="yc-bench",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Verify yc-bench is installed and build the eval matrix."""
|
||||
# Verify yc-bench CLI is available
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise FileNotFoundError
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
raise RuntimeError(
|
||||
"yc-bench CLI not found. Install with:\n"
|
||||
' pip install "hermes-agent[yc-bench]"\n'
|
||||
"Or: git clone https://github.com/collinear-ai/yc-bench "
|
||||
"&& cd yc-bench && pip install -e ."
|
||||
)
|
||||
print("yc-bench CLI verified.")
|
||||
|
||||
# Build eval matrix: preset x seed
|
||||
self.all_eval_items = [
|
||||
{"preset": preset, "seed": seed}
|
||||
for preset in self.config.presets
|
||||
for seed in self.config.seeds
|
||||
]
|
||||
self.iter = 0
|
||||
|
||||
os.makedirs(self.config.db_dir, exist_ok=True)
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL log for crash-safe result persistence
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = threading.Lock()
|
||||
|
||||
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
|
||||
for item in self.all_eval_items:
|
||||
print(f" preset={item['preset']!r} seed={item['seed']}")
|
||||
print(f"Streaming results to: {self._streaming_path}\n")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single run result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs (eval-only -- not used)
|
||||
# =========================================================================
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
return (
|
||||
f"A new YC-Bench simulation has been initialized "
|
||||
f"(preset='{preset}', seed={seed}).\n"
|
||||
f"Your company '{self.config.company_name}' is ready.\n\n"
|
||||
"Begin by calling:\n"
|
||||
"1. `yc-bench company status` -- see your starting funds and prestige\n"
|
||||
"2. `yc-bench employee list` -- see your team and their skills\n"
|
||||
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
|
||||
"you can take\n\n"
|
||||
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
|
||||
"`yc-bench sim resume` to advance time. Repeat this loop until the "
|
||||
"simulation ends (horizon reached or bankruptcy)."
|
||||
)
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Per-run evaluation
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single (preset, seed) run.
|
||||
|
||||
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
|
||||
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
|
||||
3. Runs HermesAgentLoop with terminal tool
|
||||
4. Reads SQLite DB to compute final score
|
||||
5. Returns result dict with survival, funds, and composite score
|
||||
"""
|
||||
preset = eval_item["preset"]
|
||||
seed = eval_item["seed"]
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
run_key = f"{preset}_seed{seed}_{run_id}"
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
|
||||
run_start = time.time()
|
||||
|
||||
# Isolated DB per run -- prevents cross-run state leakage
|
||||
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
||||
os.environ["YC_BENCH_EXPERIMENT"] = preset
|
||||
|
||||
# Determine horizon: explicit config override > preset lookup > default 1
|
||||
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
|
||||
|
||||
try:
|
||||
# ----------------------------------------------------------
|
||||
# Step 1: Initialise the simulation via CLI
|
||||
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
|
||||
# `yc-bench run` starts yc-bench's own LLM agent loop (via
|
||||
# LiteLLM), which would compete with our HermesAgentLoop.
|
||||
# `sim init` just sets up the world and returns.
|
||||
# ----------------------------------------------------------
|
||||
init_cmd = [
|
||||
"yc-bench", "sim", "init",
|
||||
"--seed", str(seed),
|
||||
"--start-date", self.config.start_date,
|
||||
"--company-name", self.config.company_name,
|
||||
"--horizon-years", str(horizon),
|
||||
]
|
||||
init_result = subprocess.run(
|
||||
init_cmd, capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
if init_result.returncode != 0:
|
||||
error_msg = (init_result.stderr or init_result.stdout).strip()
|
||||
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
|
||||
|
||||
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 2: Run the HermesAgentLoop
|
||||
# ----------------------------------------------------------
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": self.format_prompt(eval_item)},
|
||||
]
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=run_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 3: Read final score from the simulation DB
|
||||
# ----------------------------------------------------------
|
||||
score_data = _read_final_score(db_path)
|
||||
final_funds = score_data["final_funds_cents"]
|
||||
survived = score_data["survived"]
|
||||
terminal_reason = score_data["terminal_reason"]
|
||||
|
||||
composite = _compute_composite_score(
|
||||
final_funds_cents=final_funds,
|
||||
survived=survived,
|
||||
survival_weight=self.config.survival_weight,
|
||||
funds_weight=self.config.funds_weight,
|
||||
)
|
||||
|
||||
elapsed = time.time() - run_start
|
||||
status = "SURVIVED" if survived else "BANKRUPT"
|
||||
if final_funds >= 0:
|
||||
funds_str = f"${final_funds / 100:,.0f}"
|
||||
else:
|
||||
funds_str = f"-${abs(final_funds) / 100:,.0f}"
|
||||
|
||||
tqdm.write(
|
||||
f" [{status}] preset={preset!r} seed={seed} "
|
||||
f"funds={funds_str} score={composite:.3f} "
|
||||
f"turns={result.turns_used} ({elapsed:.0f}s)"
|
||||
)
|
||||
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": survived,
|
||||
"final_funds_cents": final_funds,
|
||||
"final_funds_usd": final_funds / 100,
|
||||
"terminal_reason": terminal_reason,
|
||||
"composite_score": composite,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"elapsed_seconds": elapsed,
|
||||
"db_path": db_path,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - run_start
|
||||
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
|
||||
tqdm.write(
|
||||
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"error: {e}",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": str(e),
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""Wrap a single rollout with a wall-clock timeout."""
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.run_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] preset={preset!r} seed={seed} "
|
||||
f"(exceeded {self.config.run_timeout}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": "timeout",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run YC-Bench evaluation over all (preset, seed) combinations.
|
||||
|
||||
Runs sequentially -- each run is 100-500 turns, parallelising would
|
||||
be prohibitively expensive and cause env var conflicts.
|
||||
"""
|
||||
start_time = time.time()
|
||||
from tqdm import tqdm
|
||||
|
||||
# --- tqdm-compatible logging handler (TB2 pattern) ---
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
root = logging.getLogger()
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(levelname)s %(name)s: %(message)s")
|
||||
)
|
||||
root.handlers = [handler]
|
||||
for noisy in ("httpx", "openai"):
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
|
||||
# --- Print config summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting YC-Bench Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Presets: {self.config.presets}")
|
||||
print(f" Seeds: {self.config.seeds}")
|
||||
print(f" Total runs: {len(self.all_eval_items)}")
|
||||
print(f" Max turns/run: {self.config.max_agent_turns}")
|
||||
print(f" Run timeout: {self.config.run_timeout}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = []
|
||||
pbar = tqdm(
|
||||
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
|
||||
)
|
||||
|
||||
try:
|
||||
for item in self.all_eval_items:
|
||||
result = await self._run_with_timeout(item)
|
||||
results.append(result)
|
||||
survived_count = sum(1 for r in results if r.get("survived"))
|
||||
pbar.set_postfix_str(
|
||||
f"survived={survived_count}/{len(results)}"
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
|
||||
pbar.close()
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
return
|
||||
|
||||
pbar.close()
|
||||
end_time = time.time()
|
||||
|
||||
# --- Compute metrics ---
|
||||
valid = [r for r in results if r is not None]
|
||||
if not valid:
|
||||
print("Warning: No valid results.")
|
||||
return
|
||||
|
||||
total = len(valid)
|
||||
survived_total = sum(1 for r in valid if r.get("survived"))
|
||||
survival_rate = survived_total / total if total else 0.0
|
||||
avg_score = (
|
||||
sum(r.get("composite_score", 0) for r in valid) / total
|
||||
if total
|
||||
else 0.0
|
||||
)
|
||||
|
||||
preset_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid:
|
||||
preset_results[r["preset"]].append(r)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/survival_rate": survival_rate,
|
||||
"eval/avg_composite_score": avg_score,
|
||||
"eval/total_runs": total,
|
||||
"eval/survived_runs": survived_total,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
key = preset.replace("-", "_")
|
||||
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
|
||||
eval_metrics[f"eval/avg_score_{key}"] = pa
|
||||
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# --- Print summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("YC-Bench Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(
|
||||
f"Overall survival rate: {survival_rate:.1%} "
|
||||
f"({survived_total}/{total})"
|
||||
)
|
||||
print(f"Average composite score: {avg_score:.4f}")
|
||||
print(f"Evaluation time: {end_time - start_time:.1f}s")
|
||||
|
||||
print("\nPer-preset breakdown:")
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
|
||||
for r in items:
|
||||
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
|
||||
funds = r.get("final_funds_usd", 0)
|
||||
print(
|
||||
f" seed={r['seed']} [{status}] "
|
||||
f"${funds:,.0f} "
|
||||
f"score={r.get('composite_score', 0):.3f}"
|
||||
)
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# --- Log results ---
|
||||
samples = [
|
||||
{k: v for k, v in r.items() if k != "messages"} for r in valid
|
||||
]
|
||||
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging results: {e}")
|
||||
|
||||
# --- Cleanup (TB2 pattern) ---
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f"Results saved to: {self._streaming_path}")
|
||||
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log YC-Bench-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
for k, v in self.eval_metrics:
|
||||
wandb_metrics[k] = v
|
||||
self.eval_metrics = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
YCBenchEvalEnv.cli()
|
||||
@@ -701,6 +701,8 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
if images:
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
@@ -727,10 +729,13 @@ class BasePlatformAdapter(ABC):
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
for image_url, alt_text in images:
|
||||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "")
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
if self._is_animation_url(image_url):
|
||||
img_result = await self.send_animation(
|
||||
@@ -745,9 +750,9 @@ class BasePlatformAdapter(ABC):
|
||||
caption=alt_text if alt_text else None,
|
||||
)
|
||||
if not img_result.success:
|
||||
print(f"[{self.name}] Failed to send image: {img_result.error}")
|
||||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
print(f"[{self.name}] Error sending image: {img_err}")
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
|
||||
@@ -267,6 +267,43 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
@@ -179,6 +179,35 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
"""Slack doesn't have a direct typing indicator API for bots."""
|
||||
pass
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file to Slack by uploading it."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=image_path,
|
||||
filename=os.path.basename(image_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
@@ -8,10 +8,13 @@ Uses python-telegram-bot library for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from telegram import Update, Bot, Message
|
||||
from telegram.ext import (
|
||||
@@ -73,6 +76,19 @@ def _escape_mdv2(text: str) -> str:
|
||||
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
|
||||
|
||||
|
||||
def _strip_mdv2(text: str) -> str:
|
||||
"""Strip MarkdownV2 escape backslashes to produce clean plain text.
|
||||
|
||||
Also removes MarkdownV2 bold markers (*text* -> text) so the fallback
|
||||
doesn't show stray asterisks from header/bold conversion.
|
||||
"""
|
||||
# Remove escape backslashes before special characters
|
||||
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
|
||||
# Remove MarkdownV2 bold markers that format_message converted from **bold**
|
||||
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
@@ -199,9 +215,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
text=plain_chunk,
|
||||
parse_mode=None, # Plain text
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
@@ -286,6 +306,34 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -293,12 +341,16 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo."""
|
||||
"""Send an image natively as a Telegram photo.
|
||||
|
||||
Tries URL-based send first (fast, works for <5MB images).
|
||||
Falls back to downloading and uploading as file (supports up to 10MB).
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_url,
|
||||
@@ -307,9 +359,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send photo, falling back to URL: {e}")
|
||||
# Fallback: send as text link
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
logger.warning("[%s] URL-based send_photo failed (%s), trying file upload", self.name, e)
|
||||
# Fallback: download and upload as file (supports up to 10MB)
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(image_url)
|
||||
resp.raise_for_status()
|
||||
image_data = resp.content
|
||||
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_data,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e2:
|
||||
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
|
||||
# Final fallback: send URL as text
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
|
||||
575
gateway/run.py
575
gateway/run.py
@@ -75,6 +75,7 @@ if _config_path.exists():
|
||||
"container_memory": "TERMINAL_CONTAINER_MEMORY",
|
||||
"container_disk": "TERMINAL_CONTAINER_DISK",
|
||||
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
|
||||
"sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
}
|
||||
for _cfg_key, _env_var in _terminal_env_map.items():
|
||||
if _cfg_key in _terminal_cfg:
|
||||
@@ -93,6 +94,11 @@ if _config_path.exists():
|
||||
if _agent_cfg and isinstance(_agent_cfg, dict):
|
||||
if "max_turns" in _agent_cfg:
|
||||
os.environ["HERMES_MAX_ITERATIONS"] = str(_agent_cfg["max_turns"])
|
||||
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
|
||||
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
|
||||
_tz_cfg = _cfg.get("timezone", "")
|
||||
if _tz_cfg and isinstance(_tz_cfg, str) and "HERMES_TIMEZONE" not in os.environ:
|
||||
os.environ["HERMES_TIMEZONE"] = _tz_cfg.strip()
|
||||
except Exception:
|
||||
pass # Non-fatal; gateway can still run with .env values
|
||||
|
||||
@@ -102,11 +108,13 @@ os.environ["HERMES_QUIET"] = "1"
|
||||
# Enable interactive exec approval for dangerous commands on messaging platforms
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
|
||||
# Set terminal working directory for messaging platforms
|
||||
# Uses MESSAGING_CWD if set, otherwise defaults to home directory
|
||||
# This is separate from CLI which uses the directory where `hermes` is run
|
||||
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
|
||||
os.environ["TERMINAL_CWD"] = messaging_cwd
|
||||
# Set terminal working directory for messaging platforms.
|
||||
# If the user set an explicit path in config.yaml (not "." or "auto"),
|
||||
# respect it. Otherwise use MESSAGING_CWD or default to home directory.
|
||||
_configured_cwd = os.environ.get("TERMINAL_CWD", "")
|
||||
if not _configured_cwd or _configured_cwd in (".", "auto", "cwd"):
|
||||
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
|
||||
os.environ["TERMINAL_CWD"] = messaging_cwd
|
||||
|
||||
from gateway.config import (
|
||||
Platform,
|
||||
@@ -173,7 +181,6 @@ class GatewayRunner:
|
||||
self.session_store = SessionStore(
|
||||
self.config.sessions_dir, self.config,
|
||||
has_active_processes_fn=lambda key: process_registry.has_active_for_session(key),
|
||||
on_auto_reset=self._flush_memories_before_reset,
|
||||
)
|
||||
self.delivery_router = DeliveryRouter(self.config)
|
||||
self._running = False
|
||||
@@ -204,15 +211,14 @@ class GatewayRunner:
|
||||
from gateway.hooks import HookRegistry
|
||||
self.hooks = HookRegistry()
|
||||
|
||||
def _flush_memories_before_reset(self, old_entry):
|
||||
"""Prompt the agent to save memories/skills before an auto-reset.
|
||||
|
||||
Called synchronously by SessionStore before destroying an expired session.
|
||||
Loads the transcript, gives the agent a real turn with memory + skills
|
||||
tools, and explicitly asks it to preserve anything worth keeping.
|
||||
def _flush_memories_for_session(self, old_session_id: str):
|
||||
"""Prompt the agent to save memories/skills before context is lost.
|
||||
|
||||
Synchronous worker — meant to be called via run_in_executor from
|
||||
an async context so it doesn't block the event loop.
|
||||
"""
|
||||
try:
|
||||
history = self.session_store.load_transcript(old_entry.session_id)
|
||||
history = self.session_store.load_transcript(old_session_id)
|
||||
if not history or len(history) < 4:
|
||||
return
|
||||
|
||||
@@ -226,7 +232,7 @@ class GatewayRunner:
|
||||
max_iterations=8,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory", "skills"],
|
||||
session_id=old_entry.session_id,
|
||||
session_id=old_session_id,
|
||||
)
|
||||
|
||||
# Build conversation history from transcript
|
||||
@@ -255,9 +261,14 @@ class GatewayRunner:
|
||||
user_message=flush_prompt,
|
||||
conversation_history=msgs,
|
||||
)
|
||||
logger.info("Pre-reset save completed for session %s", old_entry.session_id)
|
||||
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Pre-reset save failed for session %s: %s", old_entry.session_id, e)
|
||||
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
|
||||
|
||||
async def _async_flush_memories(self, old_session_id: str):
|
||||
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
|
||||
|
||||
@staticmethod
|
||||
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
||||
@@ -325,7 +336,7 @@ class GatewayRunner:
|
||||
|
||||
Checks HERMES_REASONING_EFFORT env var first, then agent.reasoning_effort
|
||||
in config.yaml. Valid: "xhigh", "high", "medium", "low", "minimal", "none".
|
||||
Returns None to use default (xhigh).
|
||||
Returns None to use default (medium).
|
||||
"""
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
@@ -346,7 +357,7 @@ class GatewayRunner:
|
||||
valid = ("xhigh", "high", "medium", "low", "minimal")
|
||||
if effort in valid:
|
||||
return {"enabled": True, "effort": effort}
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (xhigh)", effort)
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -459,10 +470,50 @@ class GatewayRunner:
|
||||
# Check if we're restarting after a /update command
|
||||
await self._send_update_notification()
|
||||
|
||||
# Start background session expiry watcher for proactive memory flushing
|
||||
asyncio.create_task(self._session_expiry_watcher())
|
||||
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
return True
|
||||
|
||||
async def _session_expiry_watcher(self, interval: int = 300):
|
||||
"""Background task that proactively flushes memories for expired sessions.
|
||||
|
||||
Runs every `interval` seconds (default 5 min). For each session that
|
||||
has expired according to its reset policy, flushes memories in a thread
|
||||
pool and marks the session so it won't be flushed again.
|
||||
|
||||
This means memories are already saved by the time the user sends their
|
||||
next message, so there's no blocking delay.
|
||||
"""
|
||||
await asyncio.sleep(60) # initial delay — let the gateway fully start
|
||||
while self._running:
|
||||
try:
|
||||
self.session_store._ensure_loaded()
|
||||
for key, entry in list(self.session_store._entries.items()):
|
||||
if entry.session_id in self.session_store._pre_flushed_sessions:
|
||||
continue # already flushed this session
|
||||
if not self.session_store._is_session_expired(entry):
|
||||
continue # session still active
|
||||
# Session has expired — flush memories in the background
|
||||
logger.info(
|
||||
"Session %s expired (key=%s), flushing memories proactively",
|
||||
entry.session_id, key,
|
||||
)
|
||||
try:
|
||||
await self._async_flush_memories(entry.session_id)
|
||||
self.session_store._pre_flushed_sessions.add(entry.session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e)
|
||||
except Exception as e:
|
||||
logger.debug("Session expiry watcher error: %s", e)
|
||||
# Sleep in small increments so we can stop quickly
|
||||
for _ in range(interval):
|
||||
if not self._running:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the gateway and disconnect all adapters."""
|
||||
logger.info("Stopping gateway...")
|
||||
@@ -659,7 +710,8 @@ class GatewayRunner:
|
||||
# Emit command:* hook for any recognized slash command
|
||||
_known_commands = {"new", "reset", "help", "status", "stop", "model",
|
||||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage", "insights", "reload-mcp", "update"}
|
||||
"compress", "usage", "insights", "reload-mcp", "update",
|
||||
"title"}
|
||||
if command and command in _known_commands:
|
||||
await self.hooks.emit(f"command:{command}", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
@@ -683,6 +735,9 @@ class GatewayRunner:
|
||||
if command == "model":
|
||||
return await self._handle_model_command(event)
|
||||
|
||||
if command == "provider":
|
||||
return await self._handle_provider_command(event)
|
||||
|
||||
if command == "personality":
|
||||
return await self._handle_personality_command(event)
|
||||
|
||||
@@ -709,6 +764,9 @@ class GatewayRunner:
|
||||
|
||||
if command == "update":
|
||||
return await self._handle_update_command(event)
|
||||
|
||||
if command == "title":
|
||||
return await self._handle_title_command(event)
|
||||
|
||||
# Skill slash commands: /skill-name loads the skill and sends to agent
|
||||
if command:
|
||||
@@ -783,6 +841,167 @@ class GatewayRunner:
|
||||
# Load conversation history from transcript
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Session hygiene: auto-compress pathologically large transcripts
|
||||
#
|
||||
# Long-lived gateway sessions can accumulate enough history that
|
||||
# every new message rehydrates an oversized transcript, causing
|
||||
# repeated truncation/context failures. Detect this early and
|
||||
# compress proactively — before the agent even starts. (#628)
|
||||
# -----------------------------------------------------------------
|
||||
if history and len(history) >= 4:
|
||||
from agent.model_metadata import estimate_messages_tokens_rough
|
||||
|
||||
# Read thresholds from config.yaml → session_hygiene section
|
||||
_hygiene_cfg = {}
|
||||
try:
|
||||
_hyg_cfg_path = _hermes_home / "config.yaml"
|
||||
if _hyg_cfg_path.exists():
|
||||
import yaml as _hyg_yaml
|
||||
with open(_hyg_cfg_path) as _hyg_f:
|
||||
_hyg_data = _hyg_yaml.safe_load(_hyg_f) or {}
|
||||
_hygiene_cfg = _hyg_data.get("session_hygiene", {})
|
||||
if not isinstance(_hygiene_cfg, dict):
|
||||
_hygiene_cfg = {}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_compress_token_threshold = int(
|
||||
_hygiene_cfg.get("auto_compress_tokens", 100_000)
|
||||
)
|
||||
_compress_msg_threshold = int(
|
||||
_hygiene_cfg.get("auto_compress_messages", 200)
|
||||
)
|
||||
_warn_token_threshold = int(
|
||||
_hygiene_cfg.get("warn_tokens", 200_000)
|
||||
)
|
||||
|
||||
_msg_count = len(history)
|
||||
_approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
_needs_compress = (
|
||||
_approx_tokens >= _compress_token_threshold
|
||||
or _msg_count >= _compress_msg_threshold
|
||||
)
|
||||
|
||||
if _needs_compress:
|
||||
logger.info(
|
||||
"Session hygiene: %s messages, ~%s tokens — auto-compressing "
|
||||
"(thresholds: %s msgs / %s tokens)",
|
||||
_msg_count, f"{_approx_tokens:,}",
|
||||
_compress_msg_threshold, f"{_compress_token_threshold:,}",
|
||||
)
|
||||
|
||||
_hyg_adapter = self.adapters.get(source.platform)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"🗜️ Session is large ({_msg_count} messages, "
|
||||
f"~{_approx_tokens:,} tokens). Auto-compressing..."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from run_agent import AIAgent
|
||||
|
||||
_hyg_runtime = _resolve_runtime_agent_kwargs()
|
||||
if _hyg_runtime.get("api_key"):
|
||||
_hyg_msgs = [
|
||||
{"role": m.get("role"), "content": m.get("content")}
|
||||
for m in history
|
||||
if m.get("role") in ("user", "assistant")
|
||||
and m.get("content")
|
||||
]
|
||||
|
||||
if len(_hyg_msgs) >= 4:
|
||||
_hyg_agent = AIAgent(
|
||||
**_hyg_runtime,
|
||||
max_iterations=4,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
_compressed, _ = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: _hyg_agent._compress_context(
|
||||
_hyg_msgs, "",
|
||||
approx_tokens=_approx_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
self.session_store.rewrite_transcript(
|
||||
session_entry.session_id, _compressed
|
||||
)
|
||||
history = _compressed
|
||||
_new_count = len(_compressed)
|
||||
_new_tokens = estimate_messages_tokens_rough(
|
||||
_compressed
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Session hygiene: compressed %s → %s msgs, "
|
||||
"~%s → ~%s tokens",
|
||||
_msg_count, _new_count,
|
||||
f"{_approx_tokens:,}", f"{_new_tokens:,}",
|
||||
)
|
||||
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"🗜️ Compressed: {_msg_count} → "
|
||||
f"{_new_count} messages, "
|
||||
f"~{_approx_tokens:,} → "
|
||||
f"~{_new_tokens:,} tokens"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Still too large after compression — warn user
|
||||
if _new_tokens >= _warn_token_threshold:
|
||||
logger.warning(
|
||||
"Session hygiene: still ~%s tokens after "
|
||||
"compression — suggesting /reset",
|
||||
f"{_new_tokens:,}",
|
||||
)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
"⚠️ Session is still very large "
|
||||
"after compression "
|
||||
f"(~{_new_tokens:,} tokens). "
|
||||
"Consider using /reset to start "
|
||||
"fresh if you experience issues."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Session hygiene auto-compress failed: %s", e
|
||||
)
|
||||
# Compression failed and session is dangerously large
|
||||
if _approx_tokens >= _warn_token_threshold:
|
||||
_hyg_adapter = self.adapters.get(source.platform)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"⚠️ Session is very large "
|
||||
f"({_msg_count} messages, "
|
||||
f"~{_approx_tokens:,} tokens) and "
|
||||
"auto-compression failed. Consider "
|
||||
"using /compress or /reset to avoid "
|
||||
"issues."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# First-message onboarding -- only on the very first interaction ever
|
||||
if not history and not self.session_store.has_any_sessions():
|
||||
context_prompt += (
|
||||
@@ -1007,33 +1226,12 @@ class GatewayRunner:
|
||||
# Get existing session key
|
||||
session_key = self.session_store._generate_session_key(source)
|
||||
|
||||
# Memory flush before reset: load the old transcript and let a
|
||||
# temporary agent save memories before the session is wiped.
|
||||
# Flush memories in the background (fire-and-forget) so the user
|
||||
# gets the "Session reset!" response immediately.
|
||||
try:
|
||||
old_entry = self.session_store._entries.get(session_key)
|
||||
if old_entry:
|
||||
old_history = self.session_store.load_transcript(old_entry.session_id)
|
||||
if old_history:
|
||||
from run_agent import AIAgent
|
||||
loop = asyncio.get_event_loop()
|
||||
_flush_kwargs = _resolve_runtime_agent_kwargs()
|
||||
def _do_flush():
|
||||
tmp_agent = AIAgent(
|
||||
**_flush_kwargs,
|
||||
max_iterations=5,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=old_entry.session_id,
|
||||
)
|
||||
# Build simple message list from transcript
|
||||
msgs = []
|
||||
for m in old_history:
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
if role in ("user", "assistant") and content:
|
||||
msgs.append({"role": role, "content": content})
|
||||
tmp_agent.flush_memories(msgs)
|
||||
await loop.run_in_executor(None, _do_flush)
|
||||
asyncio.create_task(self._async_flush_memories(old_entry.session_id))
|
||||
except Exception as e:
|
||||
logger.debug("Gateway memory flush on reset failed: %s", e)
|
||||
|
||||
@@ -1100,12 +1298,14 @@ class GatewayRunner:
|
||||
"`/reset` — Reset conversation history",
|
||||
"`/status` — Show session info",
|
||||
"`/stop` — Interrupt the running agent",
|
||||
"`/model [name]` — Show or change the model",
|
||||
"`/model [provider:model]` — Show/change model (or switch provider)",
|
||||
"`/provider` — Show available providers and auth status",
|
||||
"`/personality [name]` — Set a personality",
|
||||
"`/retry` — Retry your last message",
|
||||
"`/undo` — Remove the last exchange",
|
||||
"`/sethome` — Set this chat as the home channel",
|
||||
"`/compress` — Compress conversation context",
|
||||
"`/title [name]` — Set or show the session title",
|
||||
"`/usage` — Show token usage for this session",
|
||||
"`/insights [days]` — Show usage insights and analytics",
|
||||
"`/reload-mcp` — Reload MCP servers from config",
|
||||
@@ -1126,13 +1326,20 @@ class GatewayRunner:
|
||||
async def _handle_model_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /model command - show or change the current model."""
|
||||
import yaml
|
||||
from hermes_cli.models import (
|
||||
parse_model_input,
|
||||
validate_requested_model,
|
||||
curated_models_for_provider,
|
||||
normalize_provider,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
|
||||
args = event.get_command_args().strip()
|
||||
config_path = _hermes_home / 'config.yaml'
|
||||
|
||||
# Resolve current model the same way the agent init does:
|
||||
# env vars first, then config.yaml always overrides.
|
||||
# Resolve current model and provider from config
|
||||
current = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
|
||||
current_provider = "openrouter"
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
@@ -1142,39 +1349,164 @@ class GatewayRunner:
|
||||
current = model_cfg
|
||||
elif isinstance(model_cfg, dict):
|
||||
current = model_cfg.get("default", current)
|
||||
current_provider = model_cfg.get("provider", current_provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Resolve "auto" to the actual provider using credential detection
|
||||
current_provider = normalize_provider(current_provider)
|
||||
if current_provider == "auto":
|
||||
try:
|
||||
from hermes_cli.auth import resolve_provider as _resolve_provider
|
||||
current_provider = _resolve_provider(current_provider)
|
||||
except Exception:
|
||||
current_provider = "openrouter"
|
||||
|
||||
if not args:
|
||||
return f"🤖 **Current model:** `{current}`\n\nTo change: `/model provider/model-name`"
|
||||
provider_label = _PROVIDER_LABELS.get(current_provider, current_provider)
|
||||
lines = [
|
||||
f"🤖 **Current model:** `{current}`",
|
||||
f"**Provider:** {provider_label}",
|
||||
"",
|
||||
]
|
||||
curated = curated_models_for_provider(current_provider)
|
||||
if curated:
|
||||
lines.append(f"**Available models ({provider_label}):**")
|
||||
for mid, desc in curated:
|
||||
marker = " ←" if mid == current else ""
|
||||
label = f" _{desc}_" if desc else ""
|
||||
lines.append(f"• `{mid}`{label}{marker}")
|
||||
lines.append("")
|
||||
lines.append("To change: `/model model-name`")
|
||||
lines.append("Switch provider: `/model provider:model-name`")
|
||||
return "\n".join(lines)
|
||||
|
||||
if "/" not in args:
|
||||
return (
|
||||
f"🤖 Invalid model format: `{args}`\n\n"
|
||||
f"Use `provider/model-name` format, e.g.:\n"
|
||||
f"• `anthropic/claude-sonnet-4`\n"
|
||||
f"• `google/gemini-2.5-pro`\n"
|
||||
f"• `openai/gpt-4o`"
|
||||
)
|
||||
# Parse provider:model syntax
|
||||
target_provider, new_model = parse_model_input(args, current_provider)
|
||||
provider_changed = target_provider != current_provider
|
||||
|
||||
# Write to config.yaml (source of truth), same pattern as CLI save_config_value.
|
||||
# Resolve credentials for the target provider (for API probe)
|
||||
api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
if provider_changed:
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=target_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception as e:
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
return f"⚠️ Could not resolve credentials for provider '{provider_label}': {e}"
|
||||
else:
|
||||
# Use current provider's base_url from config or registry
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=current_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Validate the model against the live API
|
||||
try:
|
||||
validation = validate_requested_model(
|
||||
new_model,
|
||||
target_provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
msg = validation.get("message", "Invalid model")
|
||||
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
|
||||
return f"⚠️ {msg}{tip}"
|
||||
|
||||
# Persist to config only if validation approves
|
||||
if validation.get("persist"):
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = new_model
|
||||
if provider_changed:
|
||||
user_config["model"]["provider"] = target_provider
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save model change: {e}"
|
||||
|
||||
# Set env vars so the next agent run picks up the change
|
||||
os.environ["HERMES_MODEL"] = new_model
|
||||
if provider_changed:
|
||||
os.environ["HERMES_INFERENCE_PROVIDER"] = target_provider
|
||||
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
provider_note = f"\n**Provider:** {provider_label}" if provider_changed else ""
|
||||
|
||||
warning = ""
|
||||
if validation.get("message"):
|
||||
warning = f"\n⚠️ {validation['message']}"
|
||||
|
||||
if validation.get("persist"):
|
||||
persist_note = "saved to config"
|
||||
else:
|
||||
persist_note = "this session only — will revert on restart"
|
||||
return f"🤖 Model changed to `{new_model}` ({persist_note}){provider_note}{warning}\n_(takes effect on next message)_"
|
||||
|
||||
async def _handle_provider_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /provider command - show available providers."""
|
||||
import yaml
|
||||
from hermes_cli.models import (
|
||||
list_available_providers,
|
||||
normalize_provider,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
|
||||
# Resolve current provider from config
|
||||
current_provider = "openrouter"
|
||||
config_path = _hermes_home / 'config.yaml'
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = args
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save model change: {e}"
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
model_cfg = cfg.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
current_provider = model_cfg.get("provider", current_provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Also set env var so code reading it before the next agent init sees the update.
|
||||
os.environ["HERMES_MODEL"] = args
|
||||
current_provider = normalize_provider(current_provider)
|
||||
if current_provider == "auto":
|
||||
try:
|
||||
from hermes_cli.auth import resolve_provider as _resolve_provider
|
||||
current_provider = _resolve_provider(current_provider)
|
||||
except Exception:
|
||||
current_provider = "openrouter"
|
||||
|
||||
return f"🤖 Model changed to `{args}`\n_(takes effect on next message)_"
|
||||
current_label = _PROVIDER_LABELS.get(current_provider, current_provider)
|
||||
|
||||
lines = [
|
||||
f"🔌 **Current provider:** {current_label} (`{current_provider}`)",
|
||||
"",
|
||||
"**Available providers:**",
|
||||
]
|
||||
|
||||
providers = list_available_providers()
|
||||
for p in providers:
|
||||
marker = " ← active" if p["id"] == current_provider else ""
|
||||
auth = "✅" if p["authenticated"] else "❌"
|
||||
aliases = f" _(also: {', '.join(p['aliases'])})_" if p["aliases"] else ""
|
||||
lines.append(f"{auth} `{p['id']}` — {p['label']}{aliases}{marker}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("Switch: `/model provider:model-name`")
|
||||
lines.append("Setup: `hermes setup`")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_personality_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /personality command - list or set a personality."""
|
||||
@@ -1364,6 +1696,40 @@ class GatewayRunner:
|
||||
logger.warning("Manual compress failed: %s", e)
|
||||
return f"Compression failed: {e}"
|
||||
|
||||
async def _handle_title_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /title command — set or show the current session's title."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_id = session_entry.session_id
|
||||
|
||||
if not self._session_db:
|
||||
return "Session database not available."
|
||||
|
||||
title_arg = event.get_command_args().strip()
|
||||
if title_arg:
|
||||
# Sanitize the title before setting
|
||||
try:
|
||||
sanitized = self._session_db.sanitize_title(title_arg)
|
||||
except ValueError as e:
|
||||
return f"⚠️ {e}"
|
||||
if not sanitized:
|
||||
return "⚠️ Title is empty after cleanup. Please use printable characters."
|
||||
# Set the title
|
||||
try:
|
||||
if self._session_db.set_session_title(session_id, sanitized):
|
||||
return f"✏️ Session title set: **{sanitized}**"
|
||||
else:
|
||||
return "Session not found in database."
|
||||
except ValueError as e:
|
||||
return f"⚠️ {e}"
|
||||
else:
|
||||
# Show the current title
|
||||
title = self._session_db.get_session_title(session_id)
|
||||
if title:
|
||||
return f"📌 Session title: **{title}**"
|
||||
else:
|
||||
return "No title set. Usage: `/title My Session Name`"
|
||||
|
||||
async def _handle_usage_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /usage command -- show token usage for the session's last agent run."""
|
||||
source = event.source
|
||||
@@ -2092,7 +2458,7 @@ class GatewayRunner:
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key or ""
|
||||
|
||||
# Read from env var or use default (same as CLI)
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
|
||||
# Map platform enum to the platform hint key the agent understands.
|
||||
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
|
||||
@@ -2432,34 +2798,77 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
|
||||
logger.info("Cron ticker stopped")
|
||||
|
||||
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = False) -> bool:
|
||||
"""
|
||||
Start the gateway and run until interrupted.
|
||||
|
||||
This is the main entry point for running the gateway.
|
||||
Returns True if the gateway ran successfully, False if it failed to start.
|
||||
A False return causes a non-zero exit code so systemd can auto-restart.
|
||||
|
||||
Args:
|
||||
config: Optional gateway configuration override.
|
||||
replace: If True, kill any existing gateway instance before starting.
|
||||
Useful for systemd services to avoid restart-loop deadlocks
|
||||
when the previous process hasn't fully exited yet.
|
||||
"""
|
||||
# ── Duplicate-instance guard ──────────────────────────────────────
|
||||
# Prevent two gateways from running under the same HERMES_HOME.
|
||||
# The PID file is scoped to HERMES_HOME, so future multi-profile
|
||||
# setups (each profile using a distinct HERMES_HOME) will naturally
|
||||
# allow concurrent instances without tripping this guard.
|
||||
from gateway.status import get_running_pid
|
||||
import time as _time
|
||||
from gateway.status import get_running_pid, remove_pid_file
|
||||
existing_pid = get_running_pid()
|
||||
if existing_pid is not None and existing_pid != os.getpid():
|
||||
hermes_home = os.getenv("HERMES_HOME", "~/.hermes")
|
||||
logger.error(
|
||||
"Another gateway instance is already running (PID %d, HERMES_HOME=%s). "
|
||||
"Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.",
|
||||
existing_pid, hermes_home,
|
||||
)
|
||||
print(
|
||||
f"\n❌ Gateway already running (PID {existing_pid}).\n"
|
||||
f" Use 'hermes gateway restart' to replace it,\n"
|
||||
f" or 'hermes gateway stop' to kill it first.\n"
|
||||
)
|
||||
return False
|
||||
if replace:
|
||||
logger.info(
|
||||
"Replacing existing gateway instance (PID %d) with --replace.",
|
||||
existing_pid,
|
||||
)
|
||||
try:
|
||||
os.kill(existing_pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
pass # Already gone
|
||||
except PermissionError:
|
||||
logger.error(
|
||||
"Permission denied killing PID %d. Cannot replace.",
|
||||
existing_pid,
|
||||
)
|
||||
return False
|
||||
# Wait up to 10 seconds for the old process to exit
|
||||
for _ in range(20):
|
||||
try:
|
||||
os.kill(existing_pid, 0)
|
||||
_time.sleep(0.5)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
break # Process is gone
|
||||
else:
|
||||
# Still alive after 10s — force kill
|
||||
logger.warning(
|
||||
"Old gateway (PID %d) did not exit after SIGTERM, sending SIGKILL.",
|
||||
existing_pid,
|
||||
)
|
||||
try:
|
||||
os.kill(existing_pid, signal.SIGKILL)
|
||||
_time.sleep(0.5)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
remove_pid_file()
|
||||
else:
|
||||
hermes_home = os.getenv("HERMES_HOME", "~/.hermes")
|
||||
logger.error(
|
||||
"Another gateway instance is already running (PID %d, HERMES_HOME=%s). "
|
||||
"Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.",
|
||||
existing_pid, hermes_home,
|
||||
)
|
||||
print(
|
||||
f"\n❌ Gateway already running (PID {existing_pid}).\n"
|
||||
f" Use 'hermes gateway restart' to replace it,\n"
|
||||
f" or 'hermes gateway stop' to kill it first.\n"
|
||||
f" Or use 'hermes gateway run --replace' to auto-replace.\n"
|
||||
)
|
||||
return False
|
||||
|
||||
# Sync bundled skills on gateway start (fast -- skips unchanged)
|
||||
try:
|
||||
|
||||
@@ -311,7 +311,9 @@ class SessionStore:
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
self._loaded = False
|
||||
self._has_active_processes_fn = has_active_processes_fn
|
||||
self._on_auto_reset = on_auto_reset # callback(old_entry) before auto-reset
|
||||
# on_auto_reset is deprecated — memory flush now runs proactively
|
||||
# via the background session expiry watcher in GatewayRunner.
|
||||
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
|
||||
|
||||
# Initialize SQLite session database
|
||||
self._db = None
|
||||
@@ -353,6 +355,44 @@ class SessionStore:
|
||||
"""Generate a session key from a source."""
|
||||
return build_session_key(source)
|
||||
|
||||
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||
"""Check if a session has expired based on its reset policy.
|
||||
|
||||
Works from the entry alone — no SessionSource needed.
|
||||
Used by the background expiry watcher to proactively flush memories.
|
||||
Sessions with active background processes are never considered expired.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
if self._has_active_processes_fn(entry.session_key):
|
||||
return False
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=entry.platform,
|
||||
session_type=entry.chat_type,
|
||||
)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0, second=0, microsecond=0,
|
||||
)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
@@ -439,13 +479,11 @@ class SessionStore:
|
||||
self._save()
|
||||
return entry
|
||||
else:
|
||||
# Session is being auto-reset — flush memories before destroying
|
||||
# Session is being auto-reset. The background expiry watcher
|
||||
# should have already flushed memories proactively; discard
|
||||
# the marker so it doesn't accumulate.
|
||||
was_auto_reset = True
|
||||
if self._on_auto_reset:
|
||||
try:
|
||||
self._on_auto_reset(entry)
|
||||
except Exception as e:
|
||||
logger.debug("Auto-reset callback failed: %s", e)
|
||||
self._pre_flushed_sessions.discard(entry.session_id)
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(entry.session_id, "session_reset")
|
||||
|
||||
@@ -72,15 +72,19 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""Describes a known OAuth provider."""
|
||||
"""Describes a known inference provider."""
|
||||
id: str
|
||||
name: str
|
||||
auth_type: str # "oauth_device_code" or "api_key"
|
||||
auth_type: str # "oauth_device_code", "oauth_external", or "api_key"
|
||||
portal_base_url: str = ""
|
||||
inference_base_url: str = ""
|
||||
client_id: str = ""
|
||||
scope: str = ""
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
# For API-key providers: env vars to check (in priority order)
|
||||
api_key_env_vars: tuple = ()
|
||||
# Optional env var for base URL override
|
||||
base_url_env_var: str = ""
|
||||
|
||||
|
||||
PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
@@ -99,9 +103,118 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
auth_type="oauth_external",
|
||||
inference_base_url=DEFAULT_CODEX_BASE_URL,
|
||||
),
|
||||
"zai": ProviderConfig(
|
||||
id="zai",
|
||||
name="Z.AI / GLM",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.z.ai/api/paas/v4",
|
||||
api_key_env_vars=("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
base_url_env_var="GLM_BASE_URL",
|
||||
),
|
||||
"kimi-coding": ProviderConfig(
|
||||
id="kimi-coding",
|
||||
name="Kimi / Moonshot",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.moonshot.ai/v1",
|
||||
api_key_env_vars=("KIMI_API_KEY",),
|
||||
base_url_env_var="KIMI_BASE_URL",
|
||||
),
|
||||
"minimax": ProviderConfig(
|
||||
id="minimax",
|
||||
name="MiniMax",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimax.io/v1",
|
||||
api_key_env_vars=("MINIMAX_API_KEY",),
|
||||
base_url_env_var="MINIMAX_BASE_URL",
|
||||
),
|
||||
"minimax-cn": ProviderConfig(
|
||||
id="minimax-cn",
|
||||
name="MiniMax (China)",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimaxi.com/v1",
|
||||
api_key_env_vars=("MINIMAX_CN_API_KEY",),
|
||||
base_url_env_var="MINIMAX_CN_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code Endpoint Detection
|
||||
# =============================================================================
|
||||
|
||||
# Kimi Code (platform.kimi.ai) issues keys prefixed "sk-kimi-" that only work
|
||||
# on api.kimi.com/coding/v1. Legacy keys from platform.moonshot.ai work on
|
||||
# api.moonshot.ai/v1 (the default). Auto-detect when user hasn't set
|
||||
# KIMI_BASE_URL explicitly.
|
||||
KIMI_CODE_BASE_URL = "https://api.kimi.com/coding/v1"
|
||||
|
||||
|
||||
def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) -> str:
|
||||
"""Return the correct Kimi base URL based on the API key prefix.
|
||||
|
||||
If the user has explicitly set KIMI_BASE_URL, that always wins.
|
||||
Otherwise, sk-kimi- prefixed keys route to api.kimi.com/coding/v1.
|
||||
"""
|
||||
if env_override:
|
||||
return env_override
|
||||
if api_key.startswith("sk-kimi-"):
|
||||
return KIMI_CODE_BASE_URL
|
||||
return default_url
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z.AI Endpoint Detection
|
||||
# =============================================================================
|
||||
|
||||
# Z.AI has separate billing for general vs coding plans, and global vs China
|
||||
# endpoints. A key that works on one may return "Insufficient balance" on
|
||||
# another. We probe at setup time and store the working endpoint.
|
||||
|
||||
ZAI_ENDPOINTS = [
|
||||
# (id, base_url, default_model, label)
|
||||
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
|
||||
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
|
||||
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
|
||||
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
|
||||
]
|
||||
|
||||
|
||||
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str, str]]:
|
||||
"""Probe z.ai endpoints to find one that accepts this API key.
|
||||
|
||||
Returns {"id": ..., "base_url": ..., "model": ..., "label": ...} for the
|
||||
first working endpoint, or None if all fail.
|
||||
"""
|
||||
for ep_id, base_url, model, label in ZAI_ENDPOINTS:
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"stream": False,
|
||||
"max_tokens": 1,
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
logger.debug("Z.AI endpoint probe: %s (%s) OK", ep_id, base_url)
|
||||
return {
|
||||
"id": ep_id,
|
||||
"base_url": base_url,
|
||||
"model": model,
|
||||
"label": label,
|
||||
}
|
||||
logger.debug("Z.AI endpoint probe: %s returned %s", ep_id, resp.status_code)
|
||||
except Exception as exc:
|
||||
logger.debug("Z.AI endpoint probe: %s failed: %s", ep_id, exc)
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Error Types
|
||||
# =============================================================================
|
||||
@@ -355,10 +468,19 @@ def resolve_provider(
|
||||
1. active_provider in auth.json with valid credentials
|
||||
2. Explicit CLI api_key/base_url -> "openrouter"
|
||||
3. OPENAI_API_KEY or OPENROUTER_API_KEY env vars -> "openrouter"
|
||||
4. Fallback: "openrouter"
|
||||
4. Provider-specific API keys (GLM, Kimi, MiniMax) -> that provider
|
||||
5. Fallback: "openrouter"
|
||||
"""
|
||||
normalized = (requested or "auto").strip().lower()
|
||||
|
||||
# Normalize provider aliases
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
}
|
||||
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
if normalized in {"openrouter", "custom"}:
|
||||
return "openrouter"
|
||||
if normalized in PROVIDER_REGISTRY:
|
||||
@@ -387,6 +509,14 @@ def resolve_provider(
|
||||
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
|
||||
return "openrouter"
|
||||
|
||||
# Auto-detect API-key providers by checking their env vars
|
||||
for pid, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
return pid
|
||||
|
||||
return "openrouter"
|
||||
|
||||
|
||||
@@ -1230,6 +1360,42 @@ def get_codex_auth_status() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
"""Status snapshot for API-key providers (z.ai, Kimi, MiniMax)."""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
return {"configured": False}
|
||||
|
||||
api_key = ""
|
||||
key_source = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
key_source = env_var
|
||||
break
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
|
||||
if provider_id == "kimi-coding":
|
||||
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, env_url)
|
||||
elif env_url:
|
||||
base_url = env_url
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
|
||||
return {
|
||||
"configured": bool(api_key),
|
||||
"provider": provider_id,
|
||||
"name": pconfig.name,
|
||||
"key_source": key_source,
|
||||
"base_url": base_url,
|
||||
"logged_in": bool(api_key), # compat with OAuth status shape
|
||||
}
|
||||
|
||||
|
||||
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generic auth status dispatcher."""
|
||||
target = provider_id or get_active_provider()
|
||||
@@ -1237,9 +1403,54 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
return get_nous_auth_status()
|
||||
if target == "openai-codex":
|
||||
return get_codex_auth_status()
|
||||
# API-key providers
|
||||
pconfig = PROVIDER_REGISTRY.get(target)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
return get_api_key_provider_status(target)
|
||||
return {"logged_in": False}
|
||||
|
||||
|
||||
def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
"""Resolve API key and base URL for an API-key provider.
|
||||
|
||||
Returns dict with: provider, api_key, base_url, source.
|
||||
"""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
raise AuthError(
|
||||
f"Provider '{provider_id}' is not an API-key provider.",
|
||||
provider=provider_id,
|
||||
code="invalid_provider",
|
||||
)
|
||||
|
||||
api_key = ""
|
||||
key_source = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
key_source = env_var
|
||||
break
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
|
||||
if provider_id == "kimi-coding":
|
||||
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, env_url)
|
||||
elif env_url:
|
||||
base_url = env_url.rstrip("/")
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
|
||||
return {
|
||||
"provider": provider_id,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url.rstrip("/"),
|
||||
"source": key_source or "default",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# External credential detection
|
||||
# =============================================================================
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
"""Welcome banner, ASCII art, and skills summary for the CLI.
|
||||
"""Welcome banner, ASCII art, skills summary, and update check for the CLI.
|
||||
|
||||
Pure display functions with no HermesCLI state dependency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -13,6 +18,8 @@ from rich.table import Table
|
||||
from prompt_toolkit import print_formatted_text as _pt_print
|
||||
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# ANSI building blocks for conversation display
|
||||
@@ -95,6 +102,72 @@ def get_available_skills() -> Dict[str, List[str]]:
|
||||
return skills_by_category
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Update check
|
||||
# =========================================================================
|
||||
|
||||
# Cache update check results for 6 hours to avoid repeated git fetches
|
||||
_UPDATE_CHECK_CACHE_SECONDS = 6 * 3600
|
||||
|
||||
|
||||
def check_for_updates() -> Optional[int]:
|
||||
"""Check how many commits behind origin/main the local repo is.
|
||||
|
||||
Does a ``git fetch`` at most once every 6 hours (cached to
|
||||
``~/.hermes/.update_check``). Returns the number of commits behind,
|
||||
or ``None`` if the check fails or isn't applicable.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
repo_dir = hermes_home / "hermes-agent"
|
||||
cache_file = hermes_home / ".update_check"
|
||||
|
||||
# Must be a git repo
|
||||
if not (repo_dir / ".git").exists():
|
||||
return None
|
||||
|
||||
# Read cache
|
||||
now = time.time()
|
||||
try:
|
||||
if cache_file.exists():
|
||||
cached = json.loads(cache_file.read_text())
|
||||
if now - cached.get("ts", 0) < _UPDATE_CHECK_CACHE_SECONDS:
|
||||
return cached.get("behind")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch latest refs (fast — only downloads ref metadata, no files)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "fetch", "origin", "--quiet"],
|
||||
capture_output=True, timeout=10,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
except Exception:
|
||||
pass # Offline or timeout — use stale refs, that's fine
|
||||
|
||||
# Count commits behind
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-list", "--count", "HEAD..origin/main"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
behind = int(result.stdout.strip())
|
||||
else:
|
||||
behind = None
|
||||
except Exception:
|
||||
behind = None
|
||||
|
||||
# Write cache
|
||||
try:
|
||||
cache_file.write_text(json.dumps({"ts": now, "behind": behind}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return behind
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Welcome banner
|
||||
# =========================================================================
|
||||
@@ -259,6 +332,18 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
summary_parts.append("/help for commands")
|
||||
right_lines.append(f"[dim #B8860B]{' · '.join(summary_parts)}[/]")
|
||||
|
||||
# Update check — show if behind origin/main
|
||||
try:
|
||||
behind = check_for_updates()
|
||||
if behind and behind > 0:
|
||||
commits_word = "commit" if behind == 1 else "commits"
|
||||
right_lines.append(
|
||||
f"[bold yellow]⚠ {behind} {commits_word} behind[/]"
|
||||
f"[dim yellow] — run [bold]hermes update[/bold] to update[/]"
|
||||
)
|
||||
except Exception:
|
||||
pass # Never break the banner over an update check
|
||||
|
||||
right_content = "\n".join(right_lines)
|
||||
layout_table.add_row(left_content, right_content)
|
||||
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
"""Slash command definitions and autocomplete for the Hermes CLI.
|
||||
|
||||
Contains the COMMANDS dict and the SlashCommandCompleter class.
|
||||
These are pure data/UI with no HermesCLI state dependency.
|
||||
Contains the shared built-in ``COMMANDS`` dict and ``SlashCommandCompleter``.
|
||||
The completer can optionally include dynamic skill slash commands supplied by the
|
||||
interactive CLI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
|
||||
@@ -12,6 +18,7 @@ COMMANDS = {
|
||||
"/tools": "List available tools",
|
||||
"/toolsets": "List available toolsets",
|
||||
"/model": "Show or change the current model",
|
||||
"/provider": "Show available providers and current provider",
|
||||
"/prompt": "View/set custom system prompt",
|
||||
"/personality": "Set a predefined personality",
|
||||
"/clear": "Clear screen and reset conversation (fresh start)",
|
||||
@@ -27,26 +34,68 @@ COMMANDS = {
|
||||
"/platforms": "Show gateway/messaging platform status",
|
||||
"/verbose": "Cycle tool progress display: off → new → all → verbose",
|
||||
"/compress": "Manually compress conversation context (flush memories + summarize)",
|
||||
"/title": "Set a title for the current session (usage: /title My Session Name)",
|
||||
"/usage": "Show token usage for the current session",
|
||||
"/insights": "Show usage insights and analytics (last 30 days)",
|
||||
"/paste": "Check clipboard for an image and attach it",
|
||||
"/reload-mcp": "Reload MCP servers from config.yaml",
|
||||
"/quit": "Exit the CLI (also: /exit, /q)",
|
||||
}
|
||||
|
||||
|
||||
class SlashCommandCompleter(Completer):
|
||||
"""Autocomplete for /commands in the input area."""
|
||||
"""Autocomplete for built-in slash commands and optional skill commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
skill_commands_provider: Callable[[], Mapping[str, dict[str, Any]]] | None = None,
|
||||
) -> None:
|
||||
self._skill_commands_provider = skill_commands_provider
|
||||
|
||||
def _iter_skill_commands(self) -> Mapping[str, dict[str, Any]]:
|
||||
if self._skill_commands_provider is None:
|
||||
return {}
|
||||
try:
|
||||
return self._skill_commands_provider() or {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _completion_text(cmd_name: str, word: str) -> str:
|
||||
"""Return replacement text for a completion.
|
||||
|
||||
When the user has already typed the full command exactly (``/help``),
|
||||
returning ``help`` would be a no-op and prompt_toolkit suppresses the
|
||||
menu. Appending a trailing space keeps the dropdown visible and makes
|
||||
backspacing retrigger it naturally.
|
||||
"""
|
||||
return f"{cmd_name} " if cmd_name == word else cmd_name
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text_before_cursor
|
||||
if not text.startswith("/"):
|
||||
return
|
||||
|
||||
word = text[1:]
|
||||
|
||||
for cmd, desc in COMMANDS.items():
|
||||
cmd_name = cmd[1:]
|
||||
if cmd_name.startswith(word):
|
||||
yield Completion(
|
||||
cmd_name,
|
||||
self._completion_text(cmd_name, word),
|
||||
start_position=-len(word),
|
||||
display=cmd,
|
||||
display_meta=desc,
|
||||
)
|
||||
|
||||
for cmd, info in self._iter_skill_commands().items():
|
||||
cmd_name = cmd[1:]
|
||||
if cmd_name.startswith(word):
|
||||
description = str(info.get("description", "Skill command"))
|
||||
short_desc = description[:50] + ("..." if len(description) > 50 else "")
|
||||
yield Completion(
|
||||
self._completion_text(cmd_name, word),
|
||||
start_position=-len(word),
|
||||
display=cmd,
|
||||
display_meta=f"⚡ {short_desc}",
|
||||
)
|
||||
|
||||
@@ -141,9 +141,13 @@ DEFAULT_CONFIG = {
|
||||
# (apiKey, workspace, peerName, sessions, enabled) comes from the global config.
|
||||
"honcho": {},
|
||||
|
||||
# IANA timezone (e.g. "Asia/Kolkata", "America/New_York").
|
||||
# Empty string means use server-local time.
|
||||
"timezone": "",
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
"command_allowlist": [],
|
||||
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 5,
|
||||
}
|
||||
@@ -152,6 +156,15 @@ DEFAULT_CONFIG = {
|
||||
# Config Migration System
|
||||
# =============================================================================
|
||||
|
||||
# Track which env vars were introduced in each config version.
|
||||
# Migration only mentions vars new since the user's previous version.
|
||||
ENV_VARS_BY_VERSION: Dict[int, List[str]] = {
|
||||
3: ["FIRECRAWL_API_KEY", "BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID", "FAL_KEY"],
|
||||
4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"],
|
||||
5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS",
|
||||
"SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"],
|
||||
}
|
||||
|
||||
# Required environment variables with metadata for migration prompts.
|
||||
# LLM provider is required but handled in the setup wizard's provider
|
||||
# selection step (Nous Portal / OpenRouter / Custom endpoint), so this
|
||||
@@ -170,6 +183,86 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_API_KEY": {
|
||||
"description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)",
|
||||
"prompt": "Z.AI / GLM API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"ZAI_API_KEY": {
|
||||
"description": "Z.AI API key (alias for GLM_API_KEY)",
|
||||
"prompt": "Z.AI API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"Z_AI_API_KEY": {
|
||||
"description": "Z.AI API key (alias for GLM_API_KEY)",
|
||||
"prompt": "Z.AI API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_BASE_URL": {
|
||||
"description": "Z.AI / GLM base URL override",
|
||||
"prompt": "Z.AI / GLM base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"KIMI_API_KEY": {
|
||||
"description": "Kimi / Moonshot API key",
|
||||
"prompt": "Kimi API key",
|
||||
"url": "https://platform.moonshot.cn/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"KIMI_BASE_URL": {
|
||||
"description": "Kimi / Moonshot base URL override",
|
||||
"prompt": "Kimi base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_API_KEY": {
|
||||
"description": "MiniMax API key (international)",
|
||||
"prompt": "MiniMax API key",
|
||||
"url": "https://www.minimax.io/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_BASE_URL": {
|
||||
"description": "MiniMax base URL override",
|
||||
"prompt": "MiniMax base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_CN_API_KEY": {
|
||||
"description": "MiniMax API key (China endpoint)",
|
||||
"prompt": "MiniMax (China) API key",
|
||||
"url": "https://www.minimaxi.com/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_CN_BASE_URL": {
|
||||
"description": "MiniMax (China) base URL override",
|
||||
"prompt": "MiniMax (China) base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"FIRECRAWL_API_KEY": {
|
||||
@@ -189,7 +282,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"advanced": True,
|
||||
},
|
||||
"BROWSERBASE_API_KEY": {
|
||||
"description": "Browserbase API key for browser automation",
|
||||
"description": "Browserbase API key for cloud browser (optional — local browser works without this)",
|
||||
"prompt": "Browserbase API key",
|
||||
"url": "https://browserbase.com/",
|
||||
"tools": ["browser_navigate", "browser_click"],
|
||||
@@ -197,7 +290,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "tool",
|
||||
},
|
||||
"BROWSERBASE_PROJECT_ID": {
|
||||
"description": "Browserbase project ID",
|
||||
"description": "Browserbase project ID (optional — only needed for cloud browser)",
|
||||
"prompt": "Browserbase project ID",
|
||||
"url": "https://browserbase.com/",
|
||||
"tools": ["browser_navigate", "browser_click"],
|
||||
@@ -485,6 +578,22 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
if not quiet:
|
||||
print(f" ✓ Migrated tool progress to config.yaml: {display['tool_progress']}")
|
||||
|
||||
# ── Version 4 → 5: add timezone field ──
|
||||
if current_ver < 5:
|
||||
config = load_config()
|
||||
if "timezone" not in config:
|
||||
old_tz = os.getenv("HERMES_TIMEZONE", "")
|
||||
if old_tz and old_tz.strip():
|
||||
config["timezone"] = old_tz.strip()
|
||||
results["config_added"].append(f"timezone={old_tz.strip()} (from HERMES_TIMEZONE)")
|
||||
else:
|
||||
config["timezone"] = ""
|
||||
results["config_added"].append("timezone= (empty, uses server-local)")
|
||||
save_config(config)
|
||||
if not quiet:
|
||||
tz_display = config["timezone"] or "(server-local)"
|
||||
print(f" ✓ Added timezone to config.yaml: {tz_display}")
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
@@ -525,34 +634,47 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
if v["name"] not in required_names and not v.get("advanced")
|
||||
]
|
||||
|
||||
if interactive and missing_optional:
|
||||
print(" Would you like to configure any optional keys now?")
|
||||
try:
|
||||
answer = input(" Configure optional keys? [y/N]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
answer = "n"
|
||||
|
||||
if answer in ("y", "yes"):
|
||||
# Only offer to configure env vars that are NEW since the user's previous version
|
||||
new_var_names = set()
|
||||
for ver in range(current_ver + 1, latest_ver + 1):
|
||||
new_var_names.update(ENV_VARS_BY_VERSION.get(ver, []))
|
||||
|
||||
if new_var_names and interactive and not quiet:
|
||||
new_and_unset = [
|
||||
(name, OPTIONAL_ENV_VARS[name])
|
||||
for name in sorted(new_var_names)
|
||||
if not get_env_value(name) and name in OPTIONAL_ENV_VARS
|
||||
]
|
||||
if new_and_unset:
|
||||
print(f"\n {len(new_and_unset)} new optional key(s) in this update:")
|
||||
for name, info in new_and_unset:
|
||||
print(f" • {name} — {info.get('description', '')}")
|
||||
print()
|
||||
for var in missing_optional:
|
||||
desc = var.get("description", "")
|
||||
if var.get("url"):
|
||||
print(f" {desc}")
|
||||
print(f" Get your key at: {var['url']}")
|
||||
else:
|
||||
print(f" {desc}")
|
||||
|
||||
if var.get("password"):
|
||||
import getpass
|
||||
value = getpass.getpass(f" {var['prompt']} (Enter to skip): ")
|
||||
else:
|
||||
value = input(f" {var['prompt']} (Enter to skip): ").strip()
|
||||
|
||||
if value:
|
||||
save_env_value(var["name"], value)
|
||||
results["env_added"].append(var["name"])
|
||||
print(f" ✓ Saved {var['name']}")
|
||||
try:
|
||||
answer = input(" Configure new keys? [y/N]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
answer = "n"
|
||||
|
||||
if answer in ("y", "yes"):
|
||||
print()
|
||||
for name, info in new_and_unset:
|
||||
if info.get("url"):
|
||||
print(f" {info.get('description', name)}")
|
||||
print(f" Get your key at: {info['url']}")
|
||||
else:
|
||||
print(f" {info.get('description', name)}")
|
||||
if info.get("password"):
|
||||
import getpass
|
||||
value = getpass.getpass(f" {info.get('prompt', name)} (Enter to skip): ")
|
||||
else:
|
||||
value = input(f" {info.get('prompt', name)} (Enter to skip): ").strip()
|
||||
if value:
|
||||
save_env_value(name, value)
|
||||
results["env_added"].append(name)
|
||||
print(f" ✓ Saved {name}")
|
||||
print()
|
||||
else:
|
||||
print(" Set later with: hermes config set KEY VALUE")
|
||||
|
||||
# Check for missing config fields
|
||||
missing_config = get_missing_config_fields()
|
||||
@@ -772,6 +894,15 @@ def show_config():
|
||||
print(f" SSH host: {ssh_host or '(not set)'}")
|
||||
print(f" SSH user: {ssh_user or '(not set)'}")
|
||||
|
||||
# Timezone
|
||||
print()
|
||||
print(color("◆ Timezone", Colors.CYAN, Colors.BOLD))
|
||||
tz = config.get('timezone', '')
|
||||
if tz:
|
||||
print(f" Timezone: {tz}")
|
||||
else:
|
||||
print(f" Timezone: {color('(server-local)', Colors.DIM)}")
|
||||
|
||||
# Compression
|
||||
print()
|
||||
print(color("◆ Context Compression", Colors.CYAN, Colors.BOLD))
|
||||
@@ -895,6 +1026,7 @@ def set_config_value(key: str, value: str):
|
||||
"terminal.daytona_image": "TERMINAL_DAYTONA_IMAGE",
|
||||
"terminal.cwd": "TERMINAL_CWD",
|
||||
"terminal.timeout": "TERMINAL_TIMEOUT",
|
||||
"terminal.sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
}
|
||||
if key in _config_to_env_sync:
|
||||
save_env_value(_config_to_env_sync[key], str(value))
|
||||
|
||||
@@ -33,6 +33,26 @@ os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
|
||||
_PROVIDER_ENV_HINTS = (
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"OPENAI_BASE_URL",
|
||||
"GLM_API_KEY",
|
||||
"ZAI_API_KEY",
|
||||
"Z_AI_API_KEY",
|
||||
"KIMI_API_KEY",
|
||||
"MINIMAX_API_KEY",
|
||||
"MINIMAX_CN_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
def _has_provider_env_config(content: str) -> bool:
|
||||
"""Return True when ~/.hermes/.env contains provider auth/base URL settings."""
|
||||
return any(key in content for key in _PROVIDER_ENV_HINTS)
|
||||
|
||||
|
||||
def check_ok(text: str, detail: str = ""):
|
||||
print(f" {color('✓', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
@@ -132,8 +152,8 @@ def run_doctor(args):
|
||||
|
||||
# Check for common issues
|
||||
content = env_path.read_text()
|
||||
if "OPENROUTER_API_KEY" in content or "ANTHROPIC_API_KEY" in content:
|
||||
check_ok("API key configured")
|
||||
if _has_provider_env_config(content):
|
||||
check_ok("API key or custom endpoint configured")
|
||||
else:
|
||||
check_warn("No API key found in ~/.hermes/.env")
|
||||
issues.append("Run 'hermes setup' to configure API keys")
|
||||
@@ -468,7 +488,48 @@ def run_doctor(args):
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ")
|
||||
except Exception as e:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(f'({e})', Colors.DIM)} ")
|
||||
|
||||
|
||||
# -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) --
|
||||
_apikey_providers = [
|
||||
("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"),
|
||||
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
|
||||
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
|
||||
]
|
||||
for _pname, _env_vars, _default_url, _base_env in _apikey_providers:
|
||||
_key = ""
|
||||
for _ev in _env_vars:
|
||||
_key = os.getenv(_ev, "")
|
||||
if _key:
|
||||
break
|
||||
if _key:
|
||||
_label = _pname.ljust(20)
|
||||
print(f" Checking {_pname} API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
_base = os.getenv(_base_env, "")
|
||||
# Auto-detect Kimi Code keys (sk-kimi-) → api.kimi.com
|
||||
if not _base and _key.startswith("sk-kimi-"):
|
||||
_base = "https://api.kimi.com/coding/v1"
|
||||
_url = (_base.rstrip("/") + "/models") if _base else _default_url
|
||||
_headers = {"Authorization": f"Bearer {_key}"}
|
||||
if "api.kimi.com" in _url.lower():
|
||||
_headers["User-Agent"] = "KimiCLI/1.0"
|
||||
_resp = httpx.get(
|
||||
_url,
|
||||
headers=_headers,
|
||||
timeout=10,
|
||||
)
|
||||
if _resp.status_code == 200:
|
||||
print(f"\r {color('✓', Colors.GREEN)} {_label} ")
|
||||
elif _resp.status_code == 401:
|
||||
print(f"\r {color('✗', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ")
|
||||
issues.append(f"Check {_env_vars[0]} in .env")
|
||||
else:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ")
|
||||
except Exception as _e:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Submodules
|
||||
# =========================================================================
|
||||
|
||||
@@ -154,19 +154,33 @@ def get_hermes_cli_path() -> str:
|
||||
# =============================================================================
|
||||
|
||||
def generate_systemd_unit() -> str:
|
||||
import shutil
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
venv_bin = str(PROJECT_ROOT / "venv" / "bin")
|
||||
node_bin = str(PROJECT_ROOT / "node_modules" / ".bin")
|
||||
|
||||
# Build a PATH that includes the venv, node_modules, and standard system dirs
|
||||
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
||||
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run --replace
|
||||
ExecStop={hermes_cli} gateway stop
|
||||
WorkingDirectory={working_dir}
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=15
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -377,8 +391,15 @@ def launchd_status(deep: bool = False):
|
||||
# Gateway Runner
|
||||
# =============================================================================
|
||||
|
||||
def run_gateway(verbose: bool = False):
|
||||
"""Run the gateway in foreground."""
|
||||
def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
"""Run the gateway in foreground.
|
||||
|
||||
Args:
|
||||
verbose: Enable verbose logging output.
|
||||
replace: If True, kill any existing gateway instance before starting.
|
||||
This prevents systemd restart loops when the old process
|
||||
hasn't fully exited yet.
|
||||
"""
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from gateway.run import start_gateway
|
||||
@@ -393,7 +414,7 @@ def run_gateway(verbose: bool = False):
|
||||
|
||||
# Exit with code 1 if gateway fails to connect any platform,
|
||||
# so systemd Restart=on-failure will retry on transient errors
|
||||
success = asyncio.run(start_gateway())
|
||||
success = asyncio.run(start_gateway(replace=replace))
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
@@ -765,7 +786,8 @@ def gateway_command(args):
|
||||
# Default to run if no subcommand
|
||||
if subcmd is None or subcmd == "run":
|
||||
verbose = getattr(args, 'verbose', False)
|
||||
run_gateway(verbose)
|
||||
replace = getattr(args, 'replace', False)
|
||||
run_gateway(verbose, replace=replace)
|
||||
return
|
||||
|
||||
if subcmd == "setup":
|
||||
|
||||
@@ -64,7 +64,13 @@ def _has_any_provider_configured() -> bool:
|
||||
# Check env vars (may be set by .env or shell).
|
||||
# OPENAI_BASE_URL alone counts — local models (vLLM, llama.cpp, etc.)
|
||||
# often don't require an API key.
|
||||
provider_env_vars = ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL")
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
|
||||
# Collect all provider env vars
|
||||
provider_env_vars = {"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL"}
|
||||
for pconfig in PROVIDER_REGISTRY.values():
|
||||
if pconfig.auth_type == "api_key":
|
||||
provider_env_vars.update(pconfig.api_key_env_vars)
|
||||
if any(os.getenv(v) for v in provider_env_vars):
|
||||
return True
|
||||
|
||||
@@ -114,16 +120,63 @@ def _resolve_last_cli_session() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
|
||||
"""Resolve a session name (title) or ID to a session ID.
|
||||
|
||||
- If it looks like a session ID (contains underscore + hex), try direct lookup first.
|
||||
- Otherwise, treat it as a title and use resolve_session_by_title (auto-latest).
|
||||
- Falls back to the other method if the first doesn't match.
|
||||
"""
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB()
|
||||
|
||||
# Try as exact session ID first
|
||||
session = db.get_session(name_or_id)
|
||||
if session:
|
||||
db.close()
|
||||
return session["id"]
|
||||
|
||||
# Try as title (with auto-latest for lineage)
|
||||
session_id = db.resolve_session_by_title(name_or_id)
|
||||
db.close()
|
||||
return session_id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def cmd_chat(args):
|
||||
"""Run interactive chat CLI."""
|
||||
# Resolve --continue into --resume with the latest CLI session
|
||||
if getattr(args, "continue_last", False) and not getattr(args, "resume", None):
|
||||
last_id = _resolve_last_cli_session()
|
||||
if last_id:
|
||||
args.resume = last_id
|
||||
# Resolve --continue into --resume with the latest CLI session or by name
|
||||
continue_val = getattr(args, "continue_last", None)
|
||||
if continue_val and not getattr(args, "resume", None):
|
||||
if isinstance(continue_val, str):
|
||||
# -c "session name" — resolve by title or ID
|
||||
resolved = _resolve_session_by_name_or_id(continue_val)
|
||||
if resolved:
|
||||
args.resume = resolved
|
||||
else:
|
||||
print(f"No session found matching '{continue_val}'.")
|
||||
print("Use 'hermes sessions list' to see available sessions.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("No previous CLI session found to continue.")
|
||||
sys.exit(1)
|
||||
# -c with no argument — continue the most recent session
|
||||
last_id = _resolve_last_cli_session()
|
||||
if last_id:
|
||||
args.resume = last_id
|
||||
else:
|
||||
print("No previous CLI session found to continue.")
|
||||
sys.exit(1)
|
||||
|
||||
# Resolve --resume by title if it's not a direct session ID
|
||||
resume_val = getattr(args, "resume", None)
|
||||
if resume_val:
|
||||
resolved = _resolve_session_by_name_or_id(resume_val)
|
||||
if resolved:
|
||||
args.resume = resolved
|
||||
# If resolution fails, keep the original value — _init_agent will
|
||||
# report "Session not found" with the original input
|
||||
|
||||
# First-run guard: check if any provider is configured before launching
|
||||
if not _has_any_provider_configured():
|
||||
@@ -150,6 +203,10 @@ def cmd_chat(args):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --pass-session-id: include session ID in the agent's system prompt
|
||||
if getattr(args, "pass_session_id", False):
|
||||
os.environ["HERMES_PASS_SESSION_ID"] = "1"
|
||||
|
||||
# Import and run the CLI
|
||||
from cli import main as cli_main
|
||||
|
||||
@@ -161,6 +218,7 @@ def cmd_chat(args):
|
||||
"verbose": args.verbose,
|
||||
"query": args.query,
|
||||
"resume": getattr(args, "resume", None),
|
||||
"worktree": getattr(args, "worktree", False),
|
||||
}
|
||||
# Filter out None values
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
@@ -411,6 +469,10 @@ def cmd_model(args):
|
||||
"openrouter": "OpenRouter",
|
||||
"nous": "Nous Portal",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
active_label = provider_labels.get(active, active)
|
||||
@@ -425,11 +487,16 @@ def cmd_model(args):
|
||||
("openrouter", "OpenRouter (100+ models, pay-per-use)"),
|
||||
("nous", "Nous Portal (Nous Research subscription)"),
|
||||
("openai-codex", "OpenAI Codex"),
|
||||
("zai", "Z.AI / GLM (Zhipu AI direct API)"),
|
||||
("kimi-coding", "Kimi / Moonshot (Moonshot AI direct API)"),
|
||||
("minimax", "MiniMax (global direct API)"),
|
||||
("minimax-cn", "MiniMax China (domestic direct API)"),
|
||||
("custom", "Custom endpoint (self-hosted / VLLM / etc.)"),
|
||||
]
|
||||
|
||||
# Reorder so the active provider is at the top
|
||||
active_key = active if active in ("openrouter", "nous", "openai-codex") else "custom"
|
||||
known_keys = {k for k, _ in providers}
|
||||
active_key = active if active in known_keys else "custom"
|
||||
ordered = []
|
||||
for key, label in providers:
|
||||
if key == active_key:
|
||||
@@ -454,6 +521,8 @@ def cmd_model(args):
|
||||
_model_flow_openai_codex(config, current_model)
|
||||
elif selected_provider == "custom":
|
||||
_model_flow_custom(config)
|
||||
elif selected_provider in ("zai", "kimi-coding", "minimax", "minimax-cn"):
|
||||
_model_flow_api_key_provider(config, selected_provider, current_model)
|
||||
|
||||
|
||||
def _prompt_provider_choice(choices):
|
||||
@@ -723,6 +792,117 @@ def _model_flow_custom(config):
|
||||
print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.")
|
||||
|
||||
|
||||
# Curated model lists for direct API-key providers
|
||||
_PROVIDER_MODELS = {
|
||||
"zai": [
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
"minimax-cn": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
"""Generic flow for API-key providers (z.ai, Kimi, MiniMax)."""
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY, _prompt_model_selection, _save_model_choice,
|
||||
_update_config_for_provider, deactivate_provider,
|
||||
)
|
||||
from hermes_cli.config import get_env_value, save_env_value, load_config, save_config
|
||||
|
||||
pconfig = PROVIDER_REGISTRY[provider_id]
|
||||
key_env = pconfig.api_key_env_vars[0] if pconfig.api_key_env_vars else ""
|
||||
base_url_env = pconfig.base_url_env_var or ""
|
||||
|
||||
# Check / prompt for API key
|
||||
existing_key = ""
|
||||
for ev in pconfig.api_key_env_vars:
|
||||
existing_key = get_env_value(ev) or os.getenv(ev, "")
|
||||
if existing_key:
|
||||
break
|
||||
|
||||
if not existing_key:
|
||||
print(f"No {pconfig.name} API key configured.")
|
||||
if key_env:
|
||||
try:
|
||||
new_key = input(f"{key_env} (or Enter to cancel): ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return
|
||||
if not new_key:
|
||||
print("Cancelled.")
|
||||
return
|
||||
save_env_value(key_env, new_key)
|
||||
print("API key saved.")
|
||||
print()
|
||||
else:
|
||||
print(f" {pconfig.name} API key: {existing_key[:8]}... ✓")
|
||||
print()
|
||||
|
||||
# Optional base URL override
|
||||
current_base = ""
|
||||
if base_url_env:
|
||||
current_base = get_env_value(base_url_env) or os.getenv(base_url_env, "")
|
||||
effective_base = current_base or pconfig.inference_base_url
|
||||
|
||||
try:
|
||||
override = input(f"Base URL [{effective_base}]: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
override = ""
|
||||
if override and base_url_env:
|
||||
save_env_value(base_url_env, override)
|
||||
effective_base = override
|
||||
|
||||
# Model selection
|
||||
model_list = _PROVIDER_MODELS.get(provider_id, [])
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
else:
|
||||
try:
|
||||
selected = input("Model name: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
selected = None
|
||||
|
||||
if selected:
|
||||
# Clear custom endpoint if set (avoid confusion)
|
||||
if get_env_value("OPENAI_BASE_URL"):
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
|
||||
_save_model_choice(selected)
|
||||
|
||||
# Update config with provider and base URL
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if isinstance(model, dict):
|
||||
model["provider"] = provider_id
|
||||
model["base_url"] = effective_base
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
print(f"Default model set to: {selected} (via {pconfig.name})")
|
||||
else:
|
||||
print("No change.")
|
||||
|
||||
|
||||
def cmd_login(args):
|
||||
"""Authenticate Hermes CLI with a provider."""
|
||||
from hermes_cli.auth import login_command
|
||||
@@ -864,6 +1044,8 @@ def _update_via_zip(args):
|
||||
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
|
||||
if result.get("updated"):
|
||||
print(f" ↑ {len(result['updated'])} updated: {', '.join(result['updated'])}")
|
||||
if result.get("user_modified"):
|
||||
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
|
||||
if result.get("cleaned"):
|
||||
print(f" − {len(result['cleaned'])} removed from manifest")
|
||||
if not result["copied"] and not result.get("updated"):
|
||||
@@ -982,6 +1164,8 @@ def cmd_update(args):
|
||||
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
|
||||
if result.get("updated"):
|
||||
print(f" ↑ {len(result['updated'])} updated: {', '.join(result['updated'])}")
|
||||
if result.get("user_modified"):
|
||||
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
|
||||
if result.get("cleaned"):
|
||||
print(f" − {len(result['cleaned'])} removed from manifest")
|
||||
if not result["copied"] and not result.get("updated"):
|
||||
@@ -1076,8 +1260,9 @@ def main():
|
||||
Examples:
|
||||
hermes Start interactive chat
|
||||
hermes chat -q "Hello" Single query mode
|
||||
hermes --continue Resume the most recent session
|
||||
hermes --resume <session_id> Resume a specific session
|
||||
hermes -c Resume the most recent session
|
||||
hermes -c "my project" Resume a session by name (latest in lineage)
|
||||
hermes --resume <session_id> Resume a specific session by ID
|
||||
hermes setup Run setup wizard
|
||||
hermes logout Clear stored authentication
|
||||
hermes model Select default model
|
||||
@@ -1085,8 +1270,10 @@ Examples:
|
||||
hermes config edit Edit config in $EDITOR
|
||||
hermes config set model gpt-4 Set a config value
|
||||
hermes gateway Run messaging gateway
|
||||
hermes -w Start in isolated git worktree
|
||||
hermes gateway install Install as system service
|
||||
hermes sessions list List past sessions
|
||||
hermes sessions rename ID T Rename/title a session
|
||||
hermes update Update to latest version
|
||||
|
||||
For more help on a command:
|
||||
@@ -1101,16 +1288,30 @@ For more help on a command:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume", "-r",
|
||||
metavar="SESSION_ID",
|
||||
metavar="SESSION",
|
||||
default=None,
|
||||
help="Resume a previous session by ID (shortcut for: hermes chat --resume ID)"
|
||||
help="Resume a previous session by ID or title"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--continue", "-c",
|
||||
dest="continue_last",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=None,
|
||||
metavar="SESSION_NAME",
|
||||
help="Resume a session by name, or the most recent if no name given"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--worktree", "-w",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Resume the most recent CLI session"
|
||||
help="Run in an isolated git worktree (for parallel agents)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pass-session-id",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Include the session ID in the agent's system prompt"
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
@@ -1137,7 +1338,7 @@ For more help on a command:
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--provider",
|
||||
choices=["auto", "openrouter", "nous", "openai-codex"],
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn"],
|
||||
default=None,
|
||||
help="Inference provider (default: auto)"
|
||||
)
|
||||
@@ -1154,9 +1355,23 @@ For more help on a command:
|
||||
chat_parser.add_argument(
|
||||
"--continue", "-c",
|
||||
dest="continue_last",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=None,
|
||||
metavar="SESSION_NAME",
|
||||
help="Resume a session by name, or the most recent if no name given"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--worktree", "-w",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Resume the most recent CLI session"
|
||||
help="Run in an isolated git worktree (for parallel agents on the same repo)"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--pass-session-id",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Include the session ID in the agent's system prompt"
|
||||
)
|
||||
chat_parser.set_defaults(func=cmd_chat)
|
||||
|
||||
@@ -1183,6 +1398,8 @@ For more help on a command:
|
||||
# gateway run (default)
|
||||
gateway_run = gateway_subparsers.add_parser("run", help="Run gateway in foreground")
|
||||
gateway_run.add_argument("-v", "--verbose", action="store_true")
|
||||
gateway_run.add_argument("--replace", action="store_true",
|
||||
help="Replace any existing gateway instance (useful for systemd)")
|
||||
|
||||
# gateway start
|
||||
gateway_start = gateway_subparsers.add_parser("start", help="Start gateway service")
|
||||
@@ -1215,7 +1432,15 @@ For more help on a command:
|
||||
setup_parser = subparsers.add_parser(
|
||||
"setup",
|
||||
help="Interactive setup wizard",
|
||||
description="Configure Hermes Agent with an interactive wizard"
|
||||
description="Configure Hermes Agent with an interactive wizard. "
|
||||
"Run a specific section: hermes setup model|terminal|gateway|tools|agent"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"section",
|
||||
nargs="?",
|
||||
choices=["model", "terminal", "gateway", "tools", "agent"],
|
||||
default=None,
|
||||
help="Run a specific setup section instead of the full wizard"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"--non-interactive",
|
||||
@@ -1515,7 +1740,7 @@ For more help on a command:
|
||||
# =========================================================================
|
||||
sessions_parser = subparsers.add_parser(
|
||||
"sessions",
|
||||
help="Manage session history (list, export, prune, delete)",
|
||||
help="Manage session history (list, rename, export, prune, delete)",
|
||||
description="View and manage the SQLite session store"
|
||||
)
|
||||
sessions_subparsers = sessions_parser.add_subparsers(dest="sessions_action")
|
||||
@@ -1540,6 +1765,10 @@ For more help on a command:
|
||||
|
||||
sessions_stats = sessions_subparsers.add_parser("stats", help="Show session store statistics")
|
||||
|
||||
sessions_rename = sessions_subparsers.add_parser("rename", help="Set or change a session's title")
|
||||
sessions_rename.add_argument("session_id", help="Session ID to rename")
|
||||
sessions_rename.add_argument("title", nargs="+", help="New title for the session")
|
||||
|
||||
def cmd_sessions(args):
|
||||
import json as _json
|
||||
try:
|
||||
@@ -1552,18 +1781,51 @@ For more help on a command:
|
||||
action = args.sessions_action
|
||||
|
||||
if action == "list":
|
||||
sessions = db.search_sessions(source=args.source, limit=args.limit)
|
||||
sessions = db.list_sessions_rich(source=args.source, limit=args.limit)
|
||||
if not sessions:
|
||||
print("No sessions found.")
|
||||
return
|
||||
print(f"{'ID':<30} {'Source':<12} {'Model':<30} {'Messages':>8} {'Started'}")
|
||||
print("─" * 100)
|
||||
from datetime import datetime
|
||||
import time as _time
|
||||
|
||||
def _relative_time(ts):
|
||||
"""Format a timestamp as relative time (e.g., '2h ago', 'yesterday')."""
|
||||
if not ts:
|
||||
return "?"
|
||||
delta = _time.time() - ts
|
||||
if delta < 60:
|
||||
return "just now"
|
||||
elif delta < 3600:
|
||||
mins = int(delta / 60)
|
||||
return f"{mins}m ago"
|
||||
elif delta < 86400:
|
||||
hours = int(delta / 3600)
|
||||
return f"{hours}h ago"
|
||||
elif delta < 172800:
|
||||
return "yesterday"
|
||||
elif delta < 604800:
|
||||
days = int(delta / 86400)
|
||||
return f"{days}d ago"
|
||||
else:
|
||||
return datetime.fromtimestamp(ts).strftime("%Y-%m-%d")
|
||||
|
||||
has_titles = any(s.get("title") for s in sessions)
|
||||
if has_titles:
|
||||
print(f"{'Title':<22} {'Preview':<40} {'Last Active':<13} {'ID'}")
|
||||
print("─" * 100)
|
||||
else:
|
||||
print(f"{'Preview':<50} {'Last Active':<13} {'Src':<6} {'ID'}")
|
||||
print("─" * 90)
|
||||
for s in sessions:
|
||||
started = datetime.fromtimestamp(s["started_at"]).strftime("%Y-%m-%d %H:%M") if s["started_at"] else "?"
|
||||
model = (s.get("model") or "?")[:28]
|
||||
ended = " (ended)" if s.get("ended_at") else ""
|
||||
print(f"{s['id']:<30} {s['source']:<12} {model:<30} {s['message_count']:>8} {started}{ended}")
|
||||
last_active = _relative_time(s.get("last_active"))
|
||||
preview = s.get("preview", "")[:38] if has_titles else s.get("preview", "")[:48]
|
||||
if has_titles:
|
||||
title = (s.get("title") or "—")[:20]
|
||||
sid = s["id"][:20]
|
||||
print(f"{title:<22} {preview:<40} {last_active:<13} {sid}")
|
||||
else:
|
||||
sid = s["id"][:20]
|
||||
print(f"{preview:<50} {last_active:<13} {s['source']:<6} {sid}")
|
||||
|
||||
elif action == "export":
|
||||
if args.session_id:
|
||||
@@ -1603,6 +1865,16 @@ For more help on a command:
|
||||
count = db.prune_sessions(older_than_days=days, source=args.source)
|
||||
print(f"Pruned {count} session(s).")
|
||||
|
||||
elif action == "rename":
|
||||
title = " ".join(args.title)
|
||||
try:
|
||||
if db.set_session_title(args.session_id, title):
|
||||
print(f"Session '{args.session_id}' renamed to: {title}")
|
||||
else:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
elif action == "stats":
|
||||
total = db.session_count()
|
||||
msgs = db.message_count()
|
||||
@@ -1708,6 +1980,8 @@ For more help on a command:
|
||||
args.provider = None
|
||||
args.toolsets = None
|
||||
args.verbose = False
|
||||
if not hasattr(args, "worktree"):
|
||||
args.worktree = False
|
||||
cmd_chat(args)
|
||||
return
|
||||
|
||||
@@ -1719,7 +1993,9 @@ For more help on a command:
|
||||
args.toolsets = None
|
||||
args.verbose = False
|
||||
args.resume = None
|
||||
args.continue_last = False
|
||||
args.continue_last = None
|
||||
if not hasattr(args, "worktree"):
|
||||
args.worktree = False
|
||||
cmd_chat(args)
|
||||
return
|
||||
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
"""
|
||||
Canonical list of OpenRouter models offered in CLI and setup wizards.
|
||||
Canonical model catalogs and lightweight validation helpers.
|
||||
|
||||
Add, remove, or reorder entries here — both `hermes setup` and
|
||||
`hermes` provider-selection will pick up the change automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from difflib import get_close_matches
|
||||
from typing import Any, Optional
|
||||
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
@@ -14,17 +22,64 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
]
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"zai": [
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
"minimax-cn": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
}
|
||||
|
||||
_PROVIDER_LABELS = {
|
||||
"openrouter": "OpenRouter",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"nous": "Nous Portal",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"custom": "custom endpoint",
|
||||
}
|
||||
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai",
|
||||
"z-ai": "zai",
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
"kimi": "kimi-coding",
|
||||
"moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn",
|
||||
"minimax_cn": "minimax-cn",
|
||||
}
|
||||
|
||||
|
||||
def model_ids() -> list[str]:
|
||||
"""Return just the model-id strings (convenience helper)."""
|
||||
"""Return just the OpenRouter model-id strings."""
|
||||
return [mid for mid, _ in OPENROUTER_MODELS]
|
||||
|
||||
|
||||
@@ -34,3 +89,231 @@ def menu_labels() -> list[str]:
|
||||
for mid, desc in OPENROUTER_MODELS:
|
||||
labels.append(f"{mid} ({desc})" if desc else mid)
|
||||
return labels
|
||||
|
||||
|
||||
# All provider IDs and aliases that are valid for the provider:model syntax.
|
||||
_KNOWN_PROVIDER_NAMES: set[str] = (
|
||||
set(_PROVIDER_LABELS.keys())
|
||||
| set(_PROVIDER_ALIASES.keys())
|
||||
| {"openrouter", "custom"}
|
||||
)
|
||||
|
||||
|
||||
def list_available_providers() -> list[dict[str, str]]:
|
||||
"""Return info about all providers the user could use with ``provider:model``.
|
||||
|
||||
Each dict has ``id``, ``label``, and ``aliases``.
|
||||
Checks which providers have valid credentials configured.
|
||||
"""
|
||||
# Canonical providers in display order
|
||||
_PROVIDER_ORDER = [
|
||||
"openrouter", "nous", "openai-codex",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn",
|
||||
]
|
||||
# Build reverse alias map
|
||||
aliases_for: dict[str, list[str]] = {}
|
||||
for alias, canonical in _PROVIDER_ALIASES.items():
|
||||
aliases_for.setdefault(canonical, []).append(alias)
|
||||
|
||||
result = []
|
||||
for pid in _PROVIDER_ORDER:
|
||||
label = _PROVIDER_LABELS.get(pid, pid)
|
||||
alias_list = aliases_for.get(pid, [])
|
||||
# Check if this provider has credentials available
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=pid)
|
||||
has_creds = bool(runtime.get("api_key"))
|
||||
except Exception:
|
||||
pass
|
||||
result.append({
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"aliases": alias_list,
|
||||
"authenticated": has_creds,
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
|
||||
"""Parse ``/model`` input into ``(provider, model)``.
|
||||
|
||||
Supports ``provider:model`` syntax to switch providers at runtime::
|
||||
|
||||
openrouter:anthropic/claude-sonnet-4.5 → ("openrouter", "anthropic/claude-sonnet-4.5")
|
||||
nous:hermes-3 → ("nous", "hermes-3")
|
||||
anthropic/claude-sonnet-4.5 → (current_provider, "anthropic/claude-sonnet-4.5")
|
||||
gpt-5.4 → (current_provider, "gpt-5.4")
|
||||
|
||||
The colon is only treated as a provider delimiter if the left side is a
|
||||
recognized provider name or alias. This avoids misinterpreting model names
|
||||
that happen to contain colons (e.g. ``anthropic/claude-3.5-sonnet:beta``).
|
||||
|
||||
Returns ``(provider, model)`` where *provider* is either the explicit
|
||||
provider from the input or *current_provider* if none was specified.
|
||||
"""
|
||||
stripped = raw.strip()
|
||||
colon = stripped.find(":")
|
||||
if colon > 0:
|
||||
provider_part = stripped[:colon].strip().lower()
|
||||
model_part = stripped[colon + 1:].strip()
|
||||
if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES:
|
||||
return (normalize_provider(provider_part), model_part)
|
||||
return (current_provider, stripped)
|
||||
|
||||
|
||||
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
|
||||
"""Return ``(model_id, description)`` tuples for a provider's curated list."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return list(OPENROUTER_MODELS)
|
||||
models = _PROVIDER_MODELS.get(normalized, [])
|
||||
return [(m, "") for m in models]
|
||||
|
||||
|
||||
def normalize_provider(provider: Optional[str]) -> str:
|
||||
"""Normalize provider aliases to Hermes' canonical provider ids.
|
||||
|
||||
Note: ``"auto"`` passes through unchanged — use
|
||||
``hermes_cli.auth.resolve_provider()`` to resolve it to a concrete
|
||||
provider based on credentials and environment.
|
||||
"""
|
||||
normalized = (provider or "openrouter").strip().lower()
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
|
||||
def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
"""Return the best known model catalog for a provider."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return model_ids()
|
||||
if normalized == "openai-codex":
|
||||
from hermes_cli.codex_models import get_codex_model_ids
|
||||
|
||||
return get_codex_model_ids()
|
||||
return list(_PROVIDER_MODELS.get(normalized, []))
|
||||
|
||||
|
||||
def fetch_api_models(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
timeout: float = 5.0,
|
||||
) -> Optional[list[str]]:
|
||||
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
|
||||
|
||||
Returns a list of model ID strings, or ``None`` if the endpoint could not
|
||||
be reached (network error, timeout, auth failure, etc.).
|
||||
"""
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
url = base_url.rstrip("/") + "/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
# Standard OpenAI format: {"data": [{"id": "model-name", ...}, ...]}
|
||||
return [m.get("id", "") for m in data.get("data", [])]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def validate_requested_model(
|
||||
model_name: str,
|
||||
provider: Optional[str],
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate a ``/model`` value for the active provider.
|
||||
|
||||
Performs format checks first, then probes the live API to confirm
|
||||
the model actually exists.
|
||||
|
||||
Returns a dict with:
|
||||
- accepted: whether the CLI should switch to the requested model now
|
||||
- persist: whether it is safe to save to config
|
||||
- recognized: whether it matched a known provider catalog
|
||||
- message: optional warning / guidance for the user
|
||||
"""
|
||||
requested = (model_name or "").strip()
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter" and base_url and "openrouter.ai" not in base_url:
|
||||
normalized = "custom"
|
||||
|
||||
if not requested:
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": "Model name cannot be empty.",
|
||||
}
|
||||
|
||||
if any(ch.isspace() for ch in requested):
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": "Model names cannot contain spaces.",
|
||||
}
|
||||
|
||||
# Probe the live API to check if the model actually exists
|
||||
api_models = fetch_api_models(api_key, base_url)
|
||||
|
||||
if api_models is not None:
|
||||
if requested in set(api_models):
|
||||
# API confirmed the model exists
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
else:
|
||||
# API responded but model is not listed
|
||||
suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5)
|
||||
suggestion_text = ""
|
||||
if suggestions:
|
||||
suggestion_text = "\n Did you mean: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Error: `{requested}` is not a valid model for this provider."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
# api_models is None — couldn't reach API, fall back to catalog check
|
||||
provider_label = _PROVIDER_LABELS.get(normalized, normalized)
|
||||
known_models = provider_model_ids(normalized)
|
||||
|
||||
if requested in known_models:
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
# Can't validate — accept for session only
|
||||
suggestion = get_close_matches(requested, known_models, n=1, cutoff=0.6)
|
||||
suggestion_text = f" Did you mean `{suggestion[0]}`?" if suggestion else ""
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Could not validate `{requested}` against the live {provider_label} API. "
|
||||
"Using it for this session only; config unchanged."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
@@ -7,10 +7,12 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
PROVIDER_REGISTRY,
|
||||
format_auth_error,
|
||||
resolve_provider,
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
)
|
||||
from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
@@ -72,12 +74,26 @@ def _resolve_openrouter_runtime(
|
||||
or OPENROUTER_BASE_URL
|
||||
).rstrip("/")
|
||||
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
# Choose API key based on whether the resolved base_url targets OpenRouter.
|
||||
# When hitting OpenRouter, prefer OPENROUTER_API_KEY (issue #289).
|
||||
# When hitting a custom endpoint (e.g. Z.ai, local LLM), prefer
|
||||
# OPENAI_API_KEY so the OpenRouter key doesn't leak to an unrelated
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = "openrouter.ai" in base_url
|
||||
if _is_openrouter_url:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
else:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
|
||||
|
||||
@@ -132,6 +148,19 @@ def resolve_runtime_provider(
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
# API-key providers (z.ai/GLM, Kimi, MiniMax, MiniMax-CN)
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": creds.get("base_url", "").rstrip("/"),
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"source": creds.get("source", "env"),
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
runtime = _resolve_openrouter_runtime(
|
||||
requested_provider=requested_provider,
|
||||
explicit_api_key=explicit_api_key,
|
||||
|
||||
1675
hermes_cli/setup.py
1675
hermes_cli/setup.py
File diff suppressed because it is too large
Load Diff
@@ -79,8 +79,12 @@ def show_status(args):
|
||||
"OpenRouter": "OPENROUTER_API_KEY",
|
||||
"Anthropic": "ANTHROPIC_API_KEY",
|
||||
"OpenAI": "OPENAI_API_KEY",
|
||||
"Z.AI/GLM": "GLM_API_KEY",
|
||||
"Kimi": "KIMI_API_KEY",
|
||||
"MiniMax": "MINIMAX_API_KEY",
|
||||
"MiniMax-CN": "MINIMAX_CN_API_KEY",
|
||||
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
@@ -137,6 +141,28 @@ def show_status(args):
|
||||
if codex_status.get("error") and not codex_logged_in:
|
||||
print(f" Error: {codex_status.get('error')}")
|
||||
|
||||
# =========================================================================
|
||||
# API-Key Providers
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
apikey_providers = {
|
||||
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"Kimi / Moonshot": ("KIMI_API_KEY",),
|
||||
"MiniMax": ("MINIMAX_API_KEY",),
|
||||
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
|
||||
}
|
||||
for pname, env_vars in apikey_providers.items():
|
||||
key_val = ""
|
||||
for ev in env_vars:
|
||||
key_val = get_env_value(ev) or ""
|
||||
if key_val:
|
||||
break
|
||||
configured = bool(key_val)
|
||||
label = "configured" if configured else "not configured (run: hermes model)"
|
||||
print(f" {pname:<16} {check_mark(configured)} {label}")
|
||||
|
||||
# =========================================================================
|
||||
# Terminal Configuration
|
||||
# =========================================================================
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""
|
||||
Interactive tool configuration for Hermes Agent.
|
||||
Unified tool configuration for Hermes Agent.
|
||||
|
||||
`hermes tools` and `hermes setup tools` both enter this module.
|
||||
Select a platform → toggle toolsets on/off → for newly enabled tools
|
||||
that need API keys, run through provider-aware configuration.
|
||||
|
||||
`hermes tools` — select a platform, then toggle toolsets on/off via checklist.
|
||||
Saves per-platform tool configuration to ~/.hermes/config.yaml under
|
||||
the `platform_toolsets` key.
|
||||
"""
|
||||
@@ -12,9 +15,63 @@ from typing import Dict, List, Set
|
||||
|
||||
import os
|
||||
|
||||
from hermes_cli.config import load_config, save_config, get_env_value, save_env_value
|
||||
from hermes_cli.config import (
|
||||
load_config, save_config, get_env_value, save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
|
||||
|
||||
def _print_info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _print_success(text: str):
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
def _print_warning(text: str):
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _print_error(text: str):
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
else:
|
||||
display = f"{question}: "
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default or ""
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default or ""
|
||||
|
||||
def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
while True:
|
||||
try:
|
||||
value = input(color(f"{question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
if not value:
|
||||
return default
|
||||
if value in ('y', 'yes'):
|
||||
return True
|
||||
if value in ('n', 'no'):
|
||||
return False
|
||||
|
||||
|
||||
# ─── Toolset Registry ─────────────────────────────────────────────────────────
|
||||
|
||||
# Toolsets shown in the configurator, grouped for display.
|
||||
# Each entry: (toolset_name, label, description)
|
||||
# These map to keys in toolsets.py TOOLSETS dict.
|
||||
@@ -49,6 +106,187 @@ PLATFORMS = {
|
||||
}
|
||||
|
||||
|
||||
# ─── Tool Categories (provider-aware configuration) ──────────────────────────
|
||||
# Maps toolset keys to their provider options. When a toolset is newly enabled,
|
||||
# we use this to show provider selection and prompt for the right API keys.
|
||||
# Toolsets not in this map either need no config or use the simple fallback.
|
||||
|
||||
TOOL_CATEGORIES = {
|
||||
"tts": {
|
||||
"name": "Text-to-Speech",
|
||||
"icon": "🔊",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Microsoft Edge TTS",
|
||||
"tag": "Free - no API key needed",
|
||||
"env_vars": [],
|
||||
"tts_provider": "edge",
|
||||
},
|
||||
{
|
||||
"name": "OpenAI TTS",
|
||||
"tag": "Premium - high quality voices",
|
||||
"env_vars": [
|
||||
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
|
||||
],
|
||||
"tts_provider": "openai",
|
||||
},
|
||||
{
|
||||
"name": "ElevenLabs",
|
||||
"tag": "Premium - most natural voices",
|
||||
"env_vars": [
|
||||
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
|
||||
],
|
||||
"tts_provider": "elevenlabs",
|
||||
},
|
||||
],
|
||||
},
|
||||
"web": {
|
||||
"name": "Web Search & Extract",
|
||||
"icon": "🔍",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Firecrawl Cloud",
|
||||
"tag": "Recommended - hosted service",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl Self-Hosted",
|
||||
"tag": "Free - run your own instance",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"image_gen": {
|
||||
"name": "Image Generation",
|
||||
"icon": "🎨",
|
||||
"providers": [
|
||||
{
|
||||
"name": "FAL.ai",
|
||||
"tag": "FLUX 2 Pro with auto-upscaling",
|
||||
"env_vars": [
|
||||
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"browser": {
|
||||
"name": "Browser Automation",
|
||||
"icon": "🌐",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Local Browser",
|
||||
"tag": "Free headless Chromium (no API key needed)",
|
||||
"env_vars": [],
|
||||
"post_setup": "browserbase", # Same npm install for agent-browser
|
||||
},
|
||||
{
|
||||
"name": "Browserbase",
|
||||
"tag": "Cloud browser with stealth & proxies",
|
||||
"env_vars": [
|
||||
{"key": "BROWSERBASE_API_KEY", "prompt": "Browserbase API key", "url": "https://browserbase.com"},
|
||||
{"key": "BROWSERBASE_PROJECT_ID", "prompt": "Browserbase project ID"},
|
||||
],
|
||||
"post_setup": "browserbase",
|
||||
},
|
||||
],
|
||||
},
|
||||
"homeassistant": {
|
||||
"name": "Smart Home",
|
||||
"icon": "🏠",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Home Assistant",
|
||||
"tag": "REST API integration",
|
||||
"env_vars": [
|
||||
{"key": "HASS_TOKEN", "prompt": "Home Assistant Long-Lived Access Token"},
|
||||
{"key": "HASS_URL", "prompt": "Home Assistant URL", "default": "http://homeassistant.local:8123"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training",
|
||||
"icon": "🧪",
|
||||
"requires_python": (3, 11),
|
||||
"providers": [
|
||||
{
|
||||
"name": "Tinker / Atropos",
|
||||
"tag": "RL training platform",
|
||||
"env_vars": [
|
||||
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
|
||||
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
|
||||
],
|
||||
"post_setup": "rl_training",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES.
|
||||
# Used as a fallback for tools like vision/moa that just need an API key.
|
||||
TOOLSET_ENV_REQUIREMENTS = {
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
}
|
||||
|
||||
|
||||
# ─── Post-Setup Hooks ─────────────────────────────────────────────────────────
|
||||
|
||||
def _run_post_setup(post_setup_key: str):
|
||||
"""Run post-setup hooks for tools that need extra installation steps."""
|
||||
import shutil
|
||||
if post_setup_key == "browserbase":
|
||||
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
|
||||
if not node_modules.exists() and shutil.which("npm"):
|
||||
_print_info(" Installing Node.js dependencies for browser tools...")
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--silent"],
|
||||
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" Node.js dependencies installed")
|
||||
else:
|
||||
_print_warning(" npm install failed - run manually: cd ~/.hermes/hermes-agent && npm install")
|
||||
elif not node_modules.exists():
|
||||
_print_warning(" Node.js not found - browser tools require: npm install (in hermes-agent directory)")
|
||||
|
||||
elif post_setup_key == "rl_training":
|
||||
try:
|
||||
__import__("tinker_atropos")
|
||||
except ImportError:
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
_print_info(" Installing tinker-atropos submodule...")
|
||||
import subprocess
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
result = subprocess.run(
|
||||
[uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" tinker-atropos installed")
|
||||
else:
|
||||
_print_warning(" tinker-atropos install failed - run manually:")
|
||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
||||
else:
|
||||
_print_warning(" tinker-atropos submodule not found - run:")
|
||||
_print_info(" git submodule update --init --recursive")
|
||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
||||
|
||||
|
||||
# ─── Platform / Toolset Helpers ───────────────────────────────────────────────
|
||||
|
||||
def _get_enabled_platforms() -> List[str]:
|
||||
"""Return platform keys that are configured (have tokens or are CLI)."""
|
||||
enabled = ["cli"]
|
||||
@@ -70,7 +308,7 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
platform_toolsets = config.get("platform_toolsets", {})
|
||||
toolset_names = platform_toolsets.get(platform)
|
||||
|
||||
if not toolset_names or not isinstance(toolset_names, list):
|
||||
if toolset_names is None or not isinstance(toolset_names, list):
|
||||
default_ts = PLATFORMS[platform]["default_toolset"]
|
||||
toolset_names = [default_ts]
|
||||
|
||||
@@ -97,61 +335,117 @@ def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys)."""
|
||||
print(color(question, Colors.YELLOW))
|
||||
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu = TerminalMenu(
|
||||
[f" {c}" for c in choices],
|
||||
cursor_index=default,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
)
|
||||
idx = menu.show()
|
||||
if idx is None:
|
||||
sys.exit(0)
|
||||
print()
|
||||
return idx
|
||||
except (ImportError, NotImplementedError):
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {c}", style) if style else f" {marker} {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _toolset_has_keys(ts_key: str) -> bool:
|
||||
"""Check if a toolset's required API keys are configured."""
|
||||
# Check TOOL_CATEGORIES first (provider-aware)
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
if cat:
|
||||
for provider in cat["providers"]:
|
||||
env_vars = provider.get("env_vars", [])
|
||||
if not env_vars:
|
||||
return True # Free provider (e.g., Edge TTS)
|
||||
if all(get_env_value(v["key"]) for v in env_vars):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Fallback to simple requirements
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return True
|
||||
return all(get_env_value(var) for var, _ in requirements)
|
||||
|
||||
|
||||
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
|
||||
rendering bugs in tmux, iTerm, and other non-standard terminals."""
|
||||
|
||||
# Curses-based single-select — works in tmux, iTerm, and standard terminals
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(0, 0, question, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, c in enumerate(choices):
|
||||
y = i + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {c}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: numbered input (Windows without curses, etc.)
|
||||
print(color(question, Colors.YELLOW))
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
|
||||
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
import platform as _platform
|
||||
|
||||
labels = []
|
||||
for ts_key, ts_label, ts_desc in CONFIGURABLE_TOOLSETS:
|
||||
suffix = ""
|
||||
if not _toolset_has_keys(ts_key) and TOOLSET_ENV_REQUIREMENTS.get(ts_key):
|
||||
suffix = " ⚠ no API key"
|
||||
if not _toolset_has_keys(ts_key) and (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
suffix = " [no API key]"
|
||||
labels.append(f"{ts_label} ({ts_desc}){suffix}")
|
||||
|
||||
pre_selected_indices = [
|
||||
@@ -159,48 +453,8 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
if ts_key in enabled
|
||||
]
|
||||
|
||||
# simple_term_menu multi-select has rendering bugs on macOS terminals,
|
||||
# so we use a curses-based fallback there.
|
||||
use_term_menu = _platform.system() != "Darwin"
|
||||
|
||||
if use_term_menu:
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
|
||||
print(color(f"Tools for {platform_label}", Colors.YELLOW))
|
||||
print(color(" SPACE to toggle, ENTER to confirm.", Colors.DIM))
|
||||
print()
|
||||
|
||||
menu_items = [f" {label}" for label in labels]
|
||||
menu = TerminalMenu(
|
||||
menu_items,
|
||||
multi_select=True,
|
||||
show_multi_select_hint=False,
|
||||
multi_select_cursor="[✓] ",
|
||||
multi_select_select_on_accept=False,
|
||||
multi_select_empty_ok=True,
|
||||
preselected_entries=pre_selected_indices if pre_selected_indices else None,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
clear_menu_on_exit=False,
|
||||
)
|
||||
|
||||
menu.show()
|
||||
|
||||
if menu.chosen_menu_entries is None:
|
||||
return enabled
|
||||
|
||||
selected_indices = list(menu.chosen_menu_indices or [])
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected_indices}
|
||||
|
||||
except (ImportError, NotImplementedError):
|
||||
pass # fall through to curses/numbered fallback
|
||||
|
||||
# Curses-based multi-select — arrow keys + space to toggle + enter to confirm.
|
||||
# Used on macOS (where simple_term_menu ghosts) and as a fallback.
|
||||
# simple_term_menu has rendering bugs in tmux, iTerm, and other terminals.
|
||||
try:
|
||||
import curses
|
||||
selected = set(pre_selected_indices)
|
||||
@@ -302,77 +556,294 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected}
|
||||
|
||||
|
||||
# Map toolset keys to the env vars they require and where to get them
|
||||
TOOLSET_ENV_REQUIREMENTS = {
|
||||
"web": [("FIRECRAWL_API_KEY", "https://firecrawl.dev/")],
|
||||
"browser": [("BROWSERBASE_API_KEY", "https://browserbase.com/"),
|
||||
("BROWSERBASE_PROJECT_ID", None)],
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"image_gen": [("FAL_KEY", "https://fal.ai/")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"tts": [], # Edge TTS is free, no key needed
|
||||
"rl": [("TINKER_API_KEY", "https://tinker-console.thinkingmachines.ai/keys"),
|
||||
("WANDB_API_KEY", "https://wandb.ai/authorize")],
|
||||
"homeassistant": [("HASS_TOKEN", "Home Assistant > Profile > Long-Lived Access Tokens"),
|
||||
("HASS_URL", None)],
|
||||
}
|
||||
# ─── Provider-Aware Configuration ────────────────────────────────────────────
|
||||
|
||||
def _configure_toolset(ts_key: str, config: dict):
|
||||
"""Configure a toolset - provider selection + API keys.
|
||||
|
||||
Uses TOOL_CATEGORIES for provider-aware config, falls back to simple
|
||||
env var prompts for toolsets not in TOOL_CATEGORIES.
|
||||
"""
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
|
||||
if cat:
|
||||
_configure_tool_category(ts_key, cat, config)
|
||||
else:
|
||||
# Simple fallback for vision, moa, etc.
|
||||
_configure_simple_requirements(ts_key)
|
||||
|
||||
|
||||
def _check_and_prompt_requirements(newly_enabled: Set[str]):
|
||||
"""Check if newly enabled toolsets have missing API keys and offer to set them up."""
|
||||
for ts_key in sorted(newly_enabled):
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
continue
|
||||
def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
"""Configure a tool category with provider selection."""
|
||||
icon = cat.get("icon", "")
|
||||
name = cat["name"]
|
||||
providers = cat["providers"]
|
||||
|
||||
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
|
||||
if not missing:
|
||||
continue
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" ⚠ {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
for var, url in missing:
|
||||
if url:
|
||||
print(color(f" {var}", Colors.CYAN) + color(f" ({url})", Colors.DIM))
|
||||
else:
|
||||
print(color(f" {var}", Colors.CYAN))
|
||||
|
||||
print()
|
||||
try:
|
||||
response = input(color(" Set up now? [Y/n] ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Check Python version requirement
|
||||
if cat.get("requires_python"):
|
||||
req = cat["requires_python"]
|
||||
if sys.version_info < req:
|
||||
print()
|
||||
continue
|
||||
_print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})")
|
||||
_print_info(" Upgrade Python and reinstall to enable this tool.")
|
||||
return
|
||||
|
||||
if response in ("", "y", "yes"):
|
||||
for var, url in missing:
|
||||
if url:
|
||||
print(color(f" Get key at: {url}", Colors.DIM))
|
||||
try:
|
||||
import getpass
|
||||
value = getpass.getpass(color(f" {var}: ", Colors.YELLOW))
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
break
|
||||
if value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
print(color(f" ✓ Saved", Colors.GREEN))
|
||||
if len(providers) == 1:
|
||||
# Single provider - configure directly
|
||||
provider = providers[0]
|
||||
print()
|
||||
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
|
||||
if provider.get("tag"):
|
||||
_print_info(f" {provider['tag']}")
|
||||
_configure_provider(provider, config)
|
||||
else:
|
||||
# Multiple providers - let user choose
|
||||
print()
|
||||
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
|
||||
print()
|
||||
|
||||
# Plain text labels only (no ANSI codes in menu items)
|
||||
provider_choices = []
|
||||
for p in providers:
|
||||
tag = f" ({p['tag']})" if p.get("tag") else ""
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
|
||||
else:
|
||||
print(color(f" Skipped", Colors.DIM))
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
|
||||
# Detect current provider as default
|
||||
default_idx = 0
|
||||
for i, p in enumerate(providers):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
default_idx = i
|
||||
break
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
default_idx = i
|
||||
break
|
||||
|
||||
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
|
||||
_configure_provider(providers[provider_idx], config)
|
||||
|
||||
|
||||
def _configure_provider(provider: dict, config: dict):
|
||||
"""Configure a single provider - prompt for API keys and set config."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
|
||||
# Set TTS provider in config if applicable
|
||||
if provider.get("tts_provider"):
|
||||
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
|
||||
# Prompt for each required env var
|
||||
all_configured = True
|
||||
for var in env_vars:
|
||||
existing = get_env_value(var["key"])
|
||||
if existing:
|
||||
_print_success(f" {var['key']}: already configured")
|
||||
# Don't ask to update - this is a new enable flow.
|
||||
# Reconfigure is handled separately.
|
||||
else:
|
||||
print(color(" Skipped — configure later with 'hermes setup'", Colors.DIM))
|
||||
url = var.get("url", "")
|
||||
if url:
|
||||
_print_info(f" Get yours at: {url}")
|
||||
|
||||
default_val = var.get("default", "")
|
||||
if default_val:
|
||||
value = _prompt(f" {var.get('prompt', var['key'])}", default_val)
|
||||
else:
|
||||
value = _prompt(f" {var.get('prompt', var['key'])}", password=True)
|
||||
|
||||
if value:
|
||||
save_env_value(var["key"], value)
|
||||
_print_success(f" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
all_configured = False
|
||||
|
||||
# Run post-setup hooks if needed
|
||||
if provider.get("post_setup") and all_configured:
|
||||
_run_post_setup(provider["post_setup"])
|
||||
|
||||
if all_configured:
|
||||
_print_success(f" {provider['name']} configured!")
|
||||
|
||||
|
||||
def tools_command(args):
|
||||
"""Entry point for `hermes tools`."""
|
||||
def _configure_simple_requirements(ts_key: str):
|
||||
"""Simple fallback for toolsets that just need env vars (no provider selection)."""
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
|
||||
if not missing:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
for var, url in missing:
|
||||
if url:
|
||||
_print_info(f" Get key at: {url}")
|
||||
value = _prompt(f" {var}", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
|
||||
|
||||
def _reconfigure_tool(config: dict):
|
||||
"""Let user reconfigure an existing tool's provider or API key."""
|
||||
# Build list of configurable tools that are currently set up
|
||||
configurable = []
|
||||
for ts_key, ts_label, _ in CONFIGURABLE_TOOLSETS:
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
reqs = TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
if cat or reqs:
|
||||
if _toolset_has_keys(ts_key):
|
||||
configurable.append((ts_key, ts_label))
|
||||
|
||||
if not configurable:
|
||||
_print_info("No configured tools to reconfigure.")
|
||||
return
|
||||
|
||||
choices = [label for _, label in configurable]
|
||||
choices.append("Cancel")
|
||||
|
||||
idx = _prompt_choice(" Which tool would you like to reconfigure?", choices, len(choices) - 1)
|
||||
|
||||
if idx >= len(configurable):
|
||||
return # Cancel
|
||||
|
||||
ts_key, ts_label = configurable[idx]
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
|
||||
if cat:
|
||||
_configure_tool_category_for_reconfig(ts_key, cat, config)
|
||||
else:
|
||||
_reconfigure_simple_requirements(ts_key)
|
||||
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
|
||||
"""Reconfigure a tool category - provider selection + API key update."""
|
||||
icon = cat.get("icon", "")
|
||||
name = cat["name"]
|
||||
providers = cat["providers"]
|
||||
|
||||
if len(providers) == 1:
|
||||
provider = providers[0]
|
||||
print()
|
||||
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
|
||||
_reconfigure_provider(provider, config)
|
||||
else:
|
||||
print()
|
||||
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
|
||||
print()
|
||||
|
||||
provider_choices = []
|
||||
for p in providers:
|
||||
tag = f" ({p['tag']})" if p.get("tag") else ""
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = ""
|
||||
else:
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
|
||||
default_idx = 0
|
||||
for i, p in enumerate(providers):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
default_idx = i
|
||||
break
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
default_idx = i
|
||||
break
|
||||
|
||||
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
|
||||
_reconfigure_provider(providers[provider_idx], config)
|
||||
|
||||
|
||||
def _reconfigure_provider(provider: dict, config: dict):
|
||||
"""Reconfigure a provider - update API keys."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
|
||||
if provider.get("tts_provider"):
|
||||
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
|
||||
_print_success(f" TTS provider set to: {provider['tts_provider']}")
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
|
||||
for var in env_vars:
|
||||
existing = get_env_value(var["key"])
|
||||
if existing:
|
||||
_print_info(f" {var['key']}: configured ({existing[:8]}...)")
|
||||
url = var.get("url", "")
|
||||
if url:
|
||||
_print_info(f" Get yours at: {url}")
|
||||
default_val = var.get("default", "")
|
||||
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
|
||||
if value and value.strip():
|
||||
save_env_value(var["key"], value.strip())
|
||||
_print_success(f" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
|
||||
|
||||
def _reconfigure_simple_requirements(ts_key: str):
|
||||
"""Reconfigure simple env var requirements."""
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label}:", Colors.CYAN))
|
||||
|
||||
for var, url in requirements:
|
||||
existing = get_env_value(var)
|
||||
if existing:
|
||||
_print_info(f" {var}: configured ({existing[:8]}...)")
|
||||
if url:
|
||||
_print_info(f" Get key at: {url}")
|
||||
value = _prompt(f" {var} (Enter to keep current)", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
|
||||
|
||||
# ─── Main Entry Point ─────────────────────────────────────────────────────────
|
||||
|
||||
def tools_command(args=None):
|
||||
"""Entry point for `hermes tools` and `hermes setup tools`."""
|
||||
config = load_config()
|
||||
enabled_platforms = _get_enabled_platforms()
|
||||
|
||||
print()
|
||||
print(color("⚕ Hermes Tool Configuration", Colors.CYAN, Colors.BOLD))
|
||||
print(color(" Enable or disable tools per platform.", Colors.DIM))
|
||||
print(color(" Tools that need API keys will be configured when enabled.", Colors.DIM))
|
||||
print()
|
||||
|
||||
# Build platform choices
|
||||
@@ -380,22 +851,28 @@ def tools_command(args):
|
||||
platform_keys = []
|
||||
for pkey in enabled_platforms:
|
||||
pinfo = PLATFORMS[pkey]
|
||||
# Count currently enabled toolsets
|
||||
current = _get_platform_tools(config, pkey)
|
||||
count = len(current)
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
platform_choices.append(f"Configure {pinfo['label']} ({count}/{total} enabled)")
|
||||
platform_keys.append(pkey)
|
||||
|
||||
platform_choices.append("Done — save and exit")
|
||||
platform_choices.append("Reconfigure an existing tool's provider or API key")
|
||||
platform_choices.append("Done")
|
||||
|
||||
while True:
|
||||
idx = _prompt_choice("Select a platform to configure:", platform_choices, default=0)
|
||||
idx = _prompt_choice("Select an option:", platform_choices, default=0)
|
||||
|
||||
# "Done" selected
|
||||
if idx == len(platform_keys):
|
||||
if idx == len(platform_keys) + 1:
|
||||
break
|
||||
|
||||
# "Reconfigure" selected
|
||||
if idx == len(platform_keys):
|
||||
_reconfigure_tool(config)
|
||||
print()
|
||||
continue
|
||||
|
||||
pkey = platform_keys[idx]
|
||||
pinfo = PLATFORMS[pkey]
|
||||
|
||||
@@ -418,11 +895,15 @@ def tools_command(args):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Prompt for missing API keys on newly enabled toolsets
|
||||
# Configure newly enabled toolsets that need API keys
|
||||
if added:
|
||||
_check_and_prompt_requirements(added)
|
||||
for ts_key in sorted(added):
|
||||
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key):
|
||||
if not _toolset_has_keys(ts_key):
|
||||
_configure_toolset(ts_key, config)
|
||||
|
||||
_save_platform_tools(config, pkey, new_enabled)
|
||||
save_config(config)
|
||||
print(color(f" ✓ Saved {pinfo['label']} configuration", Colors.GREEN))
|
||||
else:
|
||||
print(color(f" No changes to {pinfo['label']}", Colors.DIM))
|
||||
|
||||
233
hermes_state.py
233
hermes_state.py
@@ -24,7 +24,7 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
|
||||
|
||||
SCHEMA_VERSION = 2
|
||||
SCHEMA_VERSION = 4
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
@@ -46,6 +46,7 @@ CREATE TABLE IF NOT EXISTS sessions (
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
title TEXT,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
|
||||
@@ -133,7 +134,33 @@ class SessionDB:
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 2")
|
||||
if current_version < 3:
|
||||
# v3: add title column to sessions
|
||||
try:
|
||||
cursor.execute("ALTER TABLE sessions ADD COLUMN title TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 3")
|
||||
if current_version < 4:
|
||||
# v4: add unique index on title (NULLs allowed, only non-NULL must be unique)
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
|
||||
"ON sessions(title) WHERE title IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 4")
|
||||
|
||||
# Unique title index — always ensure it exists (safe to run after migrations
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
|
||||
"ON sessions(title) WHERE title IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
|
||||
# FTS5 setup (separate because CREATE VIRTUAL TABLE can't be in executescript with IF NOT EXISTS reliably)
|
||||
try:
|
||||
@@ -219,6 +246,210 @@ class SessionDB:
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
# Maximum length for session titles
|
||||
MAX_TITLE_LENGTH = 100
|
||||
|
||||
@staticmethod
|
||||
def sanitize_title(title: Optional[str]) -> Optional[str]:
|
||||
"""Validate and sanitize a session title.
|
||||
|
||||
- Strips leading/trailing whitespace
|
||||
- Removes ASCII control characters (0x00-0x1F, 0x7F) and problematic
|
||||
Unicode control chars (zero-width, RTL/LTR overrides, etc.)
|
||||
- Collapses internal whitespace runs to single spaces
|
||||
- Normalizes empty/whitespace-only strings to None
|
||||
- Enforces MAX_TITLE_LENGTH
|
||||
|
||||
Returns the cleaned title string or None.
|
||||
Raises ValueError if the title exceeds MAX_TITLE_LENGTH after cleaning.
|
||||
"""
|
||||
if not title:
|
||||
return None
|
||||
|
||||
import re
|
||||
|
||||
# Remove ASCII control characters (0x00-0x1F, 0x7F) but keep
|
||||
# whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be
|
||||
# normalized to spaces by the whitespace collapsing step below
|
||||
cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title)
|
||||
|
||||
# Remove problematic Unicode control characters:
|
||||
# - Zero-width chars (U+200B-U+200F, U+FEFF)
|
||||
# - Directional overrides (U+202A-U+202E, U+2066-U+2069)
|
||||
# - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB)
|
||||
cleaned = re.sub(
|
||||
r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]',
|
||||
'', cleaned,
|
||||
)
|
||||
|
||||
# Collapse internal whitespace runs and strip
|
||||
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
||||
|
||||
if not cleaned:
|
||||
return None
|
||||
|
||||
if len(cleaned) > SessionDB.MAX_TITLE_LENGTH:
|
||||
raise ValueError(
|
||||
f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})"
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
def set_session_title(self, session_id: str, title: str) -> bool:
|
||||
"""Set or update a session's title.
|
||||
|
||||
Returns True if session was found and title was set.
|
||||
Raises ValueError if title is already in use by another session,
|
||||
or if the title fails validation (too long, invalid characters).
|
||||
Empty/whitespace-only strings are normalized to None (clearing the title).
|
||||
"""
|
||||
title = self.sanitize_title(title)
|
||||
if title:
|
||||
# Check uniqueness (allow the same session to keep its own title)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||
(title, session_id),
|
||||
)
|
||||
conflict = cursor.fetchone()
|
||||
if conflict:
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||
"""Get the title for a session, or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row["title"] if row else None
|
||||
|
||||
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
|
||||
"""Look up a session by exact title. Returns session dict or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE title = ?", (title,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_by_title(self, title: str) -> Optional[str]:
|
||||
"""Resolve a title to a session ID, preferring the latest in a lineage.
|
||||
|
||||
If the exact title exists, returns that session's ID.
|
||||
If not, searches for "title #N" variants and returns the latest one.
|
||||
If the exact title exists AND numbered variants exist, returns the
|
||||
latest numbered variant (the most recent continuation).
|
||||
"""
|
||||
# First try exact match
|
||||
exact = self.get_session_by_title(title)
|
||||
|
||||
# Also search for numbered variants: "title #2", "title #3", etc.
|
||||
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
||||
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, title, started_at FROM sessions "
|
||||
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
(f"{escaped} #%",),
|
||||
)
|
||||
numbered = cursor.fetchall()
|
||||
|
||||
if numbered:
|
||||
# Return the most recent numbered variant
|
||||
return numbered[0]["id"]
|
||||
elif exact:
|
||||
return exact["id"]
|
||||
return None
|
||||
|
||||
def get_next_title_in_lineage(self, base_title: str) -> str:
|
||||
"""Generate the next title in a lineage (e.g., "my session" → "my session #2").
|
||||
|
||||
Strips any existing " #N" suffix to find the base name, then finds
|
||||
the highest existing number and increments.
|
||||
"""
|
||||
import re
|
||||
# Strip existing #N suffix to find the true base
|
||||
match = re.match(r'^(.*?) #(\d+)$', base_title)
|
||||
if match:
|
||||
base = match.group(1)
|
||||
else:
|
||||
base = base_title
|
||||
|
||||
# Find all existing numbered variants
|
||||
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
|
||||
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
|
||||
(base, f"{escaped} #%"),
|
||||
)
|
||||
existing = [row["title"] for row in cursor.fetchall()]
|
||||
|
||||
if not existing:
|
||||
return base # No conflict, use the base name as-is
|
||||
|
||||
# Find the highest number
|
||||
max_num = 1 # The unnumbered original counts as #1
|
||||
for t in existing:
|
||||
m = re.match(r'^.* #(\d+)$', t)
|
||||
if m:
|
||||
max_num = max(max_num, int(m.group(1)))
|
||||
|
||||
return f"{base} #{max_num + 1}"
|
||||
|
||||
def list_sessions_rich(
|
||||
self,
|
||||
source: str = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions with preview (first user message) and last active timestamp.
|
||||
|
||||
Returns dicts with keys: id, source, model, title, started_at, ended_at,
|
||||
message_count, preview (first 60 chars of first user message),
|
||||
last_active (timestamp of last message).
|
||||
|
||||
Uses a single query with correlated subqueries instead of N+2 queries.
|
||||
"""
|
||||
source_clause = "WHERE s.source = ?" if source else ""
|
||||
query = f"""
|
||||
SELECT s.*,
|
||||
COALESCE(
|
||||
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
|
||||
FROM messages m
|
||||
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
|
||||
ORDER BY m.timestamp, m.id LIMIT 1),
|
||||
''
|
||||
) AS _preview_raw,
|
||||
COALESCE(
|
||||
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
|
||||
s.started_at
|
||||
) AS last_active
|
||||
FROM sessions s
|
||||
{source_clause}
|
||||
ORDER BY s.started_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params = (source, limit, offset) if source else (limit, offset)
|
||||
cursor = self._conn.execute(query, params)
|
||||
sessions = []
|
||||
for row in cursor.fetchall():
|
||||
s = dict(row)
|
||||
# Build the preview from the raw substring
|
||||
raw = s.pop("_preview_raw", "").strip()
|
||||
if raw:
|
||||
text = raw[:60]
|
||||
s["preview"] = text + ("..." if len(raw) > 60 else "")
|
||||
else:
|
||||
s["preview"] = ""
|
||||
sessions.append(s)
|
||||
|
||||
return sessions
|
||||
|
||||
# =========================================================================
|
||||
# Message storage
|
||||
# =========================================================================
|
||||
|
||||
119
hermes_time.py
Normal file
119
hermes_time.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Timezone-aware clock for Hermes.
|
||||
|
||||
Provides a single ``now()`` helper that returns a timezone-aware datetime
|
||||
based on the user's configured IANA timezone (e.g. ``Asia/Kolkata``).
|
||||
|
||||
Resolution order:
|
||||
1. ``HERMES_TIMEZONE`` environment variable
|
||||
2. ``timezone`` key in ``~/.hermes/config.yaml``
|
||||
3. Falls back to the server's local time (``datetime.now().astimezone()``)
|
||||
|
||||
Invalid timezone values log a warning and fall back safely — Hermes never
|
||||
crashes due to a bad timezone string.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone as _tz
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
except ImportError:
|
||||
# Python 3.8 fallback (shouldn't be needed — Hermes requires 3.9+)
|
||||
from backports.zoneinfo import ZoneInfo # type: ignore[no-redef]
|
||||
|
||||
# Cached state — resolved once, reused on every call.
|
||||
# Call reset_cache() to force re-resolution (e.g. after config changes).
|
||||
_cached_tz: Optional[ZoneInfo] = None
|
||||
_cached_tz_name: Optional[str] = None
|
||||
_cache_resolved: bool = False
|
||||
|
||||
|
||||
def _resolve_timezone_name() -> str:
|
||||
"""Read the configured IANA timezone string (or empty string).
|
||||
|
||||
This does file I/O when falling through to config.yaml, so callers
|
||||
should cache the result rather than calling on every ``now()``.
|
||||
"""
|
||||
# 1. Environment variable (highest priority — set by Supervisor, etc.)
|
||||
tz_env = os.getenv("HERMES_TIMEZONE", "").strip()
|
||||
if tz_env:
|
||||
return tz_env
|
||||
|
||||
# 2. config.yaml ``timezone`` key
|
||||
try:
|
||||
import yaml
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
tz_cfg = cfg.get("timezone", "")
|
||||
if isinstance(tz_cfg, str) and tz_cfg.strip():
|
||||
return tz_cfg.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _get_zoneinfo(name: str) -> Optional[ZoneInfo]:
|
||||
"""Validate and return a ZoneInfo, or None if invalid."""
|
||||
if not name:
|
||||
return None
|
||||
try:
|
||||
return ZoneInfo(name)
|
||||
except (KeyError, Exception) as exc:
|
||||
logger.warning(
|
||||
"Invalid timezone '%s': %s. Falling back to server local time.",
|
||||
name, exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_timezone() -> Optional[ZoneInfo]:
|
||||
"""Return the user's configured ZoneInfo, or None (meaning server-local).
|
||||
|
||||
Resolved once and cached. Call ``reset_cache()`` after config changes.
|
||||
"""
|
||||
global _cached_tz, _cached_tz_name, _cache_resolved
|
||||
if not _cache_resolved:
|
||||
_cached_tz_name = _resolve_timezone_name()
|
||||
_cached_tz = _get_zoneinfo(_cached_tz_name)
|
||||
_cache_resolved = True
|
||||
return _cached_tz
|
||||
|
||||
|
||||
def get_timezone_name() -> str:
|
||||
"""Return the IANA name of the configured timezone, or empty string."""
|
||||
global _cached_tz_name, _cache_resolved
|
||||
if not _cache_resolved:
|
||||
get_timezone() # populates cache
|
||||
return _cached_tz_name or ""
|
||||
|
||||
|
||||
def now() -> datetime:
|
||||
"""
|
||||
Return the current time as a timezone-aware datetime.
|
||||
|
||||
If a valid timezone is configured, returns wall-clock time in that zone.
|
||||
Otherwise returns the server's local time (via ``astimezone()``).
|
||||
"""
|
||||
tz = get_timezone()
|
||||
if tz is not None:
|
||||
return datetime.now(tz)
|
||||
# No timezone configured — use server-local (still tz-aware)
|
||||
return datetime.now().astimezone()
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Clear the cached timezone. Used by tests and after config changes."""
|
||||
global _cached_tz, _cached_tz_name, _cache_resolved
|
||||
_cached_tz = None
|
||||
_cached_tz_name = None
|
||||
_cache_resolved = False
|
||||
@@ -149,7 +149,7 @@ class MiniSWERunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-sonnet-4.6",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
env_type: str = "local",
|
||||
@@ -200,13 +200,7 @@ class MiniSWERunner:
|
||||
else:
|
||||
client_kwargs["base_url"] = "https://openrouter.ai/api/v1"
|
||||
|
||||
if base_url and "api.anthropic.com" in base_url.strip().lower():
|
||||
raise ValueError(
|
||||
"Anthropic's native /v1/messages API is not supported yet (planned for a future release). "
|
||||
"Hermes currently requires OpenAI-compatible /chat/completions endpoints. "
|
||||
"To use Claude models now, route through OpenRouter (OPENROUTER_API_KEY) "
|
||||
"or any OpenAI-compatible proxy that wraps the Anthropic API."
|
||||
)
|
||||
|
||||
|
||||
# Handle API key - OpenRouter is the primary provider
|
||||
if api_key:
|
||||
|
||||
@@ -225,6 +225,18 @@ def get_tool_definitions(
|
||||
# Ask the registry for schemas (only returns tools whose check_fn passes)
|
||||
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
|
||||
|
||||
# Rebuild execute_code schema to only list sandbox tools that are actually
|
||||
# enabled. Without this, the model sees "web_search is available in
|
||||
# execute_code" even when the user disabled the web toolset (#560-discord).
|
||||
if "execute_code" in tools_to_include:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "execute_code":
|
||||
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
|
||||
break
|
||||
|
||||
if not quiet_mode:
|
||||
if filtered_tools:
|
||||
tool_names = [t["function"]["name"] for t in filtered_tools]
|
||||
|
||||
441
optional-skills/research/qmd/SKILL.md
Normal file
441
optional-skills/research/qmd/SKILL.md
Normal file
@@ -0,0 +1,441 @@
|
||||
---
|
||||
name: qmd
|
||||
description: Search personal knowledge bases, notes, docs, and meeting transcripts locally using qmd — a hybrid retrieval engine with BM25, vector search, and LLM reranking. Supports CLI and MCP integration.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent + Teknium
|
||||
license: MIT
|
||||
platforms: [macos, linux]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Search, Knowledge-Base, RAG, Notes, MCP, Local-AI]
|
||||
related_skills: [obsidian, native-mcp, arxiv]
|
||||
---
|
||||
|
||||
# QMD — Query Markup Documents
|
||||
|
||||
Local, on-device search engine for personal knowledge bases. Indexes markdown
|
||||
notes, meeting transcripts, documentation, and any text-based files, then
|
||||
provides hybrid search combining keyword matching, semantic understanding, and
|
||||
LLM-powered reranking — all running locally with no cloud dependencies.
|
||||
|
||||
Created by [Tobi Lütke](https://github.com/tobi/qmd). MIT licensed.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to search their notes, docs, knowledge base, or meeting transcripts
|
||||
- User wants to find something across a large collection of markdown/text files
|
||||
- User wants semantic search ("find notes about X concept") not just keyword grep
|
||||
- User has already set up qmd collections and wants to query them
|
||||
- User asks to set up a local knowledge base or document search system
|
||||
- Keywords: "search my notes", "find in my docs", "knowledge base", "qmd"
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Node.js >= 22 (required)
|
||||
|
||||
```bash
|
||||
# Check version
|
||||
node --version # must be >= 22
|
||||
|
||||
# macOS — install or upgrade via Homebrew
|
||||
brew install node@22
|
||||
|
||||
# Linux — use NodeSource or nvm
|
||||
curl -fsSL https://deb.nodesource.com/setup_22.x | sudo -E bash -
|
||||
sudo apt-get install -y nodejs
|
||||
# or with nvm:
|
||||
nvm install 22 && nvm use 22
|
||||
```
|
||||
|
||||
### SQLite with Extension Support (macOS only)
|
||||
|
||||
macOS system SQLite lacks extension loading. Install via Homebrew:
|
||||
|
||||
```bash
|
||||
brew install sqlite
|
||||
```
|
||||
|
||||
### Install qmd
|
||||
|
||||
```bash
|
||||
npm install -g @tobilu/qmd
|
||||
# or with Bun:
|
||||
bun install -g @tobilu/qmd
|
||||
```
|
||||
|
||||
First run auto-downloads 3 local GGUF models (~2GB total):
|
||||
|
||||
| Model | Purpose | Size |
|
||||
|-------|---------|------|
|
||||
| embeddinggemma-300M-Q8_0 | Vector embeddings | ~300MB |
|
||||
| qwen3-reranker-0.6b-q8_0 | Result reranking | ~640MB |
|
||||
| qmd-query-expansion-1.7B | Query expansion | ~1.1GB |
|
||||
|
||||
### Verify Installation
|
||||
|
||||
```bash
|
||||
qmd --version
|
||||
qmd status
|
||||
```
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Command | What It Does | Speed |
|
||||
|---------|-------------|-------|
|
||||
| `qmd search "query"` | BM25 keyword search (no models) | ~0.2s |
|
||||
| `qmd vsearch "query"` | Semantic vector search (1 model) | ~3s |
|
||||
| `qmd query "query"` | Hybrid + reranking (all 3 models) | ~2-3s warm, ~19s cold |
|
||||
| `qmd get <docid>` | Retrieve full document content | instant |
|
||||
| `qmd multi-get "glob"` | Retrieve multiple files | instant |
|
||||
| `qmd collection add <path> --name <n>` | Add a directory as a collection | instant |
|
||||
| `qmd context add <path> "description"` | Add context metadata to improve retrieval | instant |
|
||||
| `qmd embed` | Generate/update vector embeddings | varies |
|
||||
| `qmd status` | Show index health and collection info | instant |
|
||||
| `qmd mcp` | Start MCP server (stdio) | persistent |
|
||||
| `qmd mcp --http --daemon` | Start MCP server (HTTP, warm models) | persistent |
|
||||
|
||||
## Setup Workflow
|
||||
|
||||
### 1. Add Collections
|
||||
|
||||
Point qmd at directories containing your documents:
|
||||
|
||||
```bash
|
||||
# Add a notes directory
|
||||
qmd collection add ~/notes --name notes
|
||||
|
||||
# Add project docs
|
||||
qmd collection add ~/projects/myproject/docs --name project-docs
|
||||
|
||||
# Add meeting transcripts
|
||||
qmd collection add ~/meetings --name meetings
|
||||
|
||||
# List all collections
|
||||
qmd collection list
|
||||
```
|
||||
|
||||
### 2. Add Context Descriptions
|
||||
|
||||
Context metadata helps the search engine understand what each collection
|
||||
contains. This significantly improves retrieval quality:
|
||||
|
||||
```bash
|
||||
qmd context add qmd://notes "Personal notes, ideas, and journal entries"
|
||||
qmd context add qmd://project-docs "Technical documentation for the main project"
|
||||
qmd context add qmd://meetings "Meeting transcripts and action items from team syncs"
|
||||
```
|
||||
|
||||
### 3. Generate Embeddings
|
||||
|
||||
```bash
|
||||
qmd embed
|
||||
```
|
||||
|
||||
This processes all documents in all collections and generates vector
|
||||
embeddings. Re-run after adding new documents or collections.
|
||||
|
||||
### 4. Verify
|
||||
|
||||
```bash
|
||||
qmd status # shows index health, collection stats, model info
|
||||
```
|
||||
|
||||
## Search Patterns
|
||||
|
||||
### Fast Keyword Search (BM25)
|
||||
|
||||
Best for: exact terms, code identifiers, names, known phrases.
|
||||
No models loaded — near-instant results.
|
||||
|
||||
```bash
|
||||
qmd search "authentication middleware"
|
||||
qmd search "handleError async"
|
||||
```
|
||||
|
||||
### Semantic Vector Search
|
||||
|
||||
Best for: natural language questions, conceptual queries.
|
||||
Loads embedding model (~3s first query).
|
||||
|
||||
```bash
|
||||
qmd vsearch "how does the rate limiter handle burst traffic"
|
||||
qmd vsearch "ideas for improving onboarding flow"
|
||||
```
|
||||
|
||||
### Hybrid Search with Reranking (Best Quality)
|
||||
|
||||
Best for: important queries where quality matters most.
|
||||
Uses all 3 models — query expansion, parallel BM25+vector, reranking.
|
||||
|
||||
```bash
|
||||
qmd query "what decisions were made about the database migration"
|
||||
```
|
||||
|
||||
### Structured Multi-Mode Queries
|
||||
|
||||
Combine different search types in a single query for precision:
|
||||
|
||||
```bash
|
||||
# BM25 for exact term + vector for concept
|
||||
qmd query $'lex: rate limiter\nvec: how does throttling work under load'
|
||||
|
||||
# With query expansion
|
||||
qmd query $'expand: database migration plan\nlex: "schema change"'
|
||||
```
|
||||
|
||||
### Query Syntax (lex/BM25 mode)
|
||||
|
||||
| Syntax | Effect | Example |
|
||||
|--------|--------|---------|
|
||||
| `term` | Prefix match | `perf` matches "performance" |
|
||||
| `"phrase"` | Exact phrase | `"rate limiter"` |
|
||||
| `-term` | Exclude term | `performance -sports` |
|
||||
|
||||
### HyDE (Hypothetical Document Embeddings)
|
||||
|
||||
For complex topics, write what you expect the answer to look like:
|
||||
|
||||
```bash
|
||||
qmd query $'hyde: The migration plan involves three phases. First, we add the new columns without dropping the old ones. Then we backfill data. Finally we cut over and remove legacy columns.'
|
||||
```
|
||||
|
||||
### Scoping to Collections
|
||||
|
||||
```bash
|
||||
qmd search "query" --collection notes
|
||||
qmd query "query" --collection project-docs
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
```bash
|
||||
qmd search "query" --json # JSON output (best for parsing)
|
||||
qmd search "query" --limit 5 # Limit results
|
||||
qmd get "#abc123" # Get by document ID
|
||||
qmd get "path/to/file.md" # Get by file path
|
||||
qmd get "file.md:50" -l 100 # Get specific line range
|
||||
qmd multi-get "journals/*.md" --json # Batch retrieve by glob
|
||||
```
|
||||
|
||||
## MCP Integration (Recommended)
|
||||
|
||||
qmd exposes an MCP server that provides search tools directly to
|
||||
Hermes Agent via the native MCP client. This is the preferred
|
||||
integration — once configured, the agent gets qmd tools automatically
|
||||
without needing to load this skill.
|
||||
|
||||
### Option A: Stdio Mode (Simple)
|
||||
|
||||
Add to `~/.hermes/config.yaml`:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
qmd:
|
||||
command: "qmd"
|
||||
args: ["mcp"]
|
||||
timeout: 30
|
||||
connect_timeout: 45
|
||||
```
|
||||
|
||||
This registers tools: `mcp_qmd_search`, `mcp_qmd_vsearch`,
|
||||
`mcp_qmd_deep_search`, `mcp_qmd_get`, `mcp_qmd_status`.
|
||||
|
||||
**Tradeoff:** Models load on first search call (~19s cold start),
|
||||
then stay warm for the session. Acceptable for occasional use.
|
||||
|
||||
### Option B: HTTP Daemon Mode (Fast, Recommended for Heavy Use)
|
||||
|
||||
Start the qmd daemon separately — it keeps models warm in memory:
|
||||
|
||||
```bash
|
||||
# Start daemon (persists across agent restarts)
|
||||
qmd mcp --http --daemon
|
||||
|
||||
# Runs on http://localhost:8181 by default
|
||||
```
|
||||
|
||||
Then configure Hermes Agent to connect via HTTP:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
qmd:
|
||||
url: "http://localhost:8181/mcp"
|
||||
timeout: 30
|
||||
```
|
||||
|
||||
**Tradeoff:** Uses ~2GB RAM while running, but every query is fast
|
||||
(~2-3s). Best for users who search frequently.
|
||||
|
||||
### Keeping the Daemon Running
|
||||
|
||||
#### macOS (launchd)
|
||||
|
||||
```bash
|
||||
cat > ~/Library/LaunchAgents/com.qmd.daemon.plist << 'EOF'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
|
||||
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.qmd.daemon</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>qmd</string>
|
||||
<string>mcp</string>
|
||||
<string>--http</string>
|
||||
<string>--daemon</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/tmp/qmd-daemon.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/tmp/qmd-daemon.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
|
||||
launchctl load ~/Library/LaunchAgents/com.qmd.daemon.plist
|
||||
```
|
||||
|
||||
#### Linux (systemd user service)
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.config/systemd/user
|
||||
|
||||
cat > ~/.config/systemd/user/qmd-daemon.service << 'EOF'
|
||||
[Unit]
|
||||
Description=QMD MCP Daemon
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
ExecStart=qmd mcp --http --daemon
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
Environment=PATH=/usr/local/bin:/usr/bin:/bin
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
EOF
|
||||
|
||||
systemctl --user daemon-reload
|
||||
systemctl --user enable --now qmd-daemon
|
||||
systemctl --user status qmd-daemon
|
||||
```
|
||||
|
||||
### MCP Tools Reference
|
||||
|
||||
Once connected, these tools are available as `mcp_qmd_*`:
|
||||
|
||||
| MCP Tool | Maps To | Description |
|
||||
|----------|---------|-------------|
|
||||
| `mcp_qmd_search` | `qmd search` | BM25 keyword search |
|
||||
| `mcp_qmd_vsearch` | `qmd vsearch` | Semantic vector search |
|
||||
| `mcp_qmd_deep_search` | `qmd query` | Hybrid search + reranking |
|
||||
| `mcp_qmd_get` | `qmd get` | Retrieve document by ID or path |
|
||||
| `mcp_qmd_status` | `qmd status` | Index health and stats |
|
||||
|
||||
The MCP tools accept structured JSON queries for multi-mode search:
|
||||
|
||||
```json
|
||||
{
|
||||
"searches": [
|
||||
{"type": "lex", "query": "authentication middleware"},
|
||||
{"type": "vec", "query": "how user login is verified"}
|
||||
],
|
||||
"collections": ["project-docs"],
|
||||
"limit": 10
|
||||
}
|
||||
```
|
||||
|
||||
## CLI Usage (Without MCP)
|
||||
|
||||
When MCP is not configured, use qmd directly via terminal:
|
||||
|
||||
```
|
||||
terminal(command="qmd query 'what was decided about the API redesign' --json", timeout=30)
|
||||
```
|
||||
|
||||
For setup and management tasks, always use terminal:
|
||||
|
||||
```
|
||||
terminal(command="qmd collection add ~/Documents/notes --name notes")
|
||||
terminal(command="qmd context add qmd://notes 'Personal research notes and ideas'")
|
||||
terminal(command="qmd embed")
|
||||
terminal(command="qmd status")
|
||||
```
|
||||
|
||||
## How the Search Pipeline Works
|
||||
|
||||
Understanding the internals helps choose the right search mode:
|
||||
|
||||
1. **Query Expansion** — A fine-tuned 1.7B model generates 2 alternative
|
||||
queries. The original gets 2x weight in fusion.
|
||||
2. **Parallel Retrieval** — BM25 (SQLite FTS5) and vector search run
|
||||
simultaneously across all query variants.
|
||||
3. **RRF Fusion** — Reciprocal Rank Fusion (k=60) merges results.
|
||||
Top-rank bonus: #1 gets +0.05, #2-3 get +0.02.
|
||||
4. **LLM Reranking** — qwen3-reranker scores top 30 candidates (0.0-1.0).
|
||||
5. **Position-Aware Blending** — Ranks 1-3: 75% retrieval / 25% reranker.
|
||||
Ranks 4-10: 60/40. Ranks 11+: 40/60 (trusts reranker more for long tail).
|
||||
|
||||
**Smart Chunking:** Documents are split at natural break points (headings,
|
||||
code blocks, blank lines) targeting ~900 tokens with 15% overlap. Code
|
||||
blocks are never split mid-block.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always add context descriptions** — `qmd context add` dramatically
|
||||
improves retrieval accuracy. Describe what each collection contains.
|
||||
2. **Re-embed after adding documents** — `qmd embed` must be re-run when
|
||||
new files are added to collections.
|
||||
3. **Use `qmd search` for speed** — when you need fast keyword lookup
|
||||
(code identifiers, exact names), BM25 is instant and needs no models.
|
||||
4. **Use `qmd query` for quality** — when the question is conceptual or
|
||||
the user needs the best possible results, use hybrid search.
|
||||
5. **Prefer MCP integration** — once configured, the agent gets native
|
||||
tools without needing to load this skill each time.
|
||||
6. **Daemon mode for frequent users** — if the user searches their
|
||||
knowledge base regularly, recommend the HTTP daemon setup.
|
||||
7. **First query in structured search gets 2x weight** — put the most
|
||||
important/certain query first when combining lex and vec.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Models downloading on first run"
|
||||
Normal — qmd auto-downloads ~2GB of GGUF models on first use.
|
||||
This is a one-time operation.
|
||||
|
||||
### Cold start latency (~19s)
|
||||
This happens when models aren't loaded in memory. Solutions:
|
||||
- Use HTTP daemon mode (`qmd mcp --http --daemon`) to keep warm
|
||||
- Use `qmd search` (BM25 only) when models aren't needed
|
||||
- MCP stdio mode loads models on first search, stays warm for session
|
||||
|
||||
### macOS: "unable to load extension"
|
||||
Install Homebrew SQLite: `brew install sqlite`
|
||||
Then ensure it's on PATH before system SQLite.
|
||||
|
||||
### "No collections found"
|
||||
Run `qmd collection add <path> --name <name>` to add directories,
|
||||
then `qmd embed` to index them.
|
||||
|
||||
### Embedding model override (CJK/multilingual)
|
||||
Set `QMD_EMBED_MODEL` environment variable for non-English content:
|
||||
```bash
|
||||
export QMD_EMBED_MODEL="your-multilingual-model"
|
||||
```
|
||||
|
||||
## Data Storage
|
||||
|
||||
- **Index & vectors:** `~/.cache/qmd/index.sqlite`
|
||||
- **Models:** Auto-downloaded to local cache on first run
|
||||
- **No cloud dependencies** — everything runs locally
|
||||
|
||||
## References
|
||||
|
||||
- [GitHub: tobi/qmd](https://github.com/tobi/qmd)
|
||||
- [QMD Changelog](https://github.com/tobi/qmd/blob/main/CHANGELOG.md)
|
||||
@@ -50,6 +50,7 @@ pty = ["ptyprocess>=0.7.0"]
|
||||
honcho = ["honcho-ai>=2.0.1"]
|
||||
mcp = ["mcp>=1.2.0"]
|
||||
homeassistant = ["aiohttp>=3.9.0"]
|
||||
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git"]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[daytona]",
|
||||
|
||||
248
run_agent.py
248
run_agent.py
@@ -99,6 +99,46 @@ from agent.trajectory import (
|
||||
)
|
||||
|
||||
|
||||
class IterationBudget:
|
||||
"""Thread-safe shared iteration counter for parent and child agents.
|
||||
|
||||
Tracks total LLM-call iterations consumed across a parent agent and all
|
||||
its subagents. A single ``IterationBudget`` is created by the parent
|
||||
and passed to every child so they share the same cap.
|
||||
|
||||
``execute_code`` (programmatic tool calling) iterations are refunded via
|
||||
:meth:`refund` so they don't eat into the budget.
|
||||
"""
|
||||
|
||||
def __init__(self, max_total: int):
|
||||
self.max_total = max_total
|
||||
self._used = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def consume(self) -> bool:
|
||||
"""Try to consume one iteration. Returns True if allowed."""
|
||||
with self._lock:
|
||||
if self._used >= self.max_total:
|
||||
return False
|
||||
self._used += 1
|
||||
return True
|
||||
|
||||
def refund(self) -> None:
|
||||
"""Give back one iteration (e.g. for execute_code turns)."""
|
||||
with self._lock:
|
||||
if self._used > 0:
|
||||
self._used -= 1
|
||||
|
||||
@property
|
||||
def used(self) -> int:
|
||||
return self._used
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
with self._lock:
|
||||
return max(0, self.max_total - self._used)
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
AI Agent with tool calling capabilities.
|
||||
@@ -114,7 +154,7 @@ class AIAgent:
|
||||
provider: str = None,
|
||||
api_mode: str = None,
|
||||
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
|
||||
max_iterations: int = 60, # Default tool-calling iterations
|
||||
max_iterations: int = 90, # Default tool-calling iterations (shared with subagents)
|
||||
tool_delay: float = 1.0,
|
||||
enabled_toolsets: List[str] = None,
|
||||
disabled_toolsets: List[str] = None,
|
||||
@@ -142,6 +182,7 @@ class AIAgent:
|
||||
skip_memory: bool = False,
|
||||
session_db=None,
|
||||
honcho_session_key: str = None,
|
||||
iteration_budget: "IterationBudget" = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -152,7 +193,7 @@ class AIAgent:
|
||||
provider (str): Provider identifier (optional; used for telemetry/routing hints)
|
||||
api_mode (str): API mode override: "chat_completions" or "codex_responses"
|
||||
model (str): Model name to use (default: "anthropic/claude-opus-4.6")
|
||||
max_iterations (int): Maximum number of tool calling iterations (default: 60)
|
||||
max_iterations (int): Maximum number of tool calling iterations (default: 90)
|
||||
tool_delay (float): Delay between tool calls in seconds (default: 1.0)
|
||||
enabled_toolsets (List[str]): Only enable tools from these toolsets (optional)
|
||||
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
|
||||
@@ -172,7 +213,7 @@ class AIAgent:
|
||||
Provided by the platform layer (CLI or gateway). If None, the clarify tool returns an error.
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning configuration override (e.g. {"effort": "none"} to disable thinking).
|
||||
If None, defaults to {"enabled": True, "effort": "xhigh"} for OpenRouter. Set to disable/customize reasoning.
|
||||
If None, defaults to {"enabled": True, "effort": "medium"} for OpenRouter. Set to disable/customize reasoning.
|
||||
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
|
||||
Useful for injecting a few-shot example or priming the model's response style.
|
||||
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
|
||||
@@ -186,6 +227,9 @@ class AIAgent:
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
# Shared iteration budget — parent creates, children inherit.
|
||||
# Consumed by every LLM turn across parent + all subagents.
|
||||
self.iteration_budget = iteration_budget or IterationBudget(max_iterations)
|
||||
self.tool_delay = tool_delay
|
||||
self.save_trajectories = save_trajectories
|
||||
self.verbose_logging = verbose_logging
|
||||
@@ -209,13 +253,7 @@ class AIAgent:
|
||||
self.provider = "openai-codex"
|
||||
else:
|
||||
self.api_mode = "chat_completions"
|
||||
if base_url and "api.anthropic.com" in base_url.strip().lower():
|
||||
raise ValueError(
|
||||
"Anthropic's native /v1/messages API is not supported yet (planned for a future release). "
|
||||
"Hermes currently requires OpenAI-compatible /chat/completions endpoints. "
|
||||
"To use Claude models now, route through OpenRouter (OPENROUTER_API_KEY) "
|
||||
"or any OpenAI-compatible proxy that wraps the Anthropic API."
|
||||
)
|
||||
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
@@ -243,7 +281,7 @@ class AIAgent:
|
||||
|
||||
# Model response configuration
|
||||
self.max_tokens = max_tokens # None = use model default
|
||||
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
|
||||
self.reasoning_config = reasoning_config # None = use default (medium for OpenRouter)
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
||||
@@ -345,6 +383,12 @@ class AIAgent:
|
||||
"X-OpenRouter-Title": "Hermes Agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
}
|
||||
elif "api.kimi.com" in effective_base.lower():
|
||||
# Kimi Code API requires a recognized coding-agent User-Agent
|
||||
# (see https://github.com/MoonshotAI/kimi-cli)
|
||||
client_kwargs["default_headers"] = {
|
||||
"User-Agent": "KimiCLI/1.0",
|
||||
}
|
||||
|
||||
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
||||
try:
|
||||
@@ -1363,10 +1407,12 @@ class AIAgent:
|
||||
if context_files_prompt:
|
||||
prompt_parts.append(context_files_prompt)
|
||||
|
||||
now = datetime.now()
|
||||
prompt_parts.append(
|
||||
f"Conversation started: {now.strftime('%A, %B %d, %Y %I:%M %p')}"
|
||||
)
|
||||
from hermes_time import now as _hermes_now
|
||||
now = _hermes_now()
|
||||
timestamp_line = f"Conversation started: {now.strftime('%A, %B %d, %Y %I:%M %p')}"
|
||||
if self.session_id and os.getenv("HERMES_PASS_SESSION_ID"):
|
||||
timestamp_line += f"\nSession ID: {self.session_id}"
|
||||
prompt_parts.append(timestamp_line)
|
||||
|
||||
platform_key = (self.platform or "").lower().strip()
|
||||
if platform_key in PLATFORM_HINTS:
|
||||
@@ -2018,6 +2064,49 @@ class AIAgent:
|
||||
|
||||
return True
|
||||
|
||||
def _try_refresh_nous_client_credentials(self, *, force: bool = True) -> bool:
|
||||
if self.api_mode != "chat_completions" or self.provider != "nous":
|
||||
return False
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import resolve_nous_runtime_credentials
|
||||
|
||||
creds = resolve_nous_runtime_credentials(
|
||||
min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))),
|
||||
timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")),
|
||||
force_mint=force,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Nous credential refresh failed: %s", exc)
|
||||
return False
|
||||
|
||||
api_key = creds.get("api_key")
|
||||
base_url = creds.get("base_url")
|
||||
if not isinstance(api_key, str) or not api_key.strip():
|
||||
return False
|
||||
if not isinstance(base_url, str) or not base_url.strip():
|
||||
return False
|
||||
|
||||
self.api_key = api_key.strip()
|
||||
self.base_url = base_url.strip().rstrip("/")
|
||||
self._client_kwargs["api_key"] = self.api_key
|
||||
self._client_kwargs["base_url"] = self.base_url
|
||||
# Nous requests should not inherit OpenRouter-only attribution headers.
|
||||
self._client_kwargs.pop("default_headers", None)
|
||||
|
||||
try:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to rebuild OpenAI client after Nous refresh: %s", exc)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _interruptible_api_call(self, api_kwargs: dict):
|
||||
"""
|
||||
Run the API call in a background thread so the main conversation loop
|
||||
@@ -2069,8 +2158,8 @@ class AIAgent:
|
||||
if not instructions:
|
||||
instructions = DEFAULT_AGENT_IDENTITY
|
||||
|
||||
# Resolve reasoning effort: config > default (xhigh)
|
||||
reasoning_effort = "xhigh"
|
||||
# Resolve reasoning effort: config > default (medium)
|
||||
reasoning_effort = "medium"
|
||||
reasoning_enabled = True
|
||||
if self.reasoning_config and isinstance(self.reasoning_config, dict):
|
||||
if self.reasoning_config.get("enabled") is False:
|
||||
@@ -2136,7 +2225,7 @@ class AIAgent:
|
||||
else:
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
"effort": "medium"
|
||||
}
|
||||
|
||||
# Nous Portal product attribution
|
||||
@@ -2396,6 +2485,8 @@ class AIAgent:
|
||||
|
||||
if self._session_db:
|
||||
try:
|
||||
# Propagate title to the new session with auto-numbering
|
||||
old_title = self._session_db.get_session_title(self.session_id)
|
||||
self._session_db.end_session(self.session_id, "compression")
|
||||
old_session_id = self.session_id
|
||||
self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||
@@ -2405,6 +2496,13 @@ class AIAgent:
|
||||
model=self.model,
|
||||
parent_session_id=old_session_id,
|
||||
)
|
||||
# Auto-number the title for the continuation session
|
||||
if old_title:
|
||||
try:
|
||||
new_title = self._session_db.get_next_title_in_lineage(old_title)
|
||||
self._session_db.set_session_title(self.session_id, new_title)
|
||||
except (ValueError, Exception) as e:
|
||||
logger.debug("Could not propagate title on compression: %s", e)
|
||||
self._session_db.update_system_prompt(self.session_id, new_system_prompt)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB compression split failed: %s", e)
|
||||
@@ -2531,7 +2629,6 @@ class AIAgent:
|
||||
context=function_args.get("context"),
|
||||
toolsets=function_args.get("toolsets"),
|
||||
tasks=tasks_arg,
|
||||
model=function_args.get("model"),
|
||||
max_iterations=function_args.get("max_iterations"),
|
||||
parent_agent=self,
|
||||
)
|
||||
@@ -2680,7 +2777,7 @@ class AIAgent:
|
||||
else:
|
||||
summary_extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
"effort": "medium"
|
||||
}
|
||||
if _is_nous:
|
||||
summary_extra_body["tags"] = ["product=hermes-agent"]
|
||||
@@ -2792,13 +2889,15 @@ class AIAgent:
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
# Reset retry counters at the start of each conversation to prevent state leakage
|
||||
# Reset retry counters and iteration budget at the start of each turn
|
||||
# so subagent usage from a previous turn doesn't eat into the next one.
|
||||
self._invalid_tool_retries = 0
|
||||
self._invalid_json_retries = 0
|
||||
self._empty_content_retries = 0
|
||||
self._last_content_with_tools = None
|
||||
self._turns_since_memory = 0
|
||||
self._iters_since_skill = 0
|
||||
self.iteration_budget = IterationBudget(self.max_iterations)
|
||||
|
||||
# Initialize conversation (copy to avoid mutating the caller's list)
|
||||
messages = list(conversation_history) if conversation_history else []
|
||||
@@ -2930,7 +3029,7 @@ class AIAgent:
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
|
||||
while api_call_count < self.max_iterations:
|
||||
while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0:
|
||||
# Check for interrupt request (e.g., user sent new message)
|
||||
if self._interrupt_requested:
|
||||
interrupted = True
|
||||
@@ -2939,6 +3038,10 @@ class AIAgent:
|
||||
break
|
||||
|
||||
api_call_count += 1
|
||||
if not self.iteration_budget.consume():
|
||||
if not self.quiet_mode:
|
||||
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.max_total} total across agent + subagents)")
|
||||
break
|
||||
|
||||
# Fire step_callback for gateway hooks (agent:step event)
|
||||
if self.step_callback is not None:
|
||||
@@ -3015,6 +3118,13 @@ class AIAgent:
|
||||
if self._use_prompt_caching:
|
||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
||||
|
||||
# Safety net: strip orphaned tool results / add stubs for missing
|
||||
# results before sending to the API. The compressor handles this
|
||||
# during compression, but orphans can also sneak in from session
|
||||
# loading or manual message manipulation.
|
||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||||
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token
|
||||
@@ -3043,9 +3153,13 @@ class AIAgent:
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 6 # Increased to allow longer backoff periods
|
||||
compression_attempts = 0
|
||||
max_compression_attempts = 3
|
||||
codex_auth_retry_attempted = False
|
||||
nous_auth_retry_attempted = False
|
||||
|
||||
finish_reason = "stop"
|
||||
response = None # Guard against UnboundLocalError if all retries fail
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
@@ -3293,6 +3407,16 @@ class AIAgent:
|
||||
if self._try_refresh_codex_client_credentials(force=True):
|
||||
print(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...")
|
||||
continue
|
||||
if (
|
||||
self.api_mode == "chat_completions"
|
||||
and self.provider == "nous"
|
||||
and status_code == 401
|
||||
and not nous_auth_retry_attempted
|
||||
):
|
||||
nous_auth_retry_attempted = True
|
||||
if self._try_refresh_nous_client_credentials(force=True):
|
||||
print(f"{self.log_prefix}🔐 Nous agent key refreshed after 401. Retrying request...")
|
||||
continue
|
||||
|
||||
retry_count += 1
|
||||
elapsed_time = time.time() - api_start_time
|
||||
@@ -3331,7 +3455,19 @@ class AIAgent:
|
||||
)
|
||||
|
||||
if is_payload_too_large:
|
||||
print(f"{self.log_prefix}⚠️ Request payload too large (413) - attempting compression...")
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.")
|
||||
logging.error(f"{self.log_prefix}413 compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"messages": messages,
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Request payload too large: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
print(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
@@ -3340,6 +3476,7 @@ class AIAgent:
|
||||
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
time.sleep(2) # Brief pause between compression retries
|
||||
continue # Retry with compressed messages
|
||||
else:
|
||||
print(f"{self.log_prefix}❌ Payload too large and cannot compress further.")
|
||||
@@ -3385,6 +3522,20 @@ class AIAgent:
|
||||
else:
|
||||
print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...")
|
||||
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.")
|
||||
logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"messages": messages,
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
print(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message, approx_tokens=approx_tokens
|
||||
@@ -3393,6 +3544,7 @@ class AIAgent:
|
||||
if len(messages) < original_len or new_ctx and new_ctx < old_ctx:
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
time.sleep(2) # Brief pause between compression retries
|
||||
continue # Retry with compressed messages or new tier
|
||||
else:
|
||||
# Can't compress further and already at minimum tier
|
||||
@@ -3471,6 +3623,14 @@ class AIAgent:
|
||||
if interrupted:
|
||||
break
|
||||
|
||||
# Guard: if all retries exhausted without a successful response
|
||||
# (e.g. repeated context-length errors that exhausted retry_count),
|
||||
# the `response` variable is still None. Break out cleanly.
|
||||
if response is None:
|
||||
print(f"{self.log_prefix}❌ All API retries exhausted with no successful response.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
break
|
||||
|
||||
try:
|
||||
if self.api_mode == "codex_responses":
|
||||
assistant_message, finish_reason = self._normalize_codex_response(response)
|
||||
@@ -3687,6 +3847,13 @@ class AIAgent:
|
||||
self._log_msg_to_db(assistant_msg)
|
||||
|
||||
self._execute_tool_calls(assistant_message, messages, effective_task_id)
|
||||
|
||||
# Refund the iteration if the ONLY tool(s) called were
|
||||
# execute_code (programmatic tool calling). These are
|
||||
# cheap RPC-style calls that shouldn't eat the budget.
|
||||
_tc_names = {tc.function.name for tc in assistant_message.tool_calls}
|
||||
if _tc_names == {"execute_code"}:
|
||||
self.iteration_budget.refund()
|
||||
|
||||
if self.compression_enabled and self.context_compressor.should_compress():
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
@@ -3707,13 +3874,33 @@ class AIAgent:
|
||||
|
||||
# Check if response only has think block with no actual content after it
|
||||
if not self._has_content_after_think_block(final_response):
|
||||
# Track retries for empty-after-think responses
|
||||
# If the previous turn already delivered real content alongside
|
||||
# tool calls (e.g. "You're welcome!" + memory save), the model
|
||||
# has nothing more to say. Use the earlier content immediately
|
||||
# instead of wasting API calls on retries that won't help.
|
||||
fallback = getattr(self, '_last_content_with_tools', None)
|
||||
if fallback:
|
||||
logger.debug("Empty follow-up after tool calls — using prior turn content as final response")
|
||||
self._last_content_with_tools = None
|
||||
self._empty_content_retries = 0
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_names = []
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
tool_names.append(fn.get("name", "unknown"))
|
||||
msg["content"] = f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..."
|
||||
break
|
||||
final_response = self._strip_think_blocks(fallback).strip()
|
||||
break
|
||||
|
||||
# No fallback available — this is a genuine empty response.
|
||||
# Retry in case the model just had a bad generation.
|
||||
if not hasattr(self, '_empty_content_retries'):
|
||||
self._empty_content_retries = 0
|
||||
self._empty_content_retries += 1
|
||||
|
||||
# Show the reasoning/thinking content so the user can see
|
||||
# what the model was thinking even though content is empty
|
||||
reasoning_text = self._extract_reasoning(assistant_message)
|
||||
print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
|
||||
if reasoning_text:
|
||||
@@ -3869,7 +4056,12 @@ class AIAgent:
|
||||
final_response = f"I apologize, but I encountered repeated errors: {error_msg}"
|
||||
break
|
||||
|
||||
if api_call_count >= self.max_iterations and final_response is None:
|
||||
if final_response is None and (
|
||||
api_call_count >= self.max_iterations
|
||||
or self.iteration_budget.remaining <= 0
|
||||
):
|
||||
if self.iteration_budget.remaining <= 0 and not self.quiet_mode:
|
||||
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} used, including subagents)")
|
||||
final_response = self._handle_max_iterations(messages, api_call_count)
|
||||
|
||||
# Determine if conversation completed successfully
|
||||
@@ -3940,7 +4132,7 @@ def main(
|
||||
|
||||
Args:
|
||||
query (str): Natural language query for the agent. Defaults to Python 3.13 example.
|
||||
model (str): Model name to use (OpenRouter format: provider/model). Defaults to anthropic/claude-sonnet-4-20250514.
|
||||
model (str): Model name to use (OpenRouter format: provider/model). Defaults to anthropic/claude-sonnet-4.6.
|
||||
api_key (str): API key for authentication. Uses OPENROUTER_API_KEY env var if not provided.
|
||||
base_url (str): Base URL for the model API. Defaults to https://openrouter.ai/api/v1
|
||||
max_turns (int): Maximum number of API call iterations. Defaults to 10.
|
||||
|
||||
@@ -829,6 +829,33 @@ install_node_deps() {
|
||||
log_warn "npm install failed (browser tools may not work)"
|
||||
}
|
||||
log_success "Node.js dependencies installed"
|
||||
|
||||
# Install Playwright browser + system dependencies.
|
||||
# Playwright's install-deps only supports apt/dnf/zypper natively.
|
||||
# For Arch/Manjaro we install the system libs via pacman first.
|
||||
log_info "Installing browser engine (Playwright Chromium)..."
|
||||
case "$DISTRO" in
|
||||
arch|manjaro)
|
||||
if command -v pacman &> /dev/null; then
|
||||
log_info "Arch/Manjaro detected — installing Chromium system dependencies via pacman..."
|
||||
if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then
|
||||
sudo NEEDRESTART_MODE=a pacman -S --noconfirm --needed \
|
||||
nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib >/dev/null 2>&1 || true
|
||||
elif [ "$(id -u)" -eq 0 ]; then
|
||||
pacman -S --noconfirm --needed \
|
||||
nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib >/dev/null 2>&1 || true
|
||||
else
|
||||
log_warn "Cannot install browser deps without sudo. Run manually:"
|
||||
log_warn " sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
|
||||
fi
|
||||
fi
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
|
||||
;;
|
||||
*)
|
||||
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || true
|
||||
;;
|
||||
esac
|
||||
log_success "Browser engine installed"
|
||||
fi
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
|
||||
3
skills/apple/DESCRIPTION.md
Normal file
3
skills/apple/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Apple/macOS-specific skills — iMessage, Reminders, Notes, FindMy, and macOS automation. These skills only load on macOS systems.
|
||||
---
|
||||
88
skills/apple/apple-notes/SKILL.md
Normal file
88
skills/apple/apple-notes/SKILL.md
Normal file
@@ -0,0 +1,88 @@
|
||||
---
|
||||
name: apple-notes
|
||||
description: Manage Apple Notes via the memo CLI on macOS (create, view, search, edit).
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Notes, Apple, macOS, note-taking]
|
||||
related_skills: [obsidian]
|
||||
---
|
||||
|
||||
# Apple Notes
|
||||
|
||||
Use `memo` to manage Apple Notes directly from the terminal. Notes sync across all Apple devices via iCloud.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Notes.app
|
||||
- Install: `brew tap antoniorodr/memo && brew install antoniorodr/memo/memo`
|
||||
- Grant Automation access to Notes.app when prompted (System Settings → Privacy → Automation)
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to create, view, or search Apple Notes
|
||||
- Saving information to Notes.app for cross-device access
|
||||
- Organizing notes into folders
|
||||
- Exporting notes to Markdown/HTML
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Obsidian vault management → use the `obsidian` skill
|
||||
- Bear Notes → separate app (not supported here)
|
||||
- Quick agent-only notes → use the `memory` tool instead
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### View Notes
|
||||
|
||||
```bash
|
||||
memo notes # List all notes
|
||||
memo notes -f "Folder Name" # Filter by folder
|
||||
memo notes -s "query" # Search notes (fuzzy)
|
||||
```
|
||||
|
||||
### Create Notes
|
||||
|
||||
```bash
|
||||
memo notes -a # Interactive editor
|
||||
memo notes -a "Note Title" # Quick add with title
|
||||
```
|
||||
|
||||
### Edit Notes
|
||||
|
||||
```bash
|
||||
memo notes -e # Interactive selection to edit
|
||||
```
|
||||
|
||||
### Delete Notes
|
||||
|
||||
```bash
|
||||
memo notes -d # Interactive selection to delete
|
||||
```
|
||||
|
||||
### Move Notes
|
||||
|
||||
```bash
|
||||
memo notes -m # Move note to folder (interactive)
|
||||
```
|
||||
|
||||
### Export Notes
|
||||
|
||||
```bash
|
||||
memo notes -ex # Export to HTML/Markdown
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Cannot edit notes containing images or attachments
|
||||
- Interactive prompts require terminal access (use pty=true if needed)
|
||||
- macOS only — requires Apple Notes.app
|
||||
|
||||
## Rules
|
||||
|
||||
1. Prefer Apple Notes when user wants cross-device sync (iPhone/iPad/Mac)
|
||||
2. Use the `memory` tool for agent-internal notes that don't need to sync
|
||||
3. Use the `obsidian` skill for Markdown-native knowledge management
|
||||
96
skills/apple/apple-reminders/SKILL.md
Normal file
96
skills/apple/apple-reminders/SKILL.md
Normal file
@@ -0,0 +1,96 @@
|
||||
---
|
||||
name: apple-reminders
|
||||
description: Manage Apple Reminders via remindctl CLI (list, add, complete, delete).
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Reminders, tasks, todo, macOS, Apple]
|
||||
---
|
||||
|
||||
# Apple Reminders
|
||||
|
||||
Use `remindctl` to manage Apple Reminders directly from the terminal. Tasks sync across all Apple devices via iCloud.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Reminders.app
|
||||
- Install: `brew install steipete/tap/remindctl`
|
||||
- Grant Reminders permission when prompted
|
||||
- Check: `remindctl status` / Request: `remindctl authorize`
|
||||
|
||||
## When to Use
|
||||
|
||||
- User mentions "reminder" or "Reminders app"
|
||||
- Creating personal to-dos with due dates that sync to iOS
|
||||
- Managing Apple Reminders lists
|
||||
- User wants tasks to appear on their iPhone/iPad
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Scheduling agent alerts → use the cronjob tool instead
|
||||
- Calendar events → use Apple Calendar or Google Calendar
|
||||
- Project task management → use GitHub Issues, Notion, etc.
|
||||
- If user says "remind me" but means an agent alert → clarify first
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### View Reminders
|
||||
|
||||
```bash
|
||||
remindctl # Today's reminders
|
||||
remindctl today # Today
|
||||
remindctl tomorrow # Tomorrow
|
||||
remindctl week # This week
|
||||
remindctl overdue # Past due
|
||||
remindctl all # Everything
|
||||
remindctl 2026-01-04 # Specific date
|
||||
```
|
||||
|
||||
### Manage Lists
|
||||
|
||||
```bash
|
||||
remindctl list # List all lists
|
||||
remindctl list Work # Show specific list
|
||||
remindctl list Projects --create # Create list
|
||||
remindctl list Work --delete # Delete list
|
||||
```
|
||||
|
||||
### Create Reminders
|
||||
|
||||
```bash
|
||||
remindctl add "Buy milk"
|
||||
remindctl add --title "Call mom" --list Personal --due tomorrow
|
||||
remindctl add --title "Meeting prep" --due "2026-02-15 09:00"
|
||||
```
|
||||
|
||||
### Complete / Delete
|
||||
|
||||
```bash
|
||||
remindctl complete 1 2 3 # Complete by ID
|
||||
remindctl delete 4A83 --force # Delete by ID
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
```bash
|
||||
remindctl today --json # JSON for scripting
|
||||
remindctl today --plain # TSV format
|
||||
remindctl today --quiet # Counts only
|
||||
```
|
||||
|
||||
## Date Formats
|
||||
|
||||
Accepted by `--due` and date filters:
|
||||
- `today`, `tomorrow`, `yesterday`
|
||||
- `YYYY-MM-DD`
|
||||
- `YYYY-MM-DD HH:mm`
|
||||
- ISO 8601 (`2026-01-04T12:34:56Z`)
|
||||
|
||||
## Rules
|
||||
|
||||
1. When user says "remind me", clarify: Apple Reminders (syncs to phone) vs agent cronjob alert
|
||||
2. Always confirm reminder content and due date before creating
|
||||
3. Use `--json` for programmatic parsing
|
||||
131
skills/apple/findmy/SKILL.md
Normal file
131
skills/apple/findmy/SKILL.md
Normal file
@@ -0,0 +1,131 @@
|
||||
---
|
||||
name: findmy
|
||||
description: Track Apple devices and AirTags via FindMy.app on macOS using AppleScript and screen capture.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [FindMy, AirTag, location, tracking, macOS, Apple]
|
||||
---
|
||||
|
||||
# Find My (Apple)
|
||||
|
||||
Track Apple devices and AirTags via the FindMy.app on macOS. Since Apple doesn't
|
||||
provide a CLI for FindMy, this skill uses AppleScript to open the app and
|
||||
screen capture to read device locations.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Find My app and iCloud signed in
|
||||
- Devices/AirTags already registered in Find My
|
||||
- Screen Recording permission for terminal (System Settings → Privacy → Screen Recording)
|
||||
- **Optional but recommended**: Install `peekaboo` for better UI automation:
|
||||
`brew install steipete/tap/peekaboo`
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks "where is my [device/cat/keys/bag]?"
|
||||
- Tracking AirTag locations
|
||||
- Checking device locations (iPhone, iPad, Mac, AirPods)
|
||||
- Monitoring pet or item movement over time (AirTag patrol routes)
|
||||
|
||||
## Method 1: AppleScript + Screenshot (Basic)
|
||||
|
||||
### Open FindMy and Navigate
|
||||
|
||||
```bash
|
||||
# Open Find My app
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
|
||||
# Wait for it to load
|
||||
sleep 3
|
||||
|
||||
# Take a screenshot of the Find My window
|
||||
screencapture -w -o /tmp/findmy.png
|
||||
```
|
||||
|
||||
Then use `vision_analyze` to read the screenshot:
|
||||
```
|
||||
vision_analyze(image_url="/tmp/findmy.png", question="What devices/items are shown and what are their locations?")
|
||||
```
|
||||
|
||||
### Switch Between Tabs
|
||||
|
||||
```bash
|
||||
# Switch to Devices tab
|
||||
osascript -e '
|
||||
tell application "System Events"
|
||||
tell process "FindMy"
|
||||
click button "Devices" of toolbar 1 of window 1
|
||||
end tell
|
||||
end tell'
|
||||
|
||||
# Switch to Items tab (AirTags)
|
||||
osascript -e '
|
||||
tell application "System Events"
|
||||
tell process "FindMy"
|
||||
click button "Items" of toolbar 1 of window 1
|
||||
end tell
|
||||
end tell'
|
||||
```
|
||||
|
||||
## Method 2: Peekaboo UI Automation (Recommended)
|
||||
|
||||
If `peekaboo` is installed, use it for more reliable UI interaction:
|
||||
|
||||
```bash
|
||||
# Open Find My
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
sleep 3
|
||||
|
||||
# Capture and annotate the UI
|
||||
peekaboo see --app "FindMy" --annotate --path /tmp/findmy-ui.png
|
||||
|
||||
# Click on a specific device/item by element ID
|
||||
peekaboo click --on B3 --app "FindMy"
|
||||
|
||||
# Capture the detail view
|
||||
peekaboo image --app "FindMy" --path /tmp/findmy-detail.png
|
||||
```
|
||||
|
||||
Then analyze with vision:
|
||||
```
|
||||
vision_analyze(image_url="/tmp/findmy-detail.png", question="What is the location shown for this device/item? Include address and coordinates if visible.")
|
||||
```
|
||||
|
||||
## Workflow: Track AirTag Location Over Time
|
||||
|
||||
For monitoring an AirTag (e.g., tracking a cat's patrol route):
|
||||
|
||||
```bash
|
||||
# 1. Open FindMy to Items tab
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
sleep 3
|
||||
|
||||
# 2. Click on the AirTag item (stay on page — AirTag only updates when page is open)
|
||||
|
||||
# 3. Periodically capture location
|
||||
while true; do
|
||||
screencapture -w -o /tmp/findmy-$(date +%H%M%S).png
|
||||
sleep 300 # Every 5 minutes
|
||||
done
|
||||
```
|
||||
|
||||
Analyze each screenshot with vision to extract coordinates, then compile a route.
|
||||
|
||||
## Limitations
|
||||
|
||||
- FindMy has **no CLI or API** — must use UI automation
|
||||
- AirTags only update location while the FindMy page is actively displayed
|
||||
- Location accuracy depends on nearby Apple devices in the FindMy network
|
||||
- Screen Recording permission required for screenshots
|
||||
- AppleScript UI automation may break across macOS versions
|
||||
|
||||
## Rules
|
||||
|
||||
1. Keep FindMy app in the foreground when tracking AirTags (updates stop when minimized)
|
||||
2. Use `vision_analyze` to read screenshot content — don't try to parse pixels
|
||||
3. For ongoing tracking, use a cronjob to periodically capture and log locations
|
||||
4. Respect privacy — only track devices/items the user owns
|
||||
100
skills/apple/imessage/SKILL.md
Normal file
100
skills/apple/imessage/SKILL.md
Normal file
@@ -0,0 +1,100 @@
|
||||
---
|
||||
name: imessage
|
||||
description: Send and receive iMessages/SMS via the imsg CLI on macOS.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [iMessage, SMS, messaging, macOS, Apple]
|
||||
---
|
||||
|
||||
# iMessage
|
||||
|
||||
Use `imsg` to read and send iMessage/SMS via macOS Messages.app.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Messages.app signed in
|
||||
- Install: `brew install steipete/tap/imsg`
|
||||
- Grant Full Disk Access for terminal (System Settings → Privacy → Full Disk Access)
|
||||
- Grant Automation permission for Messages.app when prompted
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to send an iMessage or text message
|
||||
- Reading iMessage conversation history
|
||||
- Checking recent Messages.app chats
|
||||
- Sending to phone numbers or Apple IDs
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Telegram/Discord/Slack/WhatsApp messages → use the appropriate gateway channel
|
||||
- Group chat management (adding/removing members) → not supported
|
||||
- Bulk/mass messaging → always confirm with user first
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### List Chats
|
||||
|
||||
```bash
|
||||
imsg chats --limit 10 --json
|
||||
```
|
||||
|
||||
### View History
|
||||
|
||||
```bash
|
||||
# By chat ID
|
||||
imsg history --chat-id 1 --limit 20 --json
|
||||
|
||||
# With attachments info
|
||||
imsg history --chat-id 1 --limit 20 --attachments --json
|
||||
```
|
||||
|
||||
### Send Messages
|
||||
|
||||
```bash
|
||||
# Text only
|
||||
imsg send --to "+14155551212" --text "Hello!"
|
||||
|
||||
# With attachment
|
||||
imsg send --to "+14155551212" --text "Check this out" --file /path/to/image.jpg
|
||||
|
||||
# Force iMessage or SMS
|
||||
imsg send --to "+14155551212" --text "Hi" --service imessage
|
||||
imsg send --to "+14155551212" --text "Hi" --service sms
|
||||
```
|
||||
|
||||
### Watch for New Messages
|
||||
|
||||
```bash
|
||||
imsg watch --chat-id 1 --attachments
|
||||
```
|
||||
|
||||
## Service Options
|
||||
|
||||
- `--service imessage` — Force iMessage (requires recipient has iMessage)
|
||||
- `--service sms` — Force SMS (green bubble)
|
||||
- `--service auto` — Let Messages.app decide (default)
|
||||
|
||||
## Rules
|
||||
|
||||
1. **Always confirm recipient and message content** before sending
|
||||
2. **Never send to unknown numbers** without explicit user approval
|
||||
3. **Verify file paths** exist before attaching
|
||||
4. **Don't spam** — rate-limit yourself
|
||||
|
||||
## Example Workflow
|
||||
|
||||
User: "Text mom that I'll be late"
|
||||
|
||||
```bash
|
||||
# 1. Find mom's chat
|
||||
imsg chats --limit 20 --json | jq '.[] | select(.displayName | contains("Mom"))'
|
||||
|
||||
# 2. Confirm with user: "Found Mom at +1555123456. Send 'I'll be late' via iMessage?"
|
||||
|
||||
# 3. Send after confirmation
|
||||
imsg send --to "+1555123456" --text "I'll be late"
|
||||
```
|
||||
76
skills/market-data/polymarket/SKILL.md
Normal file
76
skills/market-data/polymarket/SKILL.md
Normal file
@@ -0,0 +1,76 @@
|
||||
---
|
||||
name: polymarket
|
||||
description: Query Polymarket prediction market data — search markets, get prices, orderbooks, and price history. Read-only via public REST APIs, no API key needed.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent + Teknium
|
||||
tags: [polymarket, prediction-markets, market-data, trading]
|
||||
---
|
||||
|
||||
# Polymarket — Prediction Market Data
|
||||
|
||||
Query prediction market data from Polymarket using their public REST APIs.
|
||||
All endpoints are read-only and require zero authentication.
|
||||
|
||||
See `references/api-endpoints.md` for the full endpoint reference with curl examples.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks about prediction markets, betting odds, or event probabilities
|
||||
- User wants to know "what are the odds of X happening?"
|
||||
- User asks about Polymarket specifically
|
||||
- User wants market prices, orderbook data, or price history
|
||||
- User asks to monitor or track prediction market movements
|
||||
|
||||
## Key Concepts
|
||||
|
||||
- **Events** contain one or more **Markets** (1:many relationship)
|
||||
- **Markets** are binary outcomes with Yes/No prices between 0.00 and 1.00
|
||||
- Prices ARE probabilities: price 0.65 means the market thinks 65% likely
|
||||
- `outcomePrices` field: JSON-encoded array like `["0.80", "0.20"]`
|
||||
- `clobTokenIds` field: JSON-encoded array of two token IDs [Yes, No] for price/book queries
|
||||
- `conditionId` field: hex string used for price history queries
|
||||
- Volume is in USDC (US dollars)
|
||||
|
||||
## Three Public APIs
|
||||
|
||||
1. **Gamma API** at `gamma-api.polymarket.com` — Discovery, search, browsing
|
||||
2. **CLOB API** at `clob.polymarket.com` — Real-time prices, orderbooks, history
|
||||
3. **Data API** at `data-api.polymarket.com` — Trades, open interest
|
||||
|
||||
## Typical Workflow
|
||||
|
||||
When a user asks about prediction market odds:
|
||||
|
||||
1. **Search** using the Gamma API public-search endpoint with their query
|
||||
2. **Parse** the response — extract events and their nested markets
|
||||
3. **Present** market question, current prices as percentages, and volume
|
||||
4. **Deep dive** if asked — use clobTokenIds for orderbook, conditionId for history
|
||||
|
||||
## Presenting Results
|
||||
|
||||
Format prices as percentages for readability:
|
||||
- outcomePrices `["0.652", "0.348"]` becomes "Yes: 65.2%, No: 34.8%"
|
||||
- Always show the market question and probability
|
||||
- Include volume when available
|
||||
|
||||
Example: `"Will X happen?" — 65.2% Yes ($1.2M volume)`
|
||||
|
||||
## Parsing Double-Encoded Fields
|
||||
|
||||
The Gamma API returns `outcomePrices`, `outcomes`, and `clobTokenIds` as JSON strings
|
||||
inside JSON responses (double-encoded). When processing with Python, parse them with
|
||||
`json.loads(market['outcomePrices'])` to get the actual array.
|
||||
|
||||
## Rate Limits
|
||||
|
||||
Generous — unlikely to hit for normal usage:
|
||||
- Gamma: 4,000 requests per 10 seconds (general)
|
||||
- CLOB: 9,000 requests per 10 seconds (general)
|
||||
- Data: 1,000 requests per 10 seconds (general)
|
||||
|
||||
## Limitations
|
||||
|
||||
- This skill is read-only — it does not support placing trades
|
||||
- Trading requires wallet-based crypto authentication (EIP-712 signatures)
|
||||
- Some new markets may have empty price history
|
||||
- Geographic restrictions apply to trading but read-only data is globally accessible
|
||||
220
skills/market-data/polymarket/references/api-endpoints.md
Normal file
220
skills/market-data/polymarket/references/api-endpoints.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# Polymarket API Endpoints Reference
|
||||
|
||||
All endpoints are public REST (GET), return JSON, and need no authentication.
|
||||
|
||||
## Gamma API — gamma-api.polymarket.com
|
||||
|
||||
### Search Markets
|
||||
|
||||
```
|
||||
GET /public-search?q=QUERY
|
||||
```
|
||||
|
||||
Response structure:
|
||||
```json
|
||||
{
|
||||
"events": [
|
||||
{
|
||||
"id": "12345",
|
||||
"title": "Event title",
|
||||
"slug": "event-slug",
|
||||
"volume": 1234567.89,
|
||||
"markets": [
|
||||
{
|
||||
"question": "Will X happen?",
|
||||
"outcomePrices": "[\"0.65\", \"0.35\"]",
|
||||
"outcomes": "[\"Yes\", \"No\"]",
|
||||
"clobTokenIds": "[\"TOKEN_YES\", \"TOKEN_NO\"]",
|
||||
"conditionId": "0xabc...",
|
||||
"volume": 500000
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"pagination": {"hasMore": true, "totalResults": 100}
|
||||
}
|
||||
```
|
||||
|
||||
### List Events
|
||||
|
||||
```
|
||||
GET /events?limit=N&active=true&closed=false&order=volume&ascending=false
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `limit` — max results (default varies)
|
||||
- `offset` — pagination offset
|
||||
- `active` — true/false
|
||||
- `closed` — true/false
|
||||
- `order` — sort field: `volume`, `createdAt`, `updatedAt`
|
||||
- `ascending` — true/false
|
||||
- `tag` — filter by tag slug
|
||||
- `slug` — get specific event by slug
|
||||
|
||||
Response: array of event objects. Each event includes a `markets` array.
|
||||
|
||||
Event fields: `id`, `title`, `slug`, `description`, `volume`, `liquidity`,
|
||||
`openInterest`, `active`, `closed`, `category`, `startDate`, `endDate`,
|
||||
`markets` (array of market objects).
|
||||
|
||||
### List Markets
|
||||
|
||||
```
|
||||
GET /markets?limit=N&active=true&closed=false&order=volume&ascending=false
|
||||
```
|
||||
|
||||
Same filter parameters as events, plus:
|
||||
- `slug` — get specific market by slug
|
||||
|
||||
Market fields: `id`, `question`, `conditionId`, `slug`, `description`,
|
||||
`outcomes`, `outcomePrices`, `volume`, `liquidity`, `active`, `closed`,
|
||||
`marketType`, `clobTokenIds`, `endDate`, `category`, `createdAt`.
|
||||
|
||||
Important: `outcomePrices`, `outcomes`, and `clobTokenIds` are JSON strings
|
||||
(double-encoded). Parse with json.loads() in Python.
|
||||
|
||||
### List Tags
|
||||
|
||||
```
|
||||
GET /tags
|
||||
```
|
||||
|
||||
Returns array of tag objects: `id`, `label`, `slug`.
|
||||
Use the `slug` value when filtering events/markets by tag.
|
||||
|
||||
---
|
||||
|
||||
## CLOB API — clob.polymarket.com
|
||||
|
||||
All CLOB price endpoints use `token_id` from the market's `clobTokenIds` field.
|
||||
Index 0 = Yes outcome, Index 1 = No outcome.
|
||||
|
||||
### Current Price
|
||||
|
||||
```
|
||||
GET /price?token_id=TOKEN_ID&side=buy
|
||||
```
|
||||
|
||||
Response: `{"price": "0.650"}`
|
||||
|
||||
The `side` parameter: `buy` or `sell`.
|
||||
|
||||
### Midpoint Price
|
||||
|
||||
```
|
||||
GET /midpoint?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response: `{"mid": "0.645"}`
|
||||
|
||||
### Spread
|
||||
|
||||
```
|
||||
GET /spread?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response: `{"spread": "0.02"}`
|
||||
|
||||
### Orderbook
|
||||
|
||||
```
|
||||
GET /book?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"market": "condition_id",
|
||||
"asset_id": "token_id",
|
||||
"bids": [{"price": "0.64", "size": "500"}, ...],
|
||||
"asks": [{"price": "0.66", "size": "300"}, ...],
|
||||
"min_order_size": "5",
|
||||
"tick_size": "0.01",
|
||||
"last_trade_price": "0.65"
|
||||
}
|
||||
```
|
||||
|
||||
Bids and asks are sorted by price. Size is in shares (USDC-denominated).
|
||||
|
||||
### Price History
|
||||
|
||||
```
|
||||
GET /prices-history?market=CONDITION_ID&interval=INTERVAL&fidelity=N
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `market` — the conditionId (hex string with 0x prefix)
|
||||
- `interval` — time range: `all`, `1d`, `1w`, `1m`, `3m`, `6m`, `1y`
|
||||
- `fidelity` — number of data points to return
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"history": [
|
||||
{"t": 1709000000, "p": "0.55"},
|
||||
{"t": 1709100000, "p": "0.58"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`t` is Unix timestamp, `p` is price (probability).
|
||||
|
||||
Note: Very new markets may return empty history.
|
||||
|
||||
### CLOB Markets List
|
||||
|
||||
```
|
||||
GET /markets?limit=N
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"condition_id": "0xabc...",
|
||||
"question": "Will X?",
|
||||
"tokens": [
|
||||
{"token_id": "123...", "outcome": "Yes", "price": 0.65},
|
||||
{"token_id": "456...", "outcome": "No", "price": 0.35}
|
||||
],
|
||||
"active": true,
|
||||
"closed": false
|
||||
}
|
||||
],
|
||||
"next_cursor": "cursor_string",
|
||||
"limit": 100,
|
||||
"count": 1000
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data API — data-api.polymarket.com
|
||||
|
||||
### Recent Trades
|
||||
|
||||
```
|
||||
GET /trades?limit=N
|
||||
GET /trades?market=CONDITION_ID&limit=N
|
||||
```
|
||||
|
||||
Trade fields: `side` (BUY/SELL), `size`, `price`, `timestamp`,
|
||||
`title`, `slug`, `outcome`, `transactionHash`, `conditionId`.
|
||||
|
||||
### Open Interest
|
||||
|
||||
```
|
||||
GET /oi?market=CONDITION_ID
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field Cross-Reference
|
||||
|
||||
To go from a Gamma market to CLOB data:
|
||||
|
||||
1. Get market from Gamma: has `clobTokenIds` and `conditionId`
|
||||
2. Parse `clobTokenIds` (JSON string): `["YES_TOKEN", "NO_TOKEN"]`
|
||||
3. Use YES_TOKEN with `/price`, `/book`, `/midpoint`, `/spread`
|
||||
4. Use `conditionId` with `/prices-history` and Data API endpoints
|
||||
284
skills/market-data/polymarket/scripts/polymarket.py
Normal file
284
skills/market-data/polymarket/scripts/polymarket.py
Normal file
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Polymarket CLI helper — query prediction market data.
|
||||
|
||||
Usage:
|
||||
python3 polymarket.py search "bitcoin"
|
||||
python3 polymarket.py trending [--limit 10]
|
||||
python3 polymarket.py market <slug>
|
||||
python3 polymarket.py event <slug>
|
||||
python3 polymarket.py price <token_id>
|
||||
python3 polymarket.py book <token_id>
|
||||
python3 polymarket.py history <condition_id> [--interval all] [--fidelity 50]
|
||||
python3 polymarket.py trades [--limit 10] [--market CONDITION_ID]
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import urllib.error
|
||||
|
||||
GAMMA = "https://gamma-api.polymarket.com"
|
||||
CLOB = "https://clob.polymarket.com"
|
||||
DATA = "https://data-api.polymarket.com"
|
||||
|
||||
|
||||
def _get(url: str) -> dict | list:
|
||||
"""GET request, return parsed JSON."""
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "hermes-agent/1.0"})
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
print(f"HTTP {e.code}: {e.reason}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except urllib.error.URLError as e:
|
||||
print(f"Connection error: {e.reason}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _parse_json_field(val):
|
||||
"""Parse double-encoded JSON fields (outcomePrices, outcomes, clobTokenIds)."""
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return val
|
||||
return val
|
||||
|
||||
|
||||
def _fmt_pct(price_str: str) -> str:
|
||||
"""Format price string as percentage."""
|
||||
try:
|
||||
return f"{float(price_str) * 100:.1f}%"
|
||||
except (ValueError, TypeError):
|
||||
return price_str
|
||||
|
||||
|
||||
def _fmt_volume(vol) -> str:
|
||||
"""Format volume as human-readable."""
|
||||
try:
|
||||
v = float(vol)
|
||||
if v >= 1_000_000:
|
||||
return f"${v / 1_000_000:.1f}M"
|
||||
if v >= 1_000:
|
||||
return f"${v / 1_000:.1f}K"
|
||||
return f"${v:.0f}"
|
||||
except (ValueError, TypeError):
|
||||
return str(vol)
|
||||
|
||||
|
||||
def _print_market(m: dict, indent: str = ""):
|
||||
"""Print a market summary."""
|
||||
question = m.get("question", "?")
|
||||
prices = _parse_json_field(m.get("outcomePrices", "[]"))
|
||||
outcomes = _parse_json_field(m.get("outcomes", "[]"))
|
||||
vol = _fmt_volume(m.get("volume", 0))
|
||||
closed = m.get("closed", False)
|
||||
status = " [CLOSED]" if closed else ""
|
||||
|
||||
if isinstance(prices, list) and len(prices) >= 2:
|
||||
outcome_labels = outcomes if isinstance(outcomes, list) else ["Yes", "No"]
|
||||
price_str = " / ".join(
|
||||
f"{outcome_labels[i]}: {_fmt_pct(prices[i])}"
|
||||
for i in range(min(len(prices), len(outcome_labels)))
|
||||
)
|
||||
print(f"{indent}{question}{status}")
|
||||
print(f"{indent} {price_str} | Volume: {vol}")
|
||||
else:
|
||||
print(f"{indent}{question}{status} | Volume: {vol}")
|
||||
|
||||
slug = m.get("slug", "")
|
||||
if slug:
|
||||
print(f"{indent} slug: {slug}")
|
||||
|
||||
|
||||
def cmd_search(query: str):
|
||||
"""Search for markets."""
|
||||
q = urllib.parse.quote(query)
|
||||
data = _get(f"{GAMMA}/public-search?q={q}")
|
||||
events = data.get("events", [])
|
||||
total = data.get("pagination", {}).get("totalResults", len(events))
|
||||
print(f"Found {total} results for \"{query}\":\n")
|
||||
for evt in events[:10]:
|
||||
print(f"=== {evt['title']} ===")
|
||||
print(f" Volume: {_fmt_volume(evt.get('volume', 0))} | slug: {evt.get('slug', '')}")
|
||||
markets = evt.get("markets", [])
|
||||
for m in markets[:5]:
|
||||
_print_market(m, indent=" ")
|
||||
if len(markets) > 5:
|
||||
print(f" ... and {len(markets) - 5} more markets")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_trending(limit: int = 10):
|
||||
"""Show trending events by volume."""
|
||||
events = _get(f"{GAMMA}/events?limit={limit}&active=true&closed=false&order=volume&ascending=false")
|
||||
print(f"Top {len(events)} trending events:\n")
|
||||
for i, evt in enumerate(events, 1):
|
||||
print(f"{i}. {evt['title']}")
|
||||
print(f" Volume: {_fmt_volume(evt.get('volume', 0))} | Markets: {len(evt.get('markets', []))}")
|
||||
print(f" slug: {evt.get('slug', '')}")
|
||||
markets = evt.get("markets", [])
|
||||
for m in markets[:3]:
|
||||
_print_market(m, indent=" ")
|
||||
if len(markets) > 3:
|
||||
print(f" ... and {len(markets) - 3} more markets")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_market(slug: str):
|
||||
"""Get market details by slug."""
|
||||
markets = _get(f"{GAMMA}/markets?slug={urllib.parse.quote(slug)}")
|
||||
if not markets:
|
||||
print(f"No market found with slug: {slug}")
|
||||
return
|
||||
m = markets[0]
|
||||
print(f"Market: {m.get('question', '?')}")
|
||||
print(f"Status: {'CLOSED' if m.get('closed') else 'ACTIVE'}")
|
||||
_print_market(m)
|
||||
print(f"\n conditionId: {m.get('conditionId', 'N/A')}")
|
||||
tokens = _parse_json_field(m.get("clobTokenIds", "[]"))
|
||||
if isinstance(tokens, list):
|
||||
outcomes = _parse_json_field(m.get("outcomes", "[]"))
|
||||
for i, t in enumerate(tokens):
|
||||
label = outcomes[i] if isinstance(outcomes, list) and i < len(outcomes) else f"Outcome {i}"
|
||||
print(f" token ({label}): {t}")
|
||||
desc = m.get("description", "")
|
||||
if desc:
|
||||
print(f"\n Description: {desc[:500]}")
|
||||
|
||||
|
||||
def cmd_event(slug: str):
|
||||
"""Get event details by slug."""
|
||||
events = _get(f"{GAMMA}/events?slug={urllib.parse.quote(slug)}")
|
||||
if not events:
|
||||
print(f"No event found with slug: {slug}")
|
||||
return
|
||||
evt = events[0]
|
||||
print(f"Event: {evt['title']}")
|
||||
print(f"Volume: {_fmt_volume(evt.get('volume', 0))}")
|
||||
print(f"Status: {'CLOSED' if evt.get('closed') else 'ACTIVE'}")
|
||||
print(f"Markets: {len(evt.get('markets', []))}\n")
|
||||
for m in evt.get("markets", []):
|
||||
_print_market(m, indent=" ")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_price(token_id: str):
|
||||
"""Get current price for a token."""
|
||||
buy = _get(f"{CLOB}/price?token_id={token_id}&side=buy")
|
||||
mid = _get(f"{CLOB}/midpoint?token_id={token_id}")
|
||||
spread = _get(f"{CLOB}/spread?token_id={token_id}")
|
||||
print(f"Token: {token_id[:30]}...")
|
||||
print(f" Buy price: {_fmt_pct(buy.get('price', '?'))}")
|
||||
print(f" Midpoint: {_fmt_pct(mid.get('mid', '?'))}")
|
||||
print(f" Spread: {spread.get('spread', '?')}")
|
||||
|
||||
|
||||
def cmd_book(token_id: str):
|
||||
"""Get orderbook for a token."""
|
||||
book = _get(f"{CLOB}/book?token_id={token_id}")
|
||||
bids = book.get("bids", [])
|
||||
asks = book.get("asks", [])
|
||||
last = book.get("last_trade_price", "?")
|
||||
print(f"Orderbook for {token_id[:30]}...")
|
||||
print(f"Last trade: {_fmt_pct(last)} | Tick size: {book.get('tick_size', '?')}")
|
||||
print(f"\n Top bids ({len(bids)} total):")
|
||||
# Show bids sorted by price descending (best bids first)
|
||||
sorted_bids = sorted(bids, key=lambda x: float(x.get("price", 0)), reverse=True)
|
||||
for b in sorted_bids[:10]:
|
||||
print(f" {_fmt_pct(b['price']):>7} | Size: {float(b['size']):>10.2f}")
|
||||
print(f"\n Top asks ({len(asks)} total):")
|
||||
sorted_asks = sorted(asks, key=lambda x: float(x.get("price", 0)))
|
||||
for a in sorted_asks[:10]:
|
||||
print(f" {_fmt_pct(a['price']):>7} | Size: {float(a['size']):>10.2f}")
|
||||
|
||||
|
||||
def cmd_history(condition_id: str, interval: str = "all", fidelity: int = 50):
|
||||
"""Get price history for a market."""
|
||||
data = _get(f"{CLOB}/prices-history?market={condition_id}&interval={interval}&fidelity={fidelity}")
|
||||
history = data.get("history", [])
|
||||
if not history:
|
||||
print("No price history available for this market.")
|
||||
return
|
||||
print(f"Price history ({len(history)} points, interval={interval}):\n")
|
||||
from datetime import datetime, timezone
|
||||
for pt in history:
|
||||
ts = datetime.fromtimestamp(pt["t"], tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
|
||||
price = _fmt_pct(pt["p"])
|
||||
bar = "█" * int(float(pt["p"]) * 40)
|
||||
print(f" {ts} {price:>7} {bar}")
|
||||
|
||||
|
||||
def cmd_trades(limit: int = 10, market: str = None):
|
||||
"""Get recent trades."""
|
||||
url = f"{DATA}/trades?limit={limit}"
|
||||
if market:
|
||||
url += f"&market={market}"
|
||||
trades = _get(url)
|
||||
if not isinstance(trades, list):
|
||||
print(f"Unexpected response: {trades}")
|
||||
return
|
||||
print(f"Recent trades ({len(trades)}):\n")
|
||||
for t in trades:
|
||||
side = t.get("side", "?")
|
||||
price = _fmt_pct(t.get("price", "?"))
|
||||
size = t.get("size", "?")
|
||||
outcome = t.get("outcome", "?")
|
||||
title = t.get("title", "?")[:50]
|
||||
ts = t.get("timestamp", "")
|
||||
print(f" {side:4} {price:>7} x{float(size):>8.2f} [{outcome}] {title}")
|
||||
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
if not args or args[0] in ("-h", "--help", "help"):
|
||||
print(__doc__)
|
||||
return
|
||||
|
||||
cmd = args[0]
|
||||
|
||||
if cmd == "search" and len(args) >= 2:
|
||||
cmd_search(" ".join(args[1:]))
|
||||
elif cmd == "trending":
|
||||
limit = 10
|
||||
if "--limit" in args:
|
||||
idx = args.index("--limit")
|
||||
limit = int(args[idx + 1]) if idx + 1 < len(args) else 10
|
||||
cmd_trending(limit)
|
||||
elif cmd == "market" and len(args) >= 2:
|
||||
cmd_market(args[1])
|
||||
elif cmd == "event" and len(args) >= 2:
|
||||
cmd_event(args[1])
|
||||
elif cmd == "price" and len(args) >= 2:
|
||||
cmd_price(args[1])
|
||||
elif cmd == "book" and len(args) >= 2:
|
||||
cmd_book(args[1])
|
||||
elif cmd == "history" and len(args) >= 2:
|
||||
interval = "all"
|
||||
fidelity = 50
|
||||
if "--interval" in args:
|
||||
idx = args.index("--interval")
|
||||
interval = args[idx + 1] if idx + 1 < len(args) else "all"
|
||||
if "--fidelity" in args:
|
||||
idx = args.index("--fidelity")
|
||||
fidelity = int(args[idx + 1]) if idx + 1 < len(args) else 50
|
||||
cmd_history(args[1], interval, fidelity)
|
||||
elif cmd == "trades":
|
||||
limit = 10
|
||||
market = None
|
||||
if "--limit" in args:
|
||||
idx = args.index("--limit")
|
||||
limit = int(args[idx + 1]) if idx + 1 < len(args) else 10
|
||||
if "--market" in args:
|
||||
idx = args.index("--market")
|
||||
market = args[idx + 1] if idx + 1 < len(args) else None
|
||||
cmd_trades(limit, market)
|
||||
else:
|
||||
print(f"Unknown command: {cmd}")
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -176,3 +176,93 @@ class TestCompressWithClient:
|
||||
contents = [m.get("content", "") for m in result]
|
||||
assert any("CONTEXT SUMMARY" in c for c in contents)
|
||||
assert len(result) < len(msgs)
|
||||
|
||||
def test_summarization_does_not_split_tool_call_pairs(self):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compressed middle"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
c = ContextCompressor(
|
||||
model="test",
|
||||
quiet_mode=True,
|
||||
protect_first_n=3,
|
||||
protect_last_n=4,
|
||||
)
|
||||
|
||||
msgs = [
|
||||
{"role": "user", "content": "Could you address the reviewer comments in PR#71"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_a", "type": "function", "function": {"name": "skill_view", "arguments": "{}"}},
|
||||
{"id": "call_b", "type": "function", "function": {"name": "skill_view", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_a", "content": "output a"},
|
||||
{"role": "tool", "tool_call_id": "call_b", "content": "output b"},
|
||||
{"role": "user", "content": "later 1"},
|
||||
{"role": "assistant", "content": "later 2"},
|
||||
{"role": "tool", "tool_call_id": "call_x", "content": "later output"},
|
||||
{"role": "assistant", "content": "later 3"},
|
||||
{"role": "user", "content": "later 4"},
|
||||
]
|
||||
|
||||
result = c.compress(msgs)
|
||||
|
||||
answered_ids = {
|
||||
msg.get("tool_call_id")
|
||||
for msg in result
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id")
|
||||
}
|
||||
for msg in result:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
assert tc["id"] in answered_ids
|
||||
|
||||
def test_summarization_does_not_start_tail_with_tool_outputs(self):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compressed middle"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
c = ContextCompressor(
|
||||
model="test",
|
||||
quiet_mode=True,
|
||||
protect_first_n=2,
|
||||
protect_last_n=3,
|
||||
)
|
||||
|
||||
msgs = [
|
||||
{"role": "user", "content": "earlier 1"},
|
||||
{"role": "assistant", "content": "earlier 2"},
|
||||
{"role": "user", "content": "earlier 3"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_c", "type": "function", "function": {"name": "search_files", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_c", "content": "output c"},
|
||||
{"role": "user", "content": "latest user"},
|
||||
]
|
||||
|
||||
result = c.compress(msgs)
|
||||
|
||||
called_ids = {
|
||||
tc["id"]
|
||||
for msg in result
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
for tc in msg["tool_calls"]
|
||||
}
|
||||
for msg in result:
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id"):
|
||||
assert msg["tool_call_id"] in called_ids
|
||||
|
||||
@@ -165,6 +165,52 @@ class TestBuildSkillsSystemPrompt:
|
||||
# "search" should appear only once per category
|
||||
assert result.count("- search") == 1
|
||||
|
||||
def test_excludes_incompatible_platform_skills(self, monkeypatch, tmp_path):
|
||||
"""Skills with platforms: [macos] should not appear on Linux."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
skills_dir = tmp_path / "skills" / "apple"
|
||||
skills_dir.mkdir(parents=True)
|
||||
|
||||
# macOS-only skill
|
||||
mac_skill = skills_dir / "imessage"
|
||||
mac_skill.mkdir()
|
||||
(mac_skill / "SKILL.md").write_text(
|
||||
"---\nname: imessage\ndescription: Send iMessages\nplatforms: [macos]\n---\n"
|
||||
)
|
||||
|
||||
# Universal skill
|
||||
uni_skill = skills_dir / "web-search"
|
||||
uni_skill.mkdir()
|
||||
(uni_skill / "SKILL.md").write_text(
|
||||
"---\nname: web-search\ndescription: Search the web\n---\n"
|
||||
)
|
||||
|
||||
from unittest.mock import patch
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
result = build_skills_system_prompt()
|
||||
|
||||
assert "web-search" in result
|
||||
assert "imessage" not in result
|
||||
|
||||
def test_includes_matching_platform_skills(self, monkeypatch, tmp_path):
|
||||
"""Skills with platforms: [macos] should appear on macOS."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
skills_dir = tmp_path / "skills" / "apple"
|
||||
mac_skill = skills_dir / "imessage"
|
||||
mac_skill.mkdir(parents=True)
|
||||
(mac_skill / "SKILL.md").write_text(
|
||||
"---\nname: imessage\ndescription: Send iMessages\nplatforms: [macos]\n---\n"
|
||||
)
|
||||
|
||||
from unittest.mock import patch
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
result = build_skills_system_prompt()
|
||||
|
||||
assert "imessage" in result
|
||||
assert "Send iMessages" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Context files prompt builder
|
||||
@@ -172,7 +218,11 @@ class TestBuildSkillsSystemPrompt:
|
||||
|
||||
class TestBuildContextFilesPrompt:
|
||||
def test_empty_dir_returns_empty(self, tmp_path):
|
||||
result = build_context_files_prompt(cwd=str(tmp_path))
|
||||
from unittest.mock import patch
|
||||
fake_home = tmp_path / "fake_home"
|
||||
fake_home.mkdir()
|
||||
with patch("pathlib.Path.home", return_value=fake_home):
|
||||
result = build_context_files_prompt(cwd=str(tmp_path))
|
||||
assert result == ""
|
||||
|
||||
def test_loads_agents_md(self, tmp_path):
|
||||
|
||||
87
tests/agent/test_skill_commands.py
Normal file
87
tests/agent/test_skill_commands.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Tests for agent/skill_commands.py — skill slash command scanning and platform filtering."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.skill_commands import scan_skill_commands, build_skill_invocation_message
|
||||
|
||||
|
||||
def _make_skill(skills_dir, name, frontmatter_extra="", body="Do the thing.", category=None):
|
||||
"""Helper to create a minimal skill directory with SKILL.md."""
|
||||
if category:
|
||||
skill_dir = skills_dir / category / name
|
||||
else:
|
||||
skill_dir = skills_dir / name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
content = f"""\
|
||||
---
|
||||
name: {name}
|
||||
description: Description for {name}.
|
||||
{frontmatter_extra}---
|
||||
|
||||
# {name}
|
||||
|
||||
{body}
|
||||
"""
|
||||
(skill_dir / "SKILL.md").write_text(content)
|
||||
return skill_dir
|
||||
|
||||
|
||||
class TestScanSkillCommands:
|
||||
def test_finds_skills(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(tmp_path, "my-skill")
|
||||
result = scan_skill_commands()
|
||||
assert "/my-skill" in result
|
||||
assert result["/my-skill"]["name"] == "my-skill"
|
||||
|
||||
def test_empty_dir(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
result = scan_skill_commands()
|
||||
assert result == {}
|
||||
|
||||
def test_excludes_incompatible_platform(self, tmp_path):
|
||||
"""macOS-only skills should not register slash commands on Linux."""
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path), \
|
||||
patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
_make_skill(tmp_path, "imessage", frontmatter_extra="platforms: [macos]\n")
|
||||
_make_skill(tmp_path, "web-search")
|
||||
result = scan_skill_commands()
|
||||
assert "/web-search" in result
|
||||
assert "/imessage" not in result
|
||||
|
||||
def test_includes_matching_platform(self, tmp_path):
|
||||
"""macOS-only skills should register slash commands on macOS."""
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path), \
|
||||
patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
_make_skill(tmp_path, "imessage", frontmatter_extra="platforms: [macos]\n")
|
||||
result = scan_skill_commands()
|
||||
assert "/imessage" in result
|
||||
|
||||
def test_universal_skill_on_any_platform(self, tmp_path):
|
||||
"""Skills without platforms field should register on any platform."""
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path), \
|
||||
patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "win32"
|
||||
_make_skill(tmp_path, "generic-tool")
|
||||
result = scan_skill_commands()
|
||||
assert "/generic-tool" in result
|
||||
|
||||
|
||||
class TestBuildSkillInvocationMessage:
|
||||
def test_builds_message(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(tmp_path, "test-skill")
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/test-skill", "do stuff")
|
||||
assert msg is not None
|
||||
assert "test-skill" in msg
|
||||
assert "do stuff" in msg
|
||||
|
||||
def test_returns_none_for_unknown(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/nonexistent")
|
||||
assert msg is None
|
||||
@@ -75,8 +75,9 @@ class TestParseSchedule:
|
||||
run_at_str = result["run_at"]
|
||||
assert isinstance(run_at_str, str)
|
||||
run_at = datetime.fromisoformat(run_at_str)
|
||||
assert run_at > datetime.now()
|
||||
assert run_at < datetime.now() + timedelta(minutes=31)
|
||||
now = datetime.now().astimezone()
|
||||
assert run_at > now
|
||||
assert run_at < now + timedelta(minutes=31)
|
||||
|
||||
def test_every_becomes_interval(self):
|
||||
result = parse_schedule("every 2h")
|
||||
@@ -129,15 +130,15 @@ class TestComputeNextRun:
|
||||
result = compute_next_run(schedule)
|
||||
next_dt = datetime.fromisoformat(result)
|
||||
# Should be ~60 minutes from now
|
||||
assert next_dt > datetime.now() + timedelta(minutes=59)
|
||||
assert next_dt > datetime.now().astimezone() + timedelta(minutes=59)
|
||||
|
||||
def test_interval_subsequent_run(self):
|
||||
schedule = {"kind": "interval", "minutes": 30}
|
||||
last = datetime.now().isoformat()
|
||||
last = datetime.now().astimezone().isoformat()
|
||||
result = compute_next_run(schedule, last_run_at=last)
|
||||
next_dt = datetime.fromisoformat(result)
|
||||
# Should be ~30 minutes from last run
|
||||
assert next_dt > datetime.now() + timedelta(minutes=29)
|
||||
assert next_dt > datetime.now().astimezone() + timedelta(minutes=29)
|
||||
|
||||
def test_cron_returns_future(self):
|
||||
pytest.importorskip("croniter")
|
||||
@@ -147,7 +148,7 @@ class TestComputeNextRun:
|
||||
assert len(result) > 0
|
||||
next_dt = datetime.fromisoformat(result)
|
||||
assert isinstance(next_dt, datetime)
|
||||
assert next_dt > datetime.now()
|
||||
assert next_dt > datetime.now().astimezone()
|
||||
|
||||
def test_unknown_kind_returns_none(self):
|
||||
assert compute_next_run({"kind": "unknown"}) is None
|
||||
|
||||
180
tests/gateway/test_async_memory_flush.py
Normal file
180
tests/gateway/test_async_memory_flush.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Tests for proactive memory flush on session expiry.
|
||||
|
||||
Verifies that:
|
||||
1. _is_session_expired() works from a SessionEntry alone (no source needed)
|
||||
2. The sync callback is no longer called in get_or_create_session
|
||||
3. _pre_flushed_sessions tracking works correctly
|
||||
4. The background watcher can detect expired sessions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from gateway.config import Platform, GatewayConfig, SessionResetPolicy
|
||||
from gateway.session import SessionSource, SessionStore, SessionEntry
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def idle_store(tmp_path):
|
||||
"""SessionStore with a 60-minute idle reset policy."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="idle", idle_minutes=60),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def no_reset_store(tmp_path):
|
||||
"""SessionStore with no reset policy (mode=none)."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="none"),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
|
||||
class TestIsSessionExpired:
|
||||
"""_is_session_expired should detect expiry from entry alone."""
|
||||
|
||||
def test_idle_session_expired(self, idle_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_1",
|
||||
created_at=datetime.now() - timedelta(hours=3),
|
||||
updated_at=datetime.now() - timedelta(minutes=120),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is True
|
||||
|
||||
def test_active_session_not_expired(self, idle_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_2",
|
||||
created_at=datetime.now() - timedelta(hours=1),
|
||||
updated_at=datetime.now() - timedelta(minutes=10),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is False
|
||||
|
||||
def test_none_mode_never_expires(self, no_reset_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_3",
|
||||
created_at=datetime.now() - timedelta(days=30),
|
||||
updated_at=datetime.now() - timedelta(days=30),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert no_reset_store._is_session_expired(entry) is False
|
||||
|
||||
def test_active_processes_prevent_expiry(self, idle_store):
|
||||
"""Sessions with active background processes should never expire."""
|
||||
idle_store._has_active_processes_fn = lambda key: True
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_4",
|
||||
created_at=datetime.now() - timedelta(hours=5),
|
||||
updated_at=datetime.now() - timedelta(hours=5),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is False
|
||||
|
||||
def test_daily_mode_expired(self, tmp_path):
|
||||
"""Daily mode should expire sessions from before today's reset hour."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="daily", at_hour=4),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = None
|
||||
store._loaded = True
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_5",
|
||||
created_at=datetime.now() - timedelta(days=2),
|
||||
updated_at=datetime.now() - timedelta(days=2),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert store._is_session_expired(entry) is True
|
||||
|
||||
|
||||
class TestGetOrCreateSessionNoCallback:
|
||||
"""get_or_create_session should NOT call a sync flush callback."""
|
||||
|
||||
def test_auto_reset_cleans_pre_flushed_marker(self, idle_store):
|
||||
"""When a session auto-resets, the pre_flushed marker should be discarded."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
# Create initial session
|
||||
entry1 = idle_store.get_or_create_session(source)
|
||||
old_sid = entry1.session_id
|
||||
|
||||
# Simulate the watcher having flushed it
|
||||
idle_store._pre_flushed_sessions.add(old_sid)
|
||||
|
||||
# Simulate the session going idle
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||
idle_store._save()
|
||||
|
||||
# Next call should auto-reset
|
||||
entry2 = idle_store.get_or_create_session(source)
|
||||
assert entry2.session_id != old_sid
|
||||
assert entry2.was_auto_reset is True
|
||||
|
||||
# The old session_id should be removed from pre_flushed
|
||||
assert old_sid not in idle_store._pre_flushed_sessions
|
||||
|
||||
def test_no_sync_callback_invoked(self, idle_store):
|
||||
"""No synchronous callback should block during auto-reset."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
entry1 = idle_store.get_or_create_session(source)
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||
idle_store._save()
|
||||
|
||||
# Verify no _on_auto_reset attribute
|
||||
assert not hasattr(idle_store, '_on_auto_reset')
|
||||
|
||||
# This should NOT block (no sync LLM call)
|
||||
entry2 = idle_store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
|
||||
|
||||
class TestPreFlushedSessionsTracking:
|
||||
"""The _pre_flushed_sessions set should prevent double-flushing."""
|
||||
|
||||
def test_starts_empty(self, idle_store):
|
||||
assert len(idle_store._pre_flushed_sessions) == 0
|
||||
|
||||
def test_add_and_check(self, idle_store):
|
||||
idle_store._pre_flushed_sessions.add("sid_old")
|
||||
assert "sid_old" in idle_store._pre_flushed_sessions
|
||||
assert "sid_other" not in idle_store._pre_flushed_sessions
|
||||
|
||||
def test_discard_on_reset(self, idle_store):
|
||||
"""discard should remove without raising if not present."""
|
||||
idle_store._pre_flushed_sessions.add("sid_a")
|
||||
idle_store._pre_flushed_sessions.discard("sid_a")
|
||||
assert "sid_a" not in idle_store._pre_flushed_sessions
|
||||
# discard on non-existent should not raise
|
||||
idle_store._pre_flushed_sessions.discard("sid_nonexistent")
|
||||
335
tests/gateway/test_send_image_file.py
Normal file
335
tests/gateway/test_send_image_file.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Tests for send_image_file() on Telegram, Discord, and Slack platforms,
|
||||
and MEDIA: .png extraction/routing in the base platform adapter.
|
||||
|
||||
Covers: local image file sending, file-not-found handling, fallback on error,
|
||||
MEDIA: tag extraction for image extensions, and routing to send_image_file.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MEDIA: extraction tests for image files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractMediaImages:
|
||||
"""Test that MEDIA: tags with image extensions are correctly extracted."""
|
||||
|
||||
def test_png_image_extracted(self):
|
||||
content = "Here is the screenshot:\nMEDIA:/home/user/.hermes/browser_screenshots/shot.png"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert media[0][0] == "/home/user/.hermes/browser_screenshots/shot.png"
|
||||
assert "MEDIA:" not in cleaned
|
||||
assert "Here is the screenshot" in cleaned
|
||||
|
||||
def test_jpg_image_extracted(self):
|
||||
content = "MEDIA:/tmp/photo.jpg"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert media[0][0] == "/tmp/photo.jpg"
|
||||
|
||||
def test_webp_image_extracted(self):
|
||||
content = "MEDIA:/tmp/image.webp"
|
||||
media, _ = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
|
||||
def test_mixed_audio_and_image(self):
|
||||
content = "MEDIA:/audio.ogg\nMEDIA:/screenshot.png"
|
||||
media, _ = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 2
|
||||
paths = [m[0] for m in media]
|
||||
assert "/audio.ogg" in paths
|
||||
assert "/screenshot.png" in paths
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestTelegramSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = TelegramAdapter(config)
|
||||
a._bot = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_as_photo(self, adapter, tmp_path):
|
||||
"""send_image_file should call bot.send_photo with the opened file."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) # Minimal PNG-like
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 42
|
||||
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="12345", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
assert result.message_id == "42"
|
||||
adapter._bot.send_photo.assert_awaited_once()
|
||||
|
||||
# Verify photo arg was a file object (opened in rb mode)
|
||||
call_kwargs = adapter._bot.send_photo.call_args
|
||||
assert call_kwargs.kwargs["chat_id"] == 12345
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
"""send_image_file should return error for nonexistent file."""
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="12345", image_path="/nonexistent/image.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
"""send_image_file should return error when bot is None."""
|
||||
adapter._bot = None
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="12345", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
def test_caption_truncated_to_1024(self, adapter, tmp_path):
|
||||
"""Telegram captions have a 1024 char limit."""
|
||||
img = tmp_path / "shot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 1
|
||||
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
|
||||
|
||||
long_caption = "A" * 2000
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="12345", image_path=str(img), caption=long_caption)
|
||||
)
|
||||
|
||||
call_kwargs = adapter._bot.send_photo.call_args.kwargs
|
||||
assert len(call_kwargs["caption"]) == 1024
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install mock discord module so DiscordAdapter can be imported."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
|
||||
for name in ("discord", "discord.ext", "discord.ext.commands"):
|
||||
sys.modules.setdefault(name, discord_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import discord as discord_mod_ref # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestDiscordSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = DiscordAdapter(config)
|
||||
a._client = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_as_attachment(self, adapter, tmp_path):
|
||||
"""send_image_file should create discord.File and send to channel."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 99
|
||||
mock_channel.send = AsyncMock(return_value=mock_msg)
|
||||
adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="67890", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
assert result.message_id == "99"
|
||||
mock_channel.send.assert_awaited_once()
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="67890", image_path="/nonexistent.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
adapter._client = None
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="67890", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
def test_handles_missing_channel(self, adapter):
|
||||
adapter._client.get_channel = MagicMock(return_value=None)
|
||||
adapter._client.fetch_channel = AsyncMock(return_value=None)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="99999", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slack send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_slack_mock():
|
||||
"""Install mock slack_bolt module so SlackAdapter can be imported."""
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return
|
||||
|
||||
slack_mod = MagicMock()
|
||||
for name in ("slack_bolt", "slack_bolt.async_app", "slack_sdk", "slack_sdk.web.async_client"):
|
||||
sys.modules.setdefault(name, slack_mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestSlackSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
a = SlackAdapter(config)
|
||||
a._app = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_via_upload(self, adapter, tmp_path):
|
||||
"""send_image_file should call files_upload_v2 with the local path."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_result = MagicMock()
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="C12345", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
adapter._app.client.files_upload_v2.assert_awaited_once()
|
||||
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args.kwargs
|
||||
assert call_kwargs["file"] == str(img)
|
||||
assert call_kwargs["filename"] == "screenshot.png"
|
||||
assert call_kwargs["channel"] == "C12345"
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="C12345", image_path="/nonexistent.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="C12345", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# browser_vision screenshot cleanup tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScreenshotCleanup:
|
||||
def test_cleanup_removes_old_screenshots(self, tmp_path):
|
||||
"""_cleanup_old_screenshots should remove files older than max_age_hours."""
|
||||
import time
|
||||
from tools.browser_tool import _cleanup_old_screenshots
|
||||
|
||||
# Create a "fresh" file
|
||||
fresh = tmp_path / "browser_screenshot_fresh.png"
|
||||
fresh.write_bytes(b"new")
|
||||
|
||||
# Create an "old" file and backdate its mtime
|
||||
old = tmp_path / "browser_screenshot_old.png"
|
||||
old.write_bytes(b"old")
|
||||
old_time = time.time() - (25 * 3600) # 25 hours ago
|
||||
os.utime(str(old), (old_time, old_time))
|
||||
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
|
||||
|
||||
assert fresh.exists(), "Fresh screenshot should not be removed"
|
||||
assert not old.exists(), "Old screenshot should be removed"
|
||||
|
||||
def test_cleanup_ignores_non_screenshot_files(self, tmp_path):
|
||||
"""Only files matching browser_screenshot_*.png should be cleaned."""
|
||||
import time
|
||||
from tools.browser_tool import _cleanup_old_screenshots
|
||||
|
||||
other_file = tmp_path / "important_data.txt"
|
||||
other_file.write_bytes(b"keep me")
|
||||
old_time = time.time() - (48 * 3600)
|
||||
os.utime(str(other_file), (old_time, old_time))
|
||||
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
|
||||
|
||||
assert other_file.exists(), "Non-screenshot files should not be touched"
|
||||
|
||||
def test_cleanup_handles_empty_dir(self, tmp_path):
|
||||
"""Cleanup should not fail on empty directory."""
|
||||
from tools.browser_tool import _cleanup_old_screenshots
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24) # Should not raise
|
||||
|
||||
def test_cleanup_handles_nonexistent_dir(self):
|
||||
"""Cleanup should not fail if directory doesn't exist."""
|
||||
from pathlib import Path
|
||||
from tools.browser_tool import _cleanup_old_screenshots
|
||||
_cleanup_old_screenshots(Path("/nonexistent/dir"), max_age_hours=24) # Should not raise
|
||||
159
tests/gateway/test_session_hygiene.py
Normal file
159
tests/gateway/test_session_hygiene.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Tests for gateway session hygiene — auto-compression of large sessions.
|
||||
|
||||
Verifies that the gateway detects pathologically large transcripts and
|
||||
triggers auto-compression before running the agent. (#628)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from agent.model_metadata import estimate_messages_tokens_rough
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_history(n_messages: int, content_size: int = 100) -> list:
|
||||
"""Build a fake transcript with n_messages user/assistant pairs."""
|
||||
history = []
|
||||
content = "x" * content_size
|
||||
for i in range(n_messages):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
history.append({"role": role, "content": content, "timestamp": f"t{i}"})
|
||||
return history
|
||||
|
||||
|
||||
def _make_large_history_tokens(target_tokens: int) -> list:
|
||||
"""Build a history that estimates to roughly target_tokens tokens."""
|
||||
# estimate_messages_tokens_rough counts total chars in str(msg) // 4
|
||||
# Each msg dict has ~60 chars of overhead + content chars
|
||||
# So for N tokens we need roughly N * 4 total chars across all messages
|
||||
target_chars = target_tokens * 4
|
||||
# Each message as a dict string is roughly len(content) + 60 chars
|
||||
msg_overhead = 60
|
||||
# Use 50 messages with appropriately sized content
|
||||
n_msgs = 50
|
||||
content_size = max(10, (target_chars // n_msgs) - msg_overhead)
|
||||
return _make_history(n_msgs, content_size=content_size)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection threshold tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionHygieneThresholds:
|
||||
"""Test that the threshold logic correctly identifies large sessions."""
|
||||
|
||||
def test_small_session_below_thresholds(self):
|
||||
"""A 10-message session should not trigger compression."""
|
||||
history = _make_history(10)
|
||||
msg_count = len(history)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
compress_msg_threshold = 200
|
||||
|
||||
needs_compress = (
|
||||
approx_tokens >= compress_token_threshold
|
||||
or msg_count >= compress_msg_threshold
|
||||
)
|
||||
assert not needs_compress
|
||||
|
||||
def test_large_message_count_triggers(self):
|
||||
"""200+ messages should trigger compression even if tokens are low."""
|
||||
history = _make_history(250, content_size=10)
|
||||
msg_count = len(history)
|
||||
|
||||
compress_msg_threshold = 200
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert needs_compress
|
||||
|
||||
def test_large_token_count_triggers(self):
|
||||
"""High token count should trigger compression even if message count is low."""
|
||||
# 50 messages with huge content to exceed 100K tokens
|
||||
history = _make_history(50, content_size=10_000)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert needs_compress
|
||||
|
||||
def test_under_both_thresholds_no_trigger(self):
|
||||
"""Session under both thresholds should not trigger."""
|
||||
history = _make_history(100, content_size=100)
|
||||
msg_count = len(history)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
compress_msg_threshold = 200
|
||||
|
||||
needs_compress = (
|
||||
approx_tokens >= compress_token_threshold
|
||||
or msg_count >= compress_msg_threshold
|
||||
)
|
||||
assert not needs_compress
|
||||
|
||||
def test_custom_thresholds(self):
|
||||
"""Custom thresholds from config should be respected."""
|
||||
history = _make_history(60, content_size=100)
|
||||
msg_count = len(history)
|
||||
|
||||
# Custom lower threshold
|
||||
compress_msg_threshold = 50
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert needs_compress
|
||||
|
||||
# Custom higher threshold
|
||||
compress_msg_threshold = 100
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert not needs_compress
|
||||
|
||||
def test_minimum_message_guard(self):
|
||||
"""Sessions with fewer than 4 messages should never trigger."""
|
||||
history = _make_history(3, content_size=100_000)
|
||||
# Even with enormous content, < 4 messages should be skipped
|
||||
# (the gateway code checks `len(history) >= 4` before evaluating)
|
||||
assert len(history) < 4
|
||||
|
||||
|
||||
class TestSessionHygieneWarnThreshold:
|
||||
"""Test the post-compression warning threshold."""
|
||||
|
||||
def test_warn_when_still_large(self):
|
||||
"""If compressed result is still above warn_tokens, should warn."""
|
||||
# Simulate post-compression tokens
|
||||
warn_threshold = 200_000
|
||||
post_compress_tokens = 250_000
|
||||
assert post_compress_tokens >= warn_threshold
|
||||
|
||||
def test_no_warn_when_under(self):
|
||||
"""If compressed result is under warn_tokens, no warning."""
|
||||
warn_threshold = 200_000
|
||||
post_compress_tokens = 150_000
|
||||
assert post_compress_tokens < warn_threshold
|
||||
|
||||
|
||||
class TestTokenEstimation:
|
||||
"""Verify rough token estimation works as expected for hygiene checks."""
|
||||
|
||||
def test_empty_history(self):
|
||||
assert estimate_messages_tokens_rough([]) == 0
|
||||
|
||||
def test_proportional_to_content(self):
|
||||
small = _make_history(10, content_size=100)
|
||||
large = _make_history(10, content_size=10_000)
|
||||
assert estimate_messages_tokens_rough(large) > estimate_messages_tokens_rough(small)
|
||||
|
||||
def test_proportional_to_count(self):
|
||||
few = _make_history(10, content_size=1000)
|
||||
many = _make_history(100, content_size=1000)
|
||||
assert estimate_messages_tokens_rough(many) > estimate_messages_tokens_rough(few)
|
||||
|
||||
def test_pathological_session_detected(self):
|
||||
"""The reported pathological case: 648 messages, ~299K tokens."""
|
||||
# Simulate a 648-message session averaging ~460 tokens per message
|
||||
history = _make_history(648, content_size=1800)
|
||||
tokens = estimate_messages_tokens_rough(history)
|
||||
# Should be well above the 100K default threshold
|
||||
assert tokens > 100_000
|
||||
assert len(history) > 200
|
||||
207
tests/gateway/test_title_command.py
Normal file
207
tests/gateway/test_title_command.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Tests for /title gateway slash command.
|
||||
|
||||
Tests the _handle_title_command handler (set/show session titles)
|
||||
across all gateway messenger platforms.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_event(text="/title", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _make_runner(session_db=None):
|
||||
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._session_db = session_db
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
mock_session_entry = MagicMock()
|
||||
mock_session_entry.session_id = "test_session_123"
|
||||
mock_session_entry.session_key = "telegram:12345:67890"
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_or_create_session.return_value = mock_session_entry
|
||||
runner.session_store = mock_store
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_title_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleTitleCommand:
|
||||
"""Tests for GatewayRunner._handle_title_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_title(self, tmp_path):
|
||||
"""Setting a title returns confirmation."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title My Research Project")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "My Research Project" in result
|
||||
assert "✏️" in result
|
||||
|
||||
# Verify in DB
|
||||
assert db.get_session_title("test_session_123") == "My Research Project"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_title_when_set(self, tmp_path):
|
||||
"""Showing title when one is set returns the title."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
db.set_session_title("test_session_123", "Existing Title")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "Existing Title" in result
|
||||
assert "📌" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_title_when_not_set(self, tmp_path):
|
||||
"""Showing title when none is set returns usage hint."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "No title set" in result
|
||||
assert "/title" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_conflict(self, tmp_path):
|
||||
"""Setting a title already used by another session returns error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("other_session", "telegram")
|
||||
db.set_session_title("other_session", "Taken Title")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title Taken Title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "already in use" in result
|
||||
assert "⚠️" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_db(self):
|
||||
"""Returns error when session database is not available."""
|
||||
runner = _make_runner(session_db=None)
|
||||
event = _make_event(text="/title My Title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "not available" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_too_long(self, tmp_path):
|
||||
"""Setting a title that exceeds max length returns error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
long_title = "A" * 150
|
||||
event = _make_event(text=f"/title {long_title}")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "too long" in result
|
||||
assert "⚠️" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_control_chars_sanitized(self, tmp_path):
|
||||
"""Control characters are stripped and sanitized title is stored."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title hello\x00world")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "helloworld" in result
|
||||
assert db.get_session_title("test_session_123") == "helloworld"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_only_control_chars(self, tmp_path):
|
||||
"""Title with only control chars returns empty error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title \x00\x01\x02")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "empty after cleanup" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_works_across_platforms(self, tmp_path):
|
||||
"""The /title command works for Discord, Slack, and WhatsApp too."""
|
||||
from hermes_state import SessionDB
|
||||
for platform in [Platform.DISCORD, Platform.TELEGRAM]:
|
||||
db = SessionDB(db_path=tmp_path / f"state_{platform.value}.db")
|
||||
db.create_session("test_session_123", platform.value)
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title Cross-Platform Test", platform=platform)
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "Cross-Platform Test" in result
|
||||
assert db.get_session_title("test_session_123") == "Cross-Platform Test"
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /title in help and known_commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTitleInHelp:
|
||||
"""Verify /title appears in help text and known commands."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_in_help_output(self):
|
||||
"""The /help output includes /title."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/help")
|
||||
# Need hooks for help command
|
||||
from gateway.hooks import HookRegistry
|
||||
runner.hooks = HookRegistry()
|
||||
result = await runner._handle_help_command(event)
|
||||
assert "/title" in result
|
||||
|
||||
def test_title_is_known_command(self):
|
||||
"""The /title command is in the _known_commands set."""
|
||||
from gateway.run import GatewayRunner
|
||||
import inspect
|
||||
source = inspect.getsource(GatewayRunner._handle_message)
|
||||
assert '"title"' in source
|
||||
145
tests/hermes_cli/test_commands.py
Normal file
145
tests/hermes_cli/test_commands.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Tests for shared slash command definitions and autocomplete."""
|
||||
|
||||
from prompt_toolkit.completion import CompleteEvent
|
||||
from prompt_toolkit.document import Document
|
||||
|
||||
from hermes_cli.commands import COMMANDS, SlashCommandCompleter
|
||||
|
||||
|
||||
# All commands that must be present in the shared COMMANDS dict.
|
||||
EXPECTED_COMMANDS = {
|
||||
"/help", "/tools", "/toolsets", "/model", "/provider", "/prompt",
|
||||
"/personality", "/clear", "/history", "/new", "/reset", "/retry",
|
||||
"/undo", "/save", "/config", "/cron", "/skills", "/platforms",
|
||||
"/verbose", "/compress", "/title", "/usage", "/insights", "/paste",
|
||||
"/reload-mcp", "/quit",
|
||||
}
|
||||
|
||||
|
||||
def _completions(completer: SlashCommandCompleter, text: str):
|
||||
return list(
|
||||
completer.get_completions(
|
||||
Document(text=text),
|
||||
CompleteEvent(completion_requested=True),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestCommands:
|
||||
def test_shared_commands_include_cli_specific_entries(self):
|
||||
"""Entries that previously only existed in cli.py are now in the shared dict."""
|
||||
assert COMMANDS["/paste"] == "Check clipboard for an image and attach it"
|
||||
assert COMMANDS["/reload-mcp"] == "Reload MCP servers from config.yaml"
|
||||
|
||||
def test_all_expected_commands_present(self):
|
||||
"""Regression guard — every known command must appear in the shared dict."""
|
||||
assert set(COMMANDS.keys()) == EXPECTED_COMMANDS
|
||||
|
||||
def test_every_command_has_nonempty_description(self):
|
||||
for cmd, desc in COMMANDS.items():
|
||||
assert isinstance(desc, str) and len(desc) > 0, f"{cmd} has empty description"
|
||||
|
||||
|
||||
class TestSlashCommandCompleter:
|
||||
# -- basic prefix completion -----------------------------------------
|
||||
|
||||
def test_builtin_prefix_completion_uses_shared_registry(self):
|
||||
completions = _completions(SlashCommandCompleter(), "/re")
|
||||
texts = {item.text for item in completions}
|
||||
|
||||
assert "reset" in texts
|
||||
assert "retry" in texts
|
||||
assert "reload-mcp" in texts
|
||||
|
||||
def test_builtin_completion_display_meta_shows_description(self):
|
||||
completions = _completions(SlashCommandCompleter(), "/help")
|
||||
assert len(completions) == 1
|
||||
assert completions[0].display_meta_text == "Show this help message"
|
||||
|
||||
# -- exact-match trailing space --------------------------------------
|
||||
|
||||
def test_exact_match_completion_adds_trailing_space(self):
|
||||
completions = _completions(SlashCommandCompleter(), "/help")
|
||||
|
||||
assert [item.text for item in completions] == ["help "]
|
||||
|
||||
def test_partial_match_does_not_add_trailing_space(self):
|
||||
completions = _completions(SlashCommandCompleter(), "/hel")
|
||||
|
||||
assert [item.text for item in completions] == ["help"]
|
||||
|
||||
# -- non-slash input returns nothing ---------------------------------
|
||||
|
||||
def test_no_completions_for_non_slash_input(self):
|
||||
assert _completions(SlashCommandCompleter(), "help") == []
|
||||
|
||||
def test_no_completions_for_empty_input(self):
|
||||
assert _completions(SlashCommandCompleter(), "") == []
|
||||
|
||||
# -- skill commands via provider ------------------------------------
|
||||
|
||||
def test_skill_commands_are_completed_from_provider(self):
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: {
|
||||
"/gif-search": {"description": "Search for GIFs across providers"},
|
||||
}
|
||||
)
|
||||
|
||||
completions = _completions(completer, "/gif")
|
||||
|
||||
assert len(completions) == 1
|
||||
assert completions[0].text == "gif-search"
|
||||
assert completions[0].display_text == "/gif-search"
|
||||
assert completions[0].display_meta_text == "⚡ Search for GIFs across providers"
|
||||
|
||||
def test_skill_exact_match_adds_trailing_space(self):
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: {
|
||||
"/gif-search": {"description": "Search for GIFs"},
|
||||
}
|
||||
)
|
||||
|
||||
completions = _completions(completer, "/gif-search")
|
||||
|
||||
assert len(completions) == 1
|
||||
assert completions[0].text == "gif-search "
|
||||
|
||||
def test_no_skill_provider_means_no_skill_completions(self):
|
||||
"""Default (None) provider should not blow up or add completions."""
|
||||
completer = SlashCommandCompleter()
|
||||
completions = _completions(completer, "/gif")
|
||||
# /gif doesn't match any builtin command
|
||||
assert completions == []
|
||||
|
||||
def test_skill_provider_exception_is_swallowed(self):
|
||||
"""A broken provider should not crash autocomplete."""
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
# Should return builtin matches only, no crash
|
||||
completions = _completions(completer, "/he")
|
||||
texts = {item.text for item in completions}
|
||||
assert "help" in texts
|
||||
|
||||
def test_skill_description_truncated_at_50_chars(self):
|
||||
long_desc = "A" * 80
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: {
|
||||
"/long-skill": {"description": long_desc},
|
||||
}
|
||||
)
|
||||
completions = _completions(completer, "/long")
|
||||
assert len(completions) == 1
|
||||
meta = completions[0].display_meta_text
|
||||
# "⚡ " prefix + 50 chars + "..."
|
||||
assert meta == f"⚡ {'A' * 50}..."
|
||||
|
||||
def test_skill_missing_description_uses_fallback(self):
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: {
|
||||
"/no-desc": {},
|
||||
}
|
||||
)
|
||||
completions = _completions(completer, "/no-desc")
|
||||
assert len(completions) == 1
|
||||
assert "Skill command" in completions[0].display_meta_text
|
||||
17
tests/hermes_cli/test_doctor.py
Normal file
17
tests/hermes_cli/test_doctor.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Tests for hermes doctor helpers."""
|
||||
|
||||
from hermes_cli.doctor import _has_provider_env_config
|
||||
|
||||
|
||||
class TestProviderEnvDetection:
|
||||
def test_detects_openai_api_key(self):
|
||||
content = "OPENAI_BASE_URL=http://localhost:1234/v1\nOPENAI_API_KEY=sk-test-key\n"
|
||||
assert _has_provider_env_config(content)
|
||||
|
||||
def test_detects_custom_endpoint_without_openrouter_key(self):
|
||||
content = "OPENAI_BASE_URL=http://localhost:8080/v1\n"
|
||||
assert _has_provider_env_config(content)
|
||||
|
||||
def test_returns_false_when_no_provider_settings(self):
|
||||
content = "TERMINAL_ENV=local\n"
|
||||
assert not _has_provider_env_config(content)
|
||||
220
tests/hermes_cli/test_model_validation.py
Normal file
220
tests/hermes_cli/test_model_validation.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Tests for provider-aware `/model` validation in hermes_cli.models."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.models import (
|
||||
curated_models_for_provider,
|
||||
fetch_api_models,
|
||||
normalize_provider,
|
||||
parse_model_input,
|
||||
provider_model_ids,
|
||||
validate_requested_model,
|
||||
)
|
||||
|
||||
|
||||
# -- helpers -----------------------------------------------------------------
|
||||
|
||||
FAKE_API_MODELS = [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"openai/gpt-5.4-pro",
|
||||
"openai/gpt-5.4",
|
||||
"google/gemini-3-pro-preview",
|
||||
]
|
||||
|
||||
|
||||
def _validate(model, provider="openrouter", api_models=FAKE_API_MODELS, **kw):
|
||||
"""Shortcut: call validate_requested_model with mocked API."""
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=api_models):
|
||||
return validate_requested_model(model, provider, **kw)
|
||||
|
||||
|
||||
# -- parse_model_input -------------------------------------------------------
|
||||
|
||||
class TestParseModelInput:
|
||||
def test_plain_model_keeps_current_provider(self):
|
||||
provider, model = parse_model_input("anthropic/claude-sonnet-4.5", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == "anthropic/claude-sonnet-4.5"
|
||||
|
||||
def test_provider_colon_model_switches_provider(self):
|
||||
provider, model = parse_model_input("openrouter:anthropic/claude-sonnet-4.5", "nous")
|
||||
assert provider == "openrouter"
|
||||
assert model == "anthropic/claude-sonnet-4.5"
|
||||
|
||||
def test_provider_alias_resolved(self):
|
||||
provider, model = parse_model_input("glm:glm-5", "openrouter")
|
||||
assert provider == "zai"
|
||||
assert model == "glm-5"
|
||||
|
||||
def test_no_slash_no_colon_keeps_provider(self):
|
||||
provider, model = parse_model_input("gpt-5.4", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == "gpt-5.4"
|
||||
|
||||
def test_nous_provider_switch(self):
|
||||
provider, model = parse_model_input("nous:hermes-3", "openrouter")
|
||||
assert provider == "nous"
|
||||
assert model == "hermes-3"
|
||||
|
||||
def test_empty_model_after_colon_keeps_current(self):
|
||||
provider, model = parse_model_input("openrouter:", "nous")
|
||||
assert provider == "nous"
|
||||
assert model == "openrouter:"
|
||||
|
||||
def test_colon_at_start_keeps_current(self):
|
||||
provider, model = parse_model_input(":something", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == ":something"
|
||||
|
||||
def test_unknown_prefix_colon_not_treated_as_provider(self):
|
||||
"""Colons are only provider delimiters if the left side is a known provider."""
|
||||
provider, model = parse_model_input("anthropic/claude-3.5-sonnet:beta", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == "anthropic/claude-3.5-sonnet:beta"
|
||||
|
||||
def test_http_url_not_treated_as_provider(self):
|
||||
provider, model = parse_model_input("http://localhost:8080/model", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == "http://localhost:8080/model"
|
||||
|
||||
|
||||
# -- curated_models_for_provider ---------------------------------------------
|
||||
|
||||
class TestCuratedModelsForProvider:
|
||||
def test_openrouter_returns_curated_list(self):
|
||||
models = curated_models_for_provider("openrouter")
|
||||
assert len(models) > 0
|
||||
assert any("claude" in m[0] for m in models)
|
||||
|
||||
def test_zai_returns_glm_models(self):
|
||||
models = curated_models_for_provider("zai")
|
||||
assert any("glm" in m[0] for m in models)
|
||||
|
||||
def test_unknown_provider_returns_empty(self):
|
||||
assert curated_models_for_provider("totally-unknown") == []
|
||||
|
||||
|
||||
# -- normalize_provider ------------------------------------------------------
|
||||
|
||||
class TestNormalizeProvider:
|
||||
def test_defaults_to_openrouter(self):
|
||||
assert normalize_provider(None) == "openrouter"
|
||||
assert normalize_provider("") == "openrouter"
|
||||
|
||||
def test_known_aliases(self):
|
||||
assert normalize_provider("glm") == "zai"
|
||||
assert normalize_provider("kimi") == "kimi-coding"
|
||||
assert normalize_provider("moonshot") == "kimi-coding"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert normalize_provider("OpenRouter") == "openrouter"
|
||||
|
||||
|
||||
# -- provider_model_ids ------------------------------------------------------
|
||||
|
||||
class TestProviderModelIds:
|
||||
def test_openrouter_returns_curated_list(self):
|
||||
ids = provider_model_ids("openrouter")
|
||||
assert len(ids) > 0
|
||||
assert all("/" in mid for mid in ids)
|
||||
|
||||
def test_unknown_provider_returns_empty(self):
|
||||
assert provider_model_ids("some-unknown-provider") == []
|
||||
|
||||
def test_zai_returns_glm_models(self):
|
||||
assert "glm-5" in provider_model_ids("zai")
|
||||
|
||||
|
||||
# -- fetch_api_models --------------------------------------------------------
|
||||
|
||||
class TestFetchApiModels:
|
||||
def test_returns_none_when_no_base_url(self):
|
||||
assert fetch_api_models("key", None) is None
|
||||
|
||||
def test_returns_none_on_network_error(self):
|
||||
with patch("hermes_cli.models.urllib.request.urlopen", side_effect=Exception("timeout")):
|
||||
assert fetch_api_models("key", "https://example.com/v1") is None
|
||||
|
||||
|
||||
# -- validate — format checks -----------------------------------------------
|
||||
|
||||
class TestValidateFormatChecks:
|
||||
def test_empty_model_rejected(self):
|
||||
result = _validate("")
|
||||
assert result["accepted"] is False
|
||||
assert "empty" in result["message"]
|
||||
|
||||
def test_whitespace_only_rejected(self):
|
||||
result = _validate(" ")
|
||||
assert result["accepted"] is False
|
||||
|
||||
def test_model_with_spaces_rejected(self):
|
||||
result = _validate("anthropic/ claude-opus")
|
||||
assert result["accepted"] is False
|
||||
|
||||
def test_no_slash_model_still_probes_api(self):
|
||||
result = _validate("gpt-5.4", api_models=["gpt-5.4", "gpt-5.4-pro"])
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_no_slash_model_rejected_if_not_in_api(self):
|
||||
result = _validate("gpt-5.4", api_models=["openai/gpt-5.4"])
|
||||
assert result["accepted"] is False
|
||||
|
||||
|
||||
# -- validate — API found ----------------------------------------------------
|
||||
|
||||
class TestValidateApiFound:
|
||||
def test_model_found_in_api(self):
|
||||
result = _validate("anthropic/claude-opus-4.6")
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["recognized"] is True
|
||||
|
||||
def test_model_found_for_custom_endpoint(self):
|
||||
result = _validate(
|
||||
"my-model", provider="openrouter",
|
||||
api_models=["my-model"], base_url="http://localhost:11434/v1",
|
||||
)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
|
||||
# -- validate — API not found ------------------------------------------------
|
||||
|
||||
class TestValidateApiNotFound:
|
||||
def test_model_not_in_api_rejected(self):
|
||||
result = _validate("anthropic/claude-nonexistent")
|
||||
assert result["accepted"] is False
|
||||
assert "not a valid model" in result["message"]
|
||||
|
||||
def test_rejection_includes_suggestions(self):
|
||||
result = _validate("anthropic/claude-opus-4.5")
|
||||
assert result["accepted"] is False
|
||||
assert "Did you mean" in result["message"]
|
||||
|
||||
|
||||
# -- validate — API unreachable (fallback) -----------------------------------
|
||||
|
||||
class TestValidateApiFallback:
|
||||
def test_known_catalog_model_accepted_when_api_down(self):
|
||||
result = _validate("anthropic/claude-opus-4.6", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_unknown_model_session_only_when_api_down(self):
|
||||
result = _validate("anthropic/claude-next-gen", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is False
|
||||
assert "session only" in result["message"].lower()
|
||||
|
||||
def test_zai_known_model_accepted_when_api_down(self):
|
||||
result = _validate("glm-5", provider="zai", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_unknown_provider_session_only_when_api_down(self):
|
||||
result = _validate("some-model", provider="totally-unknown", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is False
|
||||
19
tests/hermes_cli/test_tools_config.py
Normal file
19
tests/hermes_cli/test_tools_config.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Tests for hermes_cli.tools_config platform tool persistence."""
|
||||
|
||||
from hermes_cli.tools_config import _get_platform_tools
|
||||
|
||||
|
||||
def test_get_platform_tools_uses_default_when_platform_not_configured():
|
||||
config = {}
|
||||
|
||||
enabled = _get_platform_tools(config, "cli")
|
||||
|
||||
assert enabled
|
||||
|
||||
|
||||
def test_get_platform_tools_preserves_explicit_empty_selection():
|
||||
config = {"platform_toolsets": {"cli": []}}
|
||||
|
||||
enabled = _get_platform_tools(config, "cli")
|
||||
|
||||
assert enabled == set()
|
||||
428
tests/test_api_key_providers.py
Normal file
428
tests/test_api_key_providers.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""Tests for API-key provider support (z.ai/GLM, Kimi, MiniMax)."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure dotenv doesn't interfere
|
||||
if "dotenv" not in sys.modules:
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
sys.modules["dotenv"] = fake_dotenv
|
||||
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY,
|
||||
ProviderConfig,
|
||||
resolve_provider,
|
||||
get_api_key_provider_status,
|
||||
resolve_api_key_provider_credentials,
|
||||
get_auth_status,
|
||||
AuthError,
|
||||
KIMI_CODE_BASE_URL,
|
||||
_resolve_kimi_base_url,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider Registry tests
|
||||
# =============================================================================
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Test that new providers are correctly registered."""
|
||||
|
||||
@pytest.mark.parametrize("provider_id,name,auth_type", [
|
||||
("zai", "Z.AI / GLM", "api_key"),
|
||||
("kimi-coding", "Kimi / Moonshot", "api_key"),
|
||||
("minimax", "MiniMax", "api_key"),
|
||||
("minimax-cn", "MiniMax (China)", "api_key"),
|
||||
])
|
||||
def test_provider_registered(self, provider_id, name, auth_type):
|
||||
assert provider_id in PROVIDER_REGISTRY
|
||||
pconfig = PROVIDER_REGISTRY[provider_id]
|
||||
assert pconfig.name == name
|
||||
assert pconfig.auth_type == auth_type
|
||||
assert pconfig.inference_base_url # must have a default base URL
|
||||
|
||||
def test_zai_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["zai"]
|
||||
assert pconfig.api_key_env_vars == ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY")
|
||||
assert pconfig.base_url_env_var == "GLM_BASE_URL"
|
||||
|
||||
def test_kimi_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["kimi-coding"]
|
||||
assert pconfig.api_key_env_vars == ("KIMI_API_KEY",)
|
||||
assert pconfig.base_url_env_var == "KIMI_BASE_URL"
|
||||
|
||||
def test_minimax_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["minimax"]
|
||||
assert pconfig.api_key_env_vars == ("MINIMAX_API_KEY",)
|
||||
assert pconfig.base_url_env_var == "MINIMAX_BASE_URL"
|
||||
|
||||
def test_minimax_cn_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["minimax-cn"]
|
||||
assert pconfig.api_key_env_vars == ("MINIMAX_CN_API_KEY",)
|
||||
assert pconfig.base_url_env_var == "MINIMAX_CN_BASE_URL"
|
||||
|
||||
def test_base_urls(self):
|
||||
assert PROVIDER_REGISTRY["zai"].inference_base_url == "https://api.z.ai/api/paas/v4"
|
||||
assert PROVIDER_REGISTRY["kimi-coding"].inference_base_url == "https://api.moonshot.ai/v1"
|
||||
assert PROVIDER_REGISTRY["minimax"].inference_base_url == "https://api.minimax.io/v1"
|
||||
assert PROVIDER_REGISTRY["minimax-cn"].inference_base_url == "https://api.minimaxi.com/v1"
|
||||
|
||||
def test_oauth_providers_unchanged(self):
|
||||
"""Ensure we didn't break the existing OAuth providers."""
|
||||
assert "nous" in PROVIDER_REGISTRY
|
||||
assert PROVIDER_REGISTRY["nous"].auth_type == "oauth_device_code"
|
||||
assert "openai-codex" in PROVIDER_REGISTRY
|
||||
assert PROVIDER_REGISTRY["openai-codex"].auth_type == "oauth_external"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider Resolution tests
|
||||
# =============================================================================
|
||||
|
||||
PROVIDER_ENV_VARS = (
|
||||
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY",
|
||||
"KIMI_API_KEY", "KIMI_BASE_URL", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY",
|
||||
"OPENAI_BASE_URL",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_provider_env(monkeypatch):
|
||||
for key in PROVIDER_ENV_VARS:
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
class TestResolveProvider:
|
||||
"""Test resolve_provider() with new providers."""
|
||||
|
||||
def test_explicit_zai(self):
|
||||
assert resolve_provider("zai") == "zai"
|
||||
|
||||
def test_explicit_kimi_coding(self):
|
||||
assert resolve_provider("kimi-coding") == "kimi-coding"
|
||||
|
||||
def test_explicit_minimax(self):
|
||||
assert resolve_provider("minimax") == "minimax"
|
||||
|
||||
def test_explicit_minimax_cn(self):
|
||||
assert resolve_provider("minimax-cn") == "minimax-cn"
|
||||
|
||||
def test_alias_glm(self):
|
||||
assert resolve_provider("glm") == "zai"
|
||||
|
||||
def test_alias_z_ai(self):
|
||||
assert resolve_provider("z-ai") == "zai"
|
||||
|
||||
def test_alias_zhipu(self):
|
||||
assert resolve_provider("zhipu") == "zai"
|
||||
|
||||
def test_alias_kimi(self):
|
||||
assert resolve_provider("kimi") == "kimi-coding"
|
||||
|
||||
def test_alias_moonshot(self):
|
||||
assert resolve_provider("moonshot") == "kimi-coding"
|
||||
|
||||
def test_alias_minimax_underscore(self):
|
||||
assert resolve_provider("minimax_cn") == "minimax-cn"
|
||||
|
||||
def test_alias_case_insensitive(self):
|
||||
assert resolve_provider("GLM") == "zai"
|
||||
assert resolve_provider("Z-AI") == "zai"
|
||||
assert resolve_provider("Kimi") == "kimi-coding"
|
||||
|
||||
def test_unknown_provider_raises(self):
|
||||
with pytest.raises(AuthError):
|
||||
resolve_provider("nonexistent-provider-xyz")
|
||||
|
||||
def test_auto_detects_glm_key(self, monkeypatch):
|
||||
monkeypatch.setenv("GLM_API_KEY", "test-glm-key")
|
||||
assert resolve_provider("auto") == "zai"
|
||||
|
||||
def test_auto_detects_zai_key(self, monkeypatch):
|
||||
monkeypatch.setenv("ZAI_API_KEY", "test-zai-key")
|
||||
assert resolve_provider("auto") == "zai"
|
||||
|
||||
def test_auto_detects_z_ai_key(self, monkeypatch):
|
||||
monkeypatch.setenv("Z_AI_API_KEY", "test-z-ai-key")
|
||||
assert resolve_provider("auto") == "zai"
|
||||
|
||||
def test_auto_detects_kimi_key(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "test-kimi-key")
|
||||
assert resolve_provider("auto") == "kimi-coding"
|
||||
|
||||
def test_auto_detects_minimax_key(self, monkeypatch):
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-mm-key")
|
||||
assert resolve_provider("auto") == "minimax"
|
||||
|
||||
def test_auto_detects_minimax_cn_key(self, monkeypatch):
|
||||
monkeypatch.setenv("MINIMAX_CN_API_KEY", "test-mm-cn-key")
|
||||
assert resolve_provider("auto") == "minimax-cn"
|
||||
|
||||
def test_openrouter_takes_priority_over_glm(self, monkeypatch):
|
||||
"""OpenRouter API key should win over GLM in auto-detection."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("GLM_API_KEY", "glm-key")
|
||||
assert resolve_provider("auto") == "openrouter"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Key Provider Status tests
|
||||
# =============================================================================
|
||||
|
||||
class TestApiKeyProviderStatus:
|
||||
|
||||
def test_unconfigured_provider(self):
|
||||
status = get_api_key_provider_status("zai")
|
||||
assert status["configured"] is False
|
||||
assert status["logged_in"] is False
|
||||
|
||||
def test_configured_provider(self, monkeypatch):
|
||||
monkeypatch.setenv("GLM_API_KEY", "test-key-123")
|
||||
status = get_api_key_provider_status("zai")
|
||||
assert status["configured"] is True
|
||||
assert status["logged_in"] is True
|
||||
assert status["key_source"] == "GLM_API_KEY"
|
||||
assert "z.ai" in status["base_url"].lower() or "api.z.ai" in status["base_url"]
|
||||
|
||||
def test_fallback_env_var(self, monkeypatch):
|
||||
"""ZAI_API_KEY should work when GLM_API_KEY is not set."""
|
||||
monkeypatch.setenv("ZAI_API_KEY", "zai-fallback-key")
|
||||
status = get_api_key_provider_status("zai")
|
||||
assert status["configured"] is True
|
||||
assert status["key_source"] == "ZAI_API_KEY"
|
||||
|
||||
def test_custom_base_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "kimi-key")
|
||||
monkeypatch.setenv("KIMI_BASE_URL", "https://custom.kimi.example/v1")
|
||||
status = get_api_key_provider_status("kimi-coding")
|
||||
assert status["base_url"] == "https://custom.kimi.example/v1"
|
||||
|
||||
def test_get_auth_status_dispatches_to_api_key(self, monkeypatch):
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "mm-key")
|
||||
status = get_auth_status("minimax")
|
||||
assert status["configured"] is True
|
||||
assert status["provider"] == "minimax"
|
||||
|
||||
def test_non_api_key_provider(self):
|
||||
status = get_api_key_provider_status("nous")
|
||||
assert status["configured"] is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Credential Resolution tests
|
||||
# =============================================================================
|
||||
|
||||
class TestResolveApiKeyProviderCredentials:
|
||||
|
||||
def test_resolve_zai_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("GLM_API_KEY", "glm-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["provider"] == "zai"
|
||||
assert creds["api_key"] == "glm-secret-key"
|
||||
assert creds["base_url"] == "https://api.z.ai/api/paas/v4"
|
||||
assert creds["source"] == "GLM_API_KEY"
|
||||
|
||||
def test_resolve_kimi_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "kimi-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("kimi-coding")
|
||||
assert creds["provider"] == "kimi-coding"
|
||||
assert creds["api_key"] == "kimi-secret-key"
|
||||
assert creds["base_url"] == "https://api.moonshot.ai/v1"
|
||||
|
||||
def test_resolve_minimax_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "mm-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("minimax")
|
||||
assert creds["provider"] == "minimax"
|
||||
assert creds["api_key"] == "mm-secret-key"
|
||||
assert creds["base_url"] == "https://api.minimax.io/v1"
|
||||
|
||||
def test_resolve_minimax_cn_with_key(self, monkeypatch):
|
||||
monkeypatch.setenv("MINIMAX_CN_API_KEY", "mmcn-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("minimax-cn")
|
||||
assert creds["provider"] == "minimax-cn"
|
||||
assert creds["api_key"] == "mmcn-secret-key"
|
||||
assert creds["base_url"] == "https://api.minimaxi.com/v1"
|
||||
|
||||
def test_resolve_with_custom_base_url(self, monkeypatch):
|
||||
monkeypatch.setenv("GLM_API_KEY", "glm-key")
|
||||
monkeypatch.setenv("GLM_BASE_URL", "https://custom.glm.example/v4")
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["base_url"] == "https://custom.glm.example/v4"
|
||||
|
||||
def test_resolve_without_key_returns_empty(self):
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["api_key"] == ""
|
||||
assert creds["source"] == "default"
|
||||
|
||||
def test_resolve_invalid_provider_raises(self):
|
||||
with pytest.raises(AuthError):
|
||||
resolve_api_key_provider_credentials("nous")
|
||||
|
||||
def test_glm_key_priority(self, monkeypatch):
|
||||
"""GLM_API_KEY takes priority over ZAI_API_KEY."""
|
||||
monkeypatch.setenv("GLM_API_KEY", "primary")
|
||||
monkeypatch.setenv("ZAI_API_KEY", "secondary")
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["api_key"] == "primary"
|
||||
assert creds["source"] == "GLM_API_KEY"
|
||||
|
||||
def test_zai_key_fallback(self, monkeypatch):
|
||||
"""ZAI_API_KEY used when GLM_API_KEY not set."""
|
||||
monkeypatch.setenv("ZAI_API_KEY", "secondary")
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["api_key"] == "secondary"
|
||||
assert creds["source"] == "ZAI_API_KEY"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Runtime Provider Resolution tests
|
||||
# =============================================================================
|
||||
|
||||
class TestRuntimeProviderResolution:
|
||||
|
||||
def test_runtime_zai(self, monkeypatch):
|
||||
monkeypatch.setenv("GLM_API_KEY", "glm-key")
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
result = resolve_runtime_provider(requested="zai")
|
||||
assert result["provider"] == "zai"
|
||||
assert result["api_mode"] == "chat_completions"
|
||||
assert result["api_key"] == "glm-key"
|
||||
assert "z.ai" in result["base_url"] or "api.z.ai" in result["base_url"]
|
||||
|
||||
def test_runtime_kimi(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "kimi-key")
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
result = resolve_runtime_provider(requested="kimi-coding")
|
||||
assert result["provider"] == "kimi-coding"
|
||||
assert result["api_mode"] == "chat_completions"
|
||||
assert result["api_key"] == "kimi-key"
|
||||
|
||||
def test_runtime_minimax(self, monkeypatch):
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "mm-key")
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
result = resolve_runtime_provider(requested="minimax")
|
||||
assert result["provider"] == "minimax"
|
||||
assert result["api_key"] == "mm-key"
|
||||
|
||||
def test_runtime_auto_detects_api_key_provider(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "auto-kimi-key")
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
result = resolve_runtime_provider(requested="auto")
|
||||
assert result["provider"] == "kimi-coding"
|
||||
assert result["api_key"] == "auto-kimi-key"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _has_any_provider_configured tests
|
||||
# =============================================================================
|
||||
|
||||
class TestHasAnyProviderConfigured:
|
||||
|
||||
def test_glm_key_counts(self, monkeypatch, tmp_path):
|
||||
from hermes_cli import config as config_module
|
||||
monkeypatch.setenv("GLM_API_KEY", "test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is True
|
||||
|
||||
def test_minimax_key_counts(self, monkeypatch, tmp_path):
|
||||
from hermes_cli import config as config_module
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code auto-detection tests
|
||||
# =============================================================================
|
||||
|
||||
MOONSHOT_DEFAULT_URL = "https://api.moonshot.ai/v1"
|
||||
|
||||
|
||||
class TestResolveKimiBaseUrl:
|
||||
"""Test _resolve_kimi_base_url() helper for key-prefix auto-detection."""
|
||||
|
||||
def test_sk_kimi_prefix_routes_to_kimi_code(self):
|
||||
url = _resolve_kimi_base_url("sk-kimi-abc123", MOONSHOT_DEFAULT_URL, "")
|
||||
assert url == KIMI_CODE_BASE_URL
|
||||
|
||||
def test_legacy_key_uses_default(self):
|
||||
url = _resolve_kimi_base_url("sk-abc123", MOONSHOT_DEFAULT_URL, "")
|
||||
assert url == MOONSHOT_DEFAULT_URL
|
||||
|
||||
def test_empty_key_uses_default(self):
|
||||
url = _resolve_kimi_base_url("", MOONSHOT_DEFAULT_URL, "")
|
||||
assert url == MOONSHOT_DEFAULT_URL
|
||||
|
||||
def test_env_override_wins_over_sk_kimi(self):
|
||||
"""KIMI_BASE_URL env var should always take priority."""
|
||||
custom = "https://custom.example.com/v1"
|
||||
url = _resolve_kimi_base_url("sk-kimi-abc123", MOONSHOT_DEFAULT_URL, custom)
|
||||
assert url == custom
|
||||
|
||||
def test_env_override_wins_over_legacy(self):
|
||||
custom = "https://custom.example.com/v1"
|
||||
url = _resolve_kimi_base_url("sk-abc123", MOONSHOT_DEFAULT_URL, custom)
|
||||
assert url == custom
|
||||
|
||||
|
||||
class TestKimiCodeStatusAutoDetect:
|
||||
"""Test that get_api_key_provider_status auto-detects sk-kimi- keys."""
|
||||
|
||||
def test_sk_kimi_key_gets_kimi_code_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-test-key-123")
|
||||
status = get_api_key_provider_status("kimi-coding")
|
||||
assert status["configured"] is True
|
||||
assert status["base_url"] == KIMI_CODE_BASE_URL
|
||||
|
||||
def test_legacy_key_gets_moonshot_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-legacy-test-key")
|
||||
status = get_api_key_provider_status("kimi-coding")
|
||||
assert status["configured"] is True
|
||||
assert status["base_url"] == MOONSHOT_DEFAULT_URL
|
||||
|
||||
def test_env_override_wins(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-test-key")
|
||||
monkeypatch.setenv("KIMI_BASE_URL", "https://override.example/v1")
|
||||
status = get_api_key_provider_status("kimi-coding")
|
||||
assert status["base_url"] == "https://override.example/v1"
|
||||
|
||||
|
||||
class TestKimiCodeCredentialAutoDetect:
|
||||
"""Test that resolve_api_key_provider_credentials auto-detects sk-kimi- keys."""
|
||||
|
||||
def test_sk_kimi_key_gets_kimi_code_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("kimi-coding")
|
||||
assert creds["api_key"] == "sk-kimi-secret-key"
|
||||
assert creds["base_url"] == KIMI_CODE_BASE_URL
|
||||
|
||||
def test_legacy_key_gets_moonshot_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-legacy-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("kimi-coding")
|
||||
assert creds["api_key"] == "sk-legacy-secret-key"
|
||||
assert creds["base_url"] == MOONSHOT_DEFAULT_URL
|
||||
|
||||
def test_env_override_wins(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-secret-key")
|
||||
monkeypatch.setenv("KIMI_BASE_URL", "https://override.example/v1")
|
||||
creds = resolve_api_key_provider_credentials("kimi-coding")
|
||||
assert creds["base_url"] == "https://override.example/v1"
|
||||
|
||||
def test_non_kimi_providers_unaffected(self, monkeypatch):
|
||||
"""Ensure the auto-detect logic doesn't leak to other providers."""
|
||||
monkeypatch.setenv("GLM_API_KEY", "sk-kimi-looks-like-kimi-but-isnt")
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["base_url"] == "https://api.z.ai/api/paas/v4"
|
||||
@@ -3,17 +3,31 @@ that only manifest at runtime (not in mocked unit tests)."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
def _make_cli(**kwargs):
|
||||
def _make_cli(env_overrides=None, **kwargs):
|
||||
"""Create a HermesCLI instance with minimal mocking."""
|
||||
import cli as _cli_mod
|
||||
from cli import HermesCLI
|
||||
with patch("cli.get_tool_definitions", return_value=[]):
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": "all"},
|
||||
"agent": {},
|
||||
"terminal": {"env_type": "local"},
|
||||
}
|
||||
clean_env = {"LLM_MODEL": "", "HERMES_MAX_ITERATIONS": ""}
|
||||
if env_overrides:
|
||||
clean_env.update(env_overrides)
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch.dict("os.environ", clean_env, clear=False), \
|
||||
patch.dict(_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}):
|
||||
return HermesCLI(**kwargs)
|
||||
|
||||
|
||||
@@ -23,7 +37,7 @@ class TestMaxTurnsResolution:
|
||||
def test_default_max_turns_is_integer(self):
|
||||
cli = _make_cli()
|
||||
assert isinstance(cli.max_turns, int)
|
||||
assert cli.max_turns == 60
|
||||
assert cli.max_turns == 90
|
||||
|
||||
def test_explicit_max_turns_honored(self):
|
||||
cli = _make_cli(max_turns=25)
|
||||
@@ -32,29 +46,17 @@ class TestMaxTurnsResolution:
|
||||
def test_none_max_turns_gets_default(self):
|
||||
cli = _make_cli(max_turns=None)
|
||||
assert isinstance(cli.max_turns, int)
|
||||
assert cli.max_turns == 60
|
||||
assert cli.max_turns == 90
|
||||
|
||||
def test_env_var_max_turns(self, monkeypatch):
|
||||
def test_env_var_max_turns(self):
|
||||
"""Env var is used when config file doesn't set max_turns."""
|
||||
monkeypatch.setenv("HERMES_MAX_ITERATIONS", "42")
|
||||
import cli as cli_module
|
||||
original_agent = cli_module.CLI_CONFIG["agent"].get("max_turns")
|
||||
original_root = cli_module.CLI_CONFIG.get("max_turns")
|
||||
cli_module.CLI_CONFIG["agent"]["max_turns"] = None
|
||||
cli_module.CLI_CONFIG.pop("max_turns", None)
|
||||
try:
|
||||
cli_obj = _make_cli()
|
||||
assert cli_obj.max_turns == 42
|
||||
finally:
|
||||
if original_agent is not None:
|
||||
cli_module.CLI_CONFIG["agent"]["max_turns"] = original_agent
|
||||
if original_root is not None:
|
||||
cli_module.CLI_CONFIG["max_turns"] = original_root
|
||||
cli_obj = _make_cli(env_overrides={"HERMES_MAX_ITERATIONS": "42"})
|
||||
assert cli_obj.max_turns == 42
|
||||
|
||||
def test_max_turns_never_none_for_agent(self):
|
||||
"""The value passed to AIAgent must never be None (causes TypeError in run_conversation)."""
|
||||
cli = _make_cli()
|
||||
assert isinstance(cli.max_turns, int) and cli.max_turns == 60
|
||||
assert isinstance(cli.max_turns, int) and cli.max_turns == 90
|
||||
|
||||
|
||||
class TestVerboseAndToolProgress:
|
||||
@@ -68,6 +70,38 @@ class TestVerboseAndToolProgress:
|
||||
assert cli.tool_progress_mode in ("off", "new", "all", "verbose")
|
||||
|
||||
|
||||
class TestHistoryDisplay:
|
||||
def test_history_numbers_only_visible_messages_and_summarizes_tools(self, capsys):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_1"}, {"id": "call_2"}],
|
||||
},
|
||||
{"role": "tool", "content": "tool output 1"},
|
||||
{"role": "tool", "content": "tool output 2"},
|
||||
{"role": "assistant", "content": "All set."},
|
||||
{"role": "user", "content": "A" * 250},
|
||||
]
|
||||
|
||||
cli.show_history()
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "[You #1]" in output
|
||||
assert "[Hermes #2]" in output
|
||||
assert "(requested 2 tool calls)" in output
|
||||
assert "[Tools]" in output
|
||||
assert "(2 tool messages hidden)" in output
|
||||
assert "[Hermes #3]" in output
|
||||
assert "[You #4]" in output
|
||||
assert "[You #5]" not in output
|
||||
assert "A" * 250 in output
|
||||
assert "A" * 250 + "..." not in output
|
||||
|
||||
|
||||
class TestProviderResolution:
|
||||
def test_api_key_is_string_or_none(self):
|
||||
cli = _make_cli()
|
||||
|
||||
133
tests/test_cli_model_command.py
Normal file
133
tests/test_cli_model_command.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Regression tests for the `/model` slash command in the interactive CLI."""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
class TestModelCommand:
|
||||
def _make_cli(self):
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.model = "anthropic/claude-opus-4.6"
|
||||
cli_obj.agent = object()
|
||||
cli_obj.provider = "openrouter"
|
||||
cli_obj.requested_provider = "openrouter"
|
||||
cli_obj.base_url = "https://openrouter.ai/api/v1"
|
||||
cli_obj.api_key = "test-key"
|
||||
cli_obj._explicit_api_key = None
|
||||
cli_obj._explicit_base_url = None
|
||||
return cli_obj
|
||||
|
||||
def test_valid_model_from_api_saved_to_config(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["anthropic/claude-sonnet-4.5", "openai/gpt-5.4"]), \
|
||||
patch("cli.save_config_value", return_value=True) as save_mock:
|
||||
cli_obj.process_command("/model anthropic/claude-sonnet-4.5")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "saved to config" in output
|
||||
assert cli_obj.model == "anthropic/claude-sonnet-4.5"
|
||||
save_mock.assert_called_once_with("model.default", "anthropic/claude-sonnet-4.5")
|
||||
|
||||
def test_invalid_model_from_api_is_rejected(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["anthropic/claude-opus-4.6"]), \
|
||||
patch("cli.save_config_value") as save_mock:
|
||||
cli_obj.process_command("/model anthropic/fake-model")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "not a valid model" in output
|
||||
assert "Model unchanged" in output
|
||||
assert cli_obj.model == "anthropic/claude-opus-4.6"
|
||||
save_mock.assert_not_called()
|
||||
|
||||
def test_api_unreachable_falls_back_session_only(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=None), \
|
||||
patch("cli.save_config_value") as save_mock:
|
||||
cli_obj.process_command("/model anthropic/claude-sonnet-next")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "session only" in output
|
||||
assert "will revert on restart" in output
|
||||
assert cli_obj.model == "anthropic/claude-sonnet-next"
|
||||
save_mock.assert_not_called()
|
||||
|
||||
def test_no_slash_model_probes_api_and_rejects(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["openai/gpt-5.4"]) as fetch_mock, \
|
||||
patch("cli.save_config_value") as save_mock:
|
||||
cli_obj.process_command("/model gpt-5.4")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "not a valid model" in output
|
||||
assert "Model unchanged" in output
|
||||
assert cli_obj.model == "anthropic/claude-opus-4.6" # unchanged
|
||||
assert cli_obj.agent is not None # not reset
|
||||
save_mock.assert_not_called()
|
||||
|
||||
def test_validation_crash_falls_back_to_save(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.validate_requested_model",
|
||||
side_effect=RuntimeError("boom")), \
|
||||
patch("cli.save_config_value", return_value=True) as save_mock:
|
||||
cli_obj.process_command("/model anthropic/claude-sonnet-4.5")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "saved to config" in output
|
||||
assert cli_obj.model == "anthropic/claude-sonnet-4.5"
|
||||
save_mock.assert_called_once()
|
||||
|
||||
def test_show_model_when_no_argument(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
cli_obj.process_command("/model")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "anthropic/claude-opus-4.6" in output
|
||||
assert "OpenRouter" in output
|
||||
assert "Available models" in output
|
||||
assert "provider:model-name" in output
|
||||
|
||||
# -- provider switching tests -------------------------------------------
|
||||
|
||||
def test_provider_colon_model_switches_provider(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.runtime_provider.resolve_runtime_provider", return_value={
|
||||
"provider": "zai",
|
||||
"api_key": "zai-key",
|
||||
"base_url": "https://api.z.ai/api/paas/v4",
|
||||
}), \
|
||||
patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["glm-5", "glm-4.7"]), \
|
||||
patch("cli.save_config_value", return_value=True) as save_mock:
|
||||
cli_obj.process_command("/model zai:glm-5")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "glm-5" in output
|
||||
assert "provider:" in output.lower() or "Z.AI" in output
|
||||
assert cli_obj.model == "glm-5"
|
||||
assert cli_obj.provider == "zai"
|
||||
assert cli_obj.base_url == "https://api.z.ai/api/paas/v4"
|
||||
# Both model and provider should be saved
|
||||
assert save_mock.call_count == 2
|
||||
|
||||
def test_provider_switch_fails_on_bad_credentials(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
side_effect=Exception("No API key found")):
|
||||
cli_obj.process_command("/model nous:hermes-3")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Could not resolve credentials" in output
|
||||
assert cli_obj.model == "anthropic/claude-opus-4.6" # unchanged
|
||||
assert cli_obj.provider == "openrouter" # unchanged
|
||||
@@ -162,6 +162,128 @@ def test_runtime_resolution_rebuilds_agent_on_routing_change(monkeypatch):
|
||||
assert shell.api_mode == "codex_responses"
|
||||
|
||||
|
||||
def test_codex_provider_replaces_incompatible_default_model(monkeypatch):
|
||||
"""When provider resolves to openai-codex and no model was explicitly
|
||||
chosen, the global config default (e.g. anthropic/claude-opus-4.6) must
|
||||
be replaced with a Codex-compatible model. Fixes #651."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
lambda access_token=None: ["gpt-5.2-codex", "gpt-5.1-codex-mini"],
|
||||
)
|
||||
|
||||
shell = cli.HermesCLI(compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is True
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.provider == "openai-codex"
|
||||
assert "anthropic" not in shell.model
|
||||
assert "claude" not in shell.model
|
||||
assert shell.model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
def test_codex_provider_replaces_incompatible_envvar_model(monkeypatch):
|
||||
"""Exact scenario from #651: LLM_MODEL is set to a non-Codex model and
|
||||
provider resolves to openai-codex. The model must be replaced and a
|
||||
warning printed since the user explicitly chose it."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.setenv("LLM_MODEL", "claude-opus-4-6")
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
lambda access_token=None: ["gpt-5.2-codex", "gpt-5.1-codex-mini"],
|
||||
)
|
||||
|
||||
shell = cli.HermesCLI(compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is False
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.provider == "openai-codex"
|
||||
assert "claude" not in shell.model
|
||||
assert shell.model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
def test_codex_provider_preserves_explicit_codex_model(monkeypatch):
|
||||
"""If the user explicitly passes a Codex-compatible model, it must be
|
||||
preserved even when the provider resolves to openai-codex."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
|
||||
shell = cli.HermesCLI(model="gpt-5.1-codex-mini", compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is False
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.model == "gpt-5.1-codex-mini"
|
||||
|
||||
|
||||
def test_codex_provider_strips_provider_prefix_from_model(monkeypatch):
|
||||
"""openai/gpt-5.3-codex should become gpt-5.3-codex — the Codex
|
||||
Responses API does not accept provider-prefixed model slugs."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
|
||||
shell = cli.HermesCLI(model="openai/gpt-5.3-codex", compact=True, max_turns=1)
|
||||
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.model == "gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
|
||||
@@ -351,6 +351,173 @@ class TestPruneSessions:
|
||||
# Schema and WAL mode
|
||||
# =========================================================================
|
||||
|
||||
# =========================================================================
|
||||
# Session title
|
||||
# =========================================================================
|
||||
|
||||
class TestSessionTitle:
|
||||
def test_set_and_get_title(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
assert db.set_session_title("s1", "My Session") is True
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "My Session"
|
||||
|
||||
def test_set_title_nonexistent_session(self, db):
|
||||
assert db.set_session_title("nonexistent", "Title") is False
|
||||
|
||||
def test_title_initially_none(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] is None
|
||||
|
||||
def test_update_title(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "First Title")
|
||||
db.set_session_title("s1", "Updated Title")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "Updated Title"
|
||||
|
||||
def test_title_in_search_sessions(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Debugging Auth")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
|
||||
sessions = db.search_sessions()
|
||||
titled = [s for s in sessions if s.get("title") == "Debugging Auth"]
|
||||
assert len(titled) == 1
|
||||
assert titled[0]["id"] == "s1"
|
||||
|
||||
def test_title_in_export(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Export Test")
|
||||
db.append_message("s1", role="user", content="Hello")
|
||||
|
||||
export = db.export_session("s1")
|
||||
assert export["title"] == "Export Test"
|
||||
|
||||
def test_title_with_special_characters(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
title = "PR #438 — fixing the 'auth' middleware"
|
||||
db.set_session_title("s1", title)
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == title
|
||||
|
||||
def test_title_empty_string_normalized_to_none(self, db):
|
||||
"""Empty strings are normalized to None (clearing the title)."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "My Title")
|
||||
# Setting to empty string should clear the title (normalize to None)
|
||||
db.set_session_title("s1", "")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] is None
|
||||
|
||||
def test_multiple_empty_titles_no_conflict(self, db):
|
||||
"""Multiple sessions can have empty-string (normalized to NULL) titles."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.set_session_title("s1", "")
|
||||
db.set_session_title("s2", "")
|
||||
# Both should be None, no uniqueness conflict
|
||||
assert db.get_session("s1")["title"] is None
|
||||
assert db.get_session("s2")["title"] is None
|
||||
|
||||
def test_title_survives_end_session(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Before End")
|
||||
db.end_session("s1", end_reason="user_exit")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "Before End"
|
||||
assert session["ended_at"] is not None
|
||||
|
||||
|
||||
class TestSanitizeTitle:
|
||||
"""Tests for SessionDB.sanitize_title() validation and cleaning."""
|
||||
|
||||
def test_normal_title_unchanged(self):
|
||||
assert SessionDB.sanitize_title("My Project") == "My Project"
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
assert SessionDB.sanitize_title(" hello world ") == "hello world"
|
||||
|
||||
def test_collapses_internal_whitespace(self):
|
||||
assert SessionDB.sanitize_title("hello world") == "hello world"
|
||||
|
||||
def test_tabs_and_newlines_collapsed(self):
|
||||
assert SessionDB.sanitize_title("hello\t\nworld") == "hello world"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert SessionDB.sanitize_title(None) is None
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
assert SessionDB.sanitize_title("") is None
|
||||
|
||||
def test_whitespace_only_returns_none(self):
|
||||
assert SessionDB.sanitize_title(" \t\n ") is None
|
||||
|
||||
def test_control_chars_stripped(self):
|
||||
# Null byte, bell, backspace, etc.
|
||||
assert SessionDB.sanitize_title("hello\x00world") == "helloworld"
|
||||
assert SessionDB.sanitize_title("\x07\x08test\x1b") == "test"
|
||||
|
||||
def test_del_char_stripped(self):
|
||||
assert SessionDB.sanitize_title("hello\x7fworld") == "helloworld"
|
||||
|
||||
def test_zero_width_chars_stripped(self):
|
||||
# Zero-width space (U+200B), zero-width joiner (U+200D)
|
||||
assert SessionDB.sanitize_title("hello\u200bworld") == "helloworld"
|
||||
assert SessionDB.sanitize_title("hello\u200dworld") == "helloworld"
|
||||
|
||||
def test_rtl_override_stripped(self):
|
||||
# Right-to-left override (U+202E) — used in filename spoofing attacks
|
||||
assert SessionDB.sanitize_title("hello\u202eworld") == "helloworld"
|
||||
|
||||
def test_bom_stripped(self):
|
||||
# Byte order mark (U+FEFF)
|
||||
assert SessionDB.sanitize_title("\ufeffhello") == "hello"
|
||||
|
||||
def test_only_control_chars_returns_none(self):
|
||||
assert SessionDB.sanitize_title("\x00\x01\x02\u200b\ufeff") is None
|
||||
|
||||
def test_max_length_allowed(self):
|
||||
title = "A" * 100
|
||||
assert SessionDB.sanitize_title(title) == title
|
||||
|
||||
def test_exceeds_max_length_raises(self):
|
||||
title = "A" * 101
|
||||
with pytest.raises(ValueError, match="too long"):
|
||||
SessionDB.sanitize_title(title)
|
||||
|
||||
def test_unicode_emoji_allowed(self):
|
||||
assert SessionDB.sanitize_title("🚀 My Project 🎉") == "🚀 My Project 🎉"
|
||||
|
||||
def test_cjk_characters_allowed(self):
|
||||
assert SessionDB.sanitize_title("我的项目") == "我的项目"
|
||||
|
||||
def test_accented_characters_allowed(self):
|
||||
assert SessionDB.sanitize_title("Résumé éditing") == "Résumé éditing"
|
||||
|
||||
def test_special_punctuation_allowed(self):
|
||||
title = "PR #438 — fixing the 'auth' middleware"
|
||||
assert SessionDB.sanitize_title(title) == title
|
||||
|
||||
def test_sanitize_applied_in_set_session_title(self, db):
|
||||
"""set_session_title applies sanitize_title internally."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", " hello\x00 world ")
|
||||
assert db.get_session("s1")["title"] == "hello world"
|
||||
|
||||
def test_too_long_title_rejected_by_set(self, db):
|
||||
"""set_session_title raises ValueError for overly long titles."""
|
||||
db.create_session("s1", "cli")
|
||||
with pytest.raises(ValueError, match="too long"):
|
||||
db.set_session_title("s1", "X" * 150)
|
||||
|
||||
|
||||
class TestSchemaInit:
|
||||
def test_wal_mode(self, db):
|
||||
cursor = db._conn.execute("PRAGMA journal_mode")
|
||||
@@ -373,4 +540,297 @@ class TestSchemaInit:
|
||||
def test_schema_version(self, db):
|
||||
cursor = db._conn.execute("SELECT version FROM schema_version")
|
||||
version = cursor.fetchone()[0]
|
||||
assert version == 2
|
||||
assert version == 4
|
||||
|
||||
def test_title_column_exists(self, db):
|
||||
"""Verify the title column was created in the sessions table."""
|
||||
cursor = db._conn.execute("PRAGMA table_info(sessions)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "title" in columns
|
||||
|
||||
def test_migration_from_v2(self, tmp_path):
|
||||
"""Simulate a v2 database and verify migration adds title column."""
|
||||
import sqlite3
|
||||
|
||||
db_path = tmp_path / "migrate_test.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
# Create v2 schema (without title column)
|
||||
conn.executescript("""
|
||||
CREATE TABLE schema_version (version INTEGER NOT NULL);
|
||||
INSERT INTO schema_version (version) VALUES (2);
|
||||
|
||||
CREATE TABLE sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
model TEXT,
|
||||
model_config TEXT,
|
||||
system_prompt TEXT,
|
||||
parent_session_id TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
ended_at REAL,
|
||||
end_reason TEXT,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT,
|
||||
tool_call_id TEXT,
|
||||
tool_calls TEXT,
|
||||
tool_name TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
token_count INTEGER,
|
||||
finish_reason TEXT
|
||||
);
|
||||
""")
|
||||
conn.execute(
|
||||
"INSERT INTO sessions (id, source, started_at) VALUES (?, ?, ?)",
|
||||
("existing", "cli", 1000.0),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Open with SessionDB — should migrate to v4
|
||||
migrated_db = SessionDB(db_path=db_path)
|
||||
|
||||
# Verify migration
|
||||
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 4
|
||||
|
||||
# Verify title column exists and is NULL for existing sessions
|
||||
session = migrated_db.get_session("existing")
|
||||
assert session is not None
|
||||
assert session["title"] is None
|
||||
|
||||
# Verify we can set title on migrated session
|
||||
assert migrated_db.set_session_title("existing", "Migrated Title") is True
|
||||
session = migrated_db.get_session("existing")
|
||||
assert session["title"] == "Migrated Title"
|
||||
|
||||
migrated_db.close()
|
||||
|
||||
|
||||
class TestTitleUniqueness:
|
||||
"""Tests for unique title enforcement and title-based lookups."""
|
||||
|
||||
def test_duplicate_title_raises(self, db):
|
||||
"""Setting a title already used by another session raises ValueError."""
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
with pytest.raises(ValueError, match="already in use"):
|
||||
db.set_session_title("s2", "my project")
|
||||
|
||||
def test_same_session_can_keep_title(self, db):
|
||||
"""A session can re-set its own title without error."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
# Should not raise — it's the same session
|
||||
assert db.set_session_title("s1", "my project") is True
|
||||
|
||||
def test_null_titles_not_unique(self, db):
|
||||
"""Multiple sessions can have NULL titles (no constraint violation)."""
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "cli")
|
||||
# Both have NULL titles — no error
|
||||
assert db.get_session("s1")["title"] is None
|
||||
assert db.get_session("s2")["title"] is None
|
||||
|
||||
def test_get_session_by_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "refactoring auth")
|
||||
result = db.get_session_by_title("refactoring auth")
|
||||
assert result is not None
|
||||
assert result["id"] == "s1"
|
||||
|
||||
def test_get_session_by_title_not_found(self, db):
|
||||
assert db.get_session_by_title("nonexistent") is None
|
||||
|
||||
def test_get_session_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
assert db.get_session_title("s1") is None
|
||||
db.set_session_title("s1", "my title")
|
||||
assert db.get_session_title("s1") == "my title"
|
||||
|
||||
def test_get_session_title_nonexistent(self, db):
|
||||
assert db.get_session_title("nonexistent") is None
|
||||
|
||||
|
||||
class TestTitleLineage:
|
||||
"""Tests for title lineage resolution and auto-numbering."""
|
||||
|
||||
def test_resolve_exact_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
assert db.resolve_session_by_title("my project") == "s1"
|
||||
|
||||
def test_resolve_returns_latest_numbered(self, db):
|
||||
"""When numbered variants exist, return the most recent one."""
|
||||
import time
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
time.sleep(0.01)
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
time.sleep(0.01)
|
||||
db.create_session("s3", "cli")
|
||||
db.set_session_title("s3", "my project #3")
|
||||
# Resolving "my project" should return s3 (latest numbered variant)
|
||||
assert db.resolve_session_by_title("my project") == "s3"
|
||||
|
||||
def test_resolve_exact_numbered(self, db):
|
||||
"""Resolving an exact numbered title returns that specific session."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
# Resolving "my project #2" exactly should return s2
|
||||
assert db.resolve_session_by_title("my project #2") == "s2"
|
||||
|
||||
def test_resolve_nonexistent_title(self, db):
|
||||
assert db.resolve_session_by_title("nonexistent") is None
|
||||
|
||||
def test_next_title_no_existing(self, db):
|
||||
"""With no existing sessions, base title is returned as-is."""
|
||||
assert db.get_next_title_in_lineage("my project") == "my project"
|
||||
|
||||
def test_next_title_first_continuation(self, db):
|
||||
"""First continuation after the original gets #2."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
assert db.get_next_title_in_lineage("my project") == "my project #2"
|
||||
|
||||
def test_next_title_increments(self, db):
|
||||
"""Each continuation increments the number."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
db.create_session("s3", "cli")
|
||||
db.set_session_title("s3", "my project #3")
|
||||
assert db.get_next_title_in_lineage("my project") == "my project #4"
|
||||
|
||||
def test_next_title_strips_existing_number(self, db):
|
||||
"""Passing a numbered title strips the number and finds the base."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
# Even when called with "my project #2", it should return #3
|
||||
assert db.get_next_title_in_lineage("my project #2") == "my project #3"
|
||||
|
||||
|
||||
class TestTitleSqlWildcards:
|
||||
"""Titles containing SQL LIKE wildcards (%, _) must not cause false matches."""
|
||||
|
||||
def test_resolve_title_with_underscore(self, db):
|
||||
"""A title like 'test_project' should not match 'testXproject #2'."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "test_project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "testXproject #2")
|
||||
# Resolving "test_project" should return s1 (exact), not s2
|
||||
assert db.resolve_session_by_title("test_project") == "s1"
|
||||
|
||||
def test_resolve_title_with_percent(self, db):
|
||||
"""A title with '%' should not wildcard-match unrelated sessions."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "100% done")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "100X done #2")
|
||||
# Should resolve to s1 (exact), not s2
|
||||
assert db.resolve_session_by_title("100% done") == "s1"
|
||||
|
||||
def test_next_lineage_with_underscore(self, db):
|
||||
"""get_next_title_in_lineage with underscores doesn't match wrong sessions."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "test_project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "testXproject #2")
|
||||
# Only "test_project" exists, so next should be "test_project #2"
|
||||
assert db.get_next_title_in_lineage("test_project") == "test_project #2"
|
||||
|
||||
|
||||
class TestListSessionsRich:
|
||||
"""Tests for enhanced session listing with preview and last_active."""
|
||||
|
||||
def test_preview_from_first_user_message(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "system", "You are a helpful assistant.")
|
||||
db.append_message("s1", "user", "Help me refactor the auth module please")
|
||||
db.append_message("s1", "assistant", "Sure, let me look at it.")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert len(sessions) == 1
|
||||
assert "Help me refactor the auth module" in sessions[0]["preview"]
|
||||
|
||||
def test_preview_truncated_at_60(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
long_msg = "A" * 100
|
||||
db.append_message("s1", "user", long_msg)
|
||||
sessions = db.list_sessions_rich()
|
||||
assert len(sessions[0]["preview"]) == 63 # 60 chars + "..."
|
||||
assert sessions[0]["preview"].endswith("...")
|
||||
|
||||
def test_preview_empty_when_no_user_messages(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "system", "System prompt")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert sessions[0]["preview"] == ""
|
||||
|
||||
def test_last_active_from_latest_message(self, db):
|
||||
import time
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "user", "Hello")
|
||||
time.sleep(0.01)
|
||||
db.append_message("s1", "assistant", "Hi there!")
|
||||
sessions = db.list_sessions_rich()
|
||||
# last_active should be close to now (the assistant message)
|
||||
assert sessions[0]["last_active"] > sessions[0]["started_at"]
|
||||
|
||||
def test_last_active_fallback_to_started_at(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
sessions = db.list_sessions_rich()
|
||||
# No messages, so last_active falls back to started_at
|
||||
assert sessions[0]["last_active"] == sessions[0]["started_at"]
|
||||
|
||||
def test_rich_list_includes_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "refactoring auth")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert sessions[0]["title"] == "refactoring auth"
|
||||
|
||||
def test_rich_list_source_filter(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "telegram")
|
||||
sessions = db.list_sessions_rich(source="cli")
|
||||
assert len(sessions) == 1
|
||||
assert sessions[0]["id"] == "s1"
|
||||
|
||||
def test_preview_newlines_collapsed(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "user", "Line one\nLine two\nLine three")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert "\n" not in sessions[0]["preview"]
|
||||
assert "Line one Line two" in sessions[0]["preview"]
|
||||
|
||||
|
||||
class TestResolveSessionByNameOrId:
|
||||
"""Tests for the main.py helper that resolves names or IDs."""
|
||||
|
||||
def test_resolve_by_id(self, db):
|
||||
db.create_session("test-id-123", "cli")
|
||||
session = db.get_session("test-id-123")
|
||||
assert session is not None
|
||||
assert session["id"] == "test-id-123"
|
||||
|
||||
def test_resolve_by_title_falls_back(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
result = db.resolve_session_by_title("my project")
|
||||
assert result == "s1"
|
||||
|
||||
@@ -145,7 +145,7 @@ class TestBuildApiKwargsCodex:
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert "reasoning" in kwargs
|
||||
assert kwargs["reasoning"]["effort"] == "xhigh"
|
||||
assert kwargs["reasoning"]["effort"] == "medium"
|
||||
|
||||
def test_includes_encrypted_content_in_include(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openai-codex", api_mode="codex_responses",
|
||||
@@ -596,19 +596,19 @@ class TestCodexReasoningPreflight:
|
||||
# ── Reasoning effort consistency tests ───────────────────────────────────────
|
||||
|
||||
class TestReasoningEffortDefaults:
|
||||
"""Verify reasoning effort defaults to xhigh across all provider paths."""
|
||||
"""Verify reasoning effort defaults to medium across all provider paths."""
|
||||
|
||||
def test_openrouter_default_xhigh(self, monkeypatch):
|
||||
def test_openrouter_default_medium(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openrouter")
|
||||
kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}])
|
||||
reasoning = kwargs["extra_body"]["reasoning"]
|
||||
assert reasoning["effort"] == "xhigh"
|
||||
assert reasoning["effort"] == "medium"
|
||||
|
||||
def test_codex_default_xhigh(self, monkeypatch):
|
||||
def test_codex_default_medium(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openai-codex", api_mode="codex_responses",
|
||||
base_url="https://chatgpt.com/backend-api/codex")
|
||||
kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}])
|
||||
assert kwargs["reasoning"]["effort"] == "xhigh"
|
||||
assert kwargs["reasoning"]["effort"] == "medium"
|
||||
|
||||
def test_codex_reasoning_disabled(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openai-codex", api_mode="codex_responses",
|
||||
|
||||
@@ -280,22 +280,21 @@ class TestMaskApiKey:
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_anthropic_base_url_fails_fast(self):
|
||||
"""Anthropic native endpoints should error before building an OpenAI client."""
|
||||
def test_anthropic_base_url_accepted(self):
|
||||
"""Anthropic base URLs should be accepted (OpenAI-compatible endpoint)."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
with pytest.raises(ValueError, match="not supported yet"):
|
||||
AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://api.anthropic.com/v1/messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
mock_openai.assert_not_called()
|
||||
AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://api.anthropic.com/v1/",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
mock_openai.assert_called_once()
|
||||
|
||||
def test_prompt_caching_claude_openrouter(self):
|
||||
"""Claude model via OpenRouter should enable prompt caching."""
|
||||
@@ -498,12 +497,12 @@ class TestBuildApiKwargs:
|
||||
assert kwargs["extra_body"]["provider"]["only"] == ["Anthropic"]
|
||||
|
||||
def test_reasoning_config_default_openrouter(self, agent):
|
||||
"""Default reasoning config for OpenRouter should be xhigh."""
|
||||
"""Default reasoning config for OpenRouter should be medium."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
reasoning = kwargs["extra_body"]["reasoning"]
|
||||
assert reasoning["enabled"] is True
|
||||
assert reasoning["effort"] == "xhigh"
|
||||
assert reasoning["effort"] == "medium"
|
||||
|
||||
def test_reasoning_config_custom(self, agent):
|
||||
agent.reasoning_config = {"enabled": False}
|
||||
@@ -765,6 +764,43 @@ class TestRunConversation:
|
||||
assert result["completed"] is False
|
||||
assert result.get("partial") is True
|
||||
|
||||
def test_nous_401_refreshes_after_remint_and_retries(self, agent):
|
||||
self._setup_agent(agent)
|
||||
agent.provider = "nous"
|
||||
agent.api_mode = "chat_completions"
|
||||
|
||||
calls = {"api": 0, "refresh": 0}
|
||||
|
||||
class _UnauthorizedError(RuntimeError):
|
||||
def __init__(self):
|
||||
super().__init__("Error code: 401 - unauthorized")
|
||||
self.status_code = 401
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
calls["api"] += 1
|
||||
if calls["api"] == 1:
|
||||
raise _UnauthorizedError()
|
||||
return _mock_response(content="Recovered after remint", finish_reason="stop")
|
||||
|
||||
def _fake_refresh(*, force=True):
|
||||
calls["refresh"] += 1
|
||||
assert force is True
|
||||
return True
|
||||
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
patch.object(agent, "_interruptible_api_call", side_effect=_fake_api_call),
|
||||
patch.object(agent, "_try_refresh_nous_client_credentials", side_effect=_fake_refresh),
|
||||
):
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
assert calls["api"] == 2
|
||||
assert calls["refresh"] == 1
|
||||
assert result["completed"] is True
|
||||
assert result["final_response"] == "Recovered after remint"
|
||||
|
||||
def test_context_compression_triggered(self, agent):
|
||||
"""When compressor says should_compress, compression runs."""
|
||||
self._setup_agent(agent)
|
||||
@@ -938,6 +974,50 @@ class TestConversationHistoryNotMutated:
|
||||
# _max_tokens_param consistency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNousCredentialRefresh:
|
||||
"""Verify Nous credential refresh rebuilds the runtime client."""
|
||||
|
||||
def test_try_refresh_nous_client_credentials_rebuilds_client(self, agent, monkeypatch):
|
||||
agent.provider = "nous"
|
||||
agent.api_mode = "chat_completions"
|
||||
|
||||
closed = {"value": False}
|
||||
rebuilt = {"kwargs": None}
|
||||
captured = {}
|
||||
|
||||
class _ExistingClient:
|
||||
def close(self):
|
||||
closed["value"] = True
|
||||
|
||||
class _RebuiltClient:
|
||||
pass
|
||||
|
||||
def _fake_resolve(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return {
|
||||
"api_key": "new-nous-key",
|
||||
"base_url": "https://inference-api.nousresearch.com/v1",
|
||||
}
|
||||
|
||||
def _fake_openai(**kwargs):
|
||||
rebuilt["kwargs"] = kwargs
|
||||
return _RebuiltClient()
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth.resolve_nous_runtime_credentials", _fake_resolve)
|
||||
|
||||
agent.client = _ExistingClient()
|
||||
with patch("run_agent.OpenAI", side_effect=_fake_openai):
|
||||
ok = agent._try_refresh_nous_client_credentials(force=True)
|
||||
|
||||
assert ok is True
|
||||
assert closed["value"] is True
|
||||
assert captured["force_mint"] is True
|
||||
assert rebuilt["kwargs"]["api_key"] == "new-nous-key"
|
||||
assert rebuilt["kwargs"]["base_url"] == "https://inference-api.nousresearch.com/v1"
|
||||
assert "default_headers" not in rebuilt["kwargs"]
|
||||
assert isinstance(agent.client, _RebuiltClient)
|
||||
|
||||
|
||||
class TestMaxTokensParam:
|
||||
"""Verify _max_tokens_param returns the correct key for each provider."""
|
||||
|
||||
|
||||
@@ -121,6 +121,43 @@ def test_openai_key_used_when_no_openrouter_key(monkeypatch):
|
||||
assert resolved["api_key"] == "sk-openai-fallback"
|
||||
|
||||
|
||||
def test_custom_endpoint_prefers_openai_key(monkeypatch):
|
||||
"""Custom endpoint should use OPENAI_API_KEY, not OPENROUTER_API_KEY.
|
||||
|
||||
Regression test for #560: when base_url is a non-OpenRouter endpoint,
|
||||
OPENROUTER_API_KEY was being sent as the auth header instead of OPENAI_API_KEY.
|
||||
"""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://api.z.ai/api/coding/paas/v4")
|
||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-zai-correct-key")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-wrong-key-for-zai")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert resolved["base_url"] == "https://api.z.ai/api/coding/paas/v4"
|
||||
assert resolved["api_key"] == "sk-zai-correct-key"
|
||||
|
||||
|
||||
def test_custom_endpoint_auto_provider_prefers_openai_key(monkeypatch):
|
||||
"""Auto provider with non-OpenRouter base_url should prefer OPENAI_API_KEY.
|
||||
|
||||
Same as #560 but via 'hermes model' flow which sets provider to 'auto'.
|
||||
"""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://my-vllm-server.example.com/v1")
|
||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-vllm-key")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-should-not-leak")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="auto")
|
||||
|
||||
assert resolved["base_url"] == "https://my-vllm-server.example.com/v1"
|
||||
assert resolved["api_key"] == "sk-vllm-key"
|
||||
|
||||
|
||||
def test_resolve_requested_provider_precedence(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "nous")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "openai-codex"})
|
||||
|
||||
269
tests/test_timezone.py
Normal file
269
tests/test_timezone.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Tests for timezone support (hermes_time module + integration points).
|
||||
|
||||
Covers:
|
||||
- Valid timezone applies correctly
|
||||
- Invalid timezone falls back safely (no crash, warning logged)
|
||||
- execute_code child env receives TZ
|
||||
- Cron uses timezone-aware now()
|
||||
- Backward compatibility with naive timestamps
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch, MagicMock
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import hermes_time
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# hermes_time.now() — core helper
|
||||
# =========================================================================
|
||||
|
||||
class TestHermesTimeNow:
|
||||
"""Test the timezone-aware now() helper."""
|
||||
|
||||
def setup_method(self):
|
||||
hermes_time.reset_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
hermes_time.reset_cache()
|
||||
os.environ.pop("HERMES_TIMEZONE", None)
|
||||
|
||||
def test_valid_timezone_applies(self):
|
||||
"""With a valid IANA timezone, now() returns time in that zone."""
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
result = hermes_time.now()
|
||||
assert result.tzinfo is not None
|
||||
# IST is UTC+5:30
|
||||
offset = result.utcoffset()
|
||||
assert offset == timedelta(hours=5, minutes=30)
|
||||
|
||||
def test_utc_timezone(self):
|
||||
"""UTC timezone works."""
|
||||
os.environ["HERMES_TIMEZONE"] = "UTC"
|
||||
result = hermes_time.now()
|
||||
assert result.utcoffset() == timedelta(0)
|
||||
|
||||
def test_us_eastern(self):
|
||||
"""US/Eastern timezone works (DST-aware zone)."""
|
||||
os.environ["HERMES_TIMEZONE"] = "America/New_York"
|
||||
result = hermes_time.now()
|
||||
assert result.tzinfo is not None
|
||||
# Offset is -5h or -4h depending on DST
|
||||
offset_hours = result.utcoffset().total_seconds() / 3600
|
||||
assert offset_hours in (-5, -4)
|
||||
|
||||
def test_invalid_timezone_falls_back(self, caplog):
|
||||
"""Invalid timezone logs warning and falls back to server-local."""
|
||||
os.environ["HERMES_TIMEZONE"] = "Mars/Olympus_Mons"
|
||||
with caplog.at_level(logging.WARNING, logger="hermes_time"):
|
||||
result = hermes_time.now()
|
||||
assert result.tzinfo is not None # Still tz-aware (server-local)
|
||||
assert "Invalid timezone" in caplog.text
|
||||
assert "Mars/Olympus_Mons" in caplog.text
|
||||
|
||||
def test_empty_timezone_uses_local(self):
|
||||
"""No timezone configured → server-local time (still tz-aware)."""
|
||||
os.environ.pop("HERMES_TIMEZONE", None)
|
||||
result = hermes_time.now()
|
||||
assert result.tzinfo is not None
|
||||
|
||||
def test_format_unchanged(self):
|
||||
"""Timestamp formatting matches original strftime pattern."""
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
result = hermes_time.now()
|
||||
formatted = result.strftime("%A, %B %d, %Y %I:%M %p")
|
||||
# Should produce something like "Monday, March 03, 2026 05:30 PM"
|
||||
assert len(formatted) > 10
|
||||
# No timezone abbreviation in the format (matching original behavior)
|
||||
assert "+" not in formatted
|
||||
|
||||
def test_cache_invalidation(self):
|
||||
"""Changing env var + reset_cache picks up new timezone."""
|
||||
os.environ["HERMES_TIMEZONE"] = "UTC"
|
||||
hermes_time.reset_cache()
|
||||
r1 = hermes_time.now()
|
||||
assert r1.utcoffset() == timedelta(0)
|
||||
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
hermes_time.reset_cache()
|
||||
r2 = hermes_time.now()
|
||||
assert r2.utcoffset() == timedelta(hours=5, minutes=30)
|
||||
|
||||
|
||||
class TestGetTimezone:
|
||||
"""Test get_timezone() and get_timezone_name()."""
|
||||
|
||||
def setup_method(self):
|
||||
hermes_time.reset_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
hermes_time.reset_cache()
|
||||
os.environ.pop("HERMES_TIMEZONE", None)
|
||||
|
||||
def test_returns_zoneinfo_for_valid(self):
|
||||
os.environ["HERMES_TIMEZONE"] = "Europe/London"
|
||||
tz = hermes_time.get_timezone()
|
||||
assert isinstance(tz, ZoneInfo)
|
||||
assert str(tz) == "Europe/London"
|
||||
|
||||
def test_returns_none_for_empty(self):
|
||||
os.environ.pop("HERMES_TIMEZONE", None)
|
||||
tz = hermes_time.get_timezone()
|
||||
assert tz is None
|
||||
|
||||
def test_returns_none_for_invalid(self):
|
||||
os.environ["HERMES_TIMEZONE"] = "Not/A/Timezone"
|
||||
tz = hermes_time.get_timezone()
|
||||
assert tz is None
|
||||
|
||||
def test_get_timezone_name(self):
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Tokyo"
|
||||
assert hermes_time.get_timezone_name() == "Asia/Tokyo"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# execute_code child env — TZ injection
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="UDS not available on Windows")
|
||||
class TestCodeExecutionTZ:
|
||||
"""Verify TZ env var is passed to sandboxed child process via real execute_code."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _import_execute_code(self):
|
||||
"""Lazy-import execute_code to avoid pulling in firecrawl at collection time."""
|
||||
try:
|
||||
from tools.code_execution_tool import execute_code
|
||||
self._execute_code = execute_code
|
||||
except ImportError:
|
||||
pytest.skip("tools.code_execution_tool not importable (missing deps)")
|
||||
|
||||
def teardown_method(self):
|
||||
os.environ.pop("HERMES_TIMEZONE", None)
|
||||
|
||||
def _mock_handle(self, function_name, function_args, task_id=None, user_task=None):
|
||||
import json as _json
|
||||
return _json.dumps({"error": f"unexpected tool call: {function_name}"})
|
||||
|
||||
def test_tz_injected_when_configured(self):
|
||||
"""When HERMES_TIMEZONE is set, child process sees TZ env var."""
|
||||
import json as _json
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
|
||||
with patch("model_tools.handle_function_call", side_effect=self._mock_handle):
|
||||
result = _json.loads(self._execute_code(
|
||||
code='import os; print(os.environ.get("TZ", "NOT_SET"))',
|
||||
task_id="tz-test",
|
||||
enabled_tools=[],
|
||||
))
|
||||
assert result["status"] == "success"
|
||||
assert "Asia/Kolkata" in result["output"]
|
||||
|
||||
def test_tz_not_injected_when_empty(self):
|
||||
"""When HERMES_TIMEZONE is not set, child process has no TZ."""
|
||||
import json as _json
|
||||
os.environ.pop("HERMES_TIMEZONE", None)
|
||||
|
||||
with patch("model_tools.handle_function_call", side_effect=self._mock_handle):
|
||||
result = _json.loads(self._execute_code(
|
||||
code='import os; print(os.environ.get("TZ", "NOT_SET"))',
|
||||
task_id="tz-test-empty",
|
||||
enabled_tools=[],
|
||||
))
|
||||
assert result["status"] == "success"
|
||||
assert "NOT_SET" in result["output"]
|
||||
|
||||
def test_hermes_timezone_not_leaked_to_child(self):
|
||||
"""HERMES_TIMEZONE itself must NOT appear in child env (only TZ)."""
|
||||
import json as _json
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
|
||||
with patch("model_tools.handle_function_call", side_effect=self._mock_handle):
|
||||
result = _json.loads(self._execute_code(
|
||||
code='import os; print(os.environ.get("HERMES_TIMEZONE", "NOT_SET"))',
|
||||
task_id="tz-leak-test",
|
||||
enabled_tools=[],
|
||||
))
|
||||
assert result["status"] == "success"
|
||||
assert "NOT_SET" in result["output"]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Cron timezone-aware scheduling
|
||||
# =========================================================================
|
||||
|
||||
class TestCronTimezone:
|
||||
"""Verify cron paths use timezone-aware now()."""
|
||||
|
||||
def setup_method(self):
|
||||
hermes_time.reset_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
hermes_time.reset_cache()
|
||||
os.environ.pop("HERMES_TIMEZONE", None)
|
||||
|
||||
def test_parse_schedule_duration_uses_tz_aware_now(self):
|
||||
"""parse_schedule('30m') should produce a tz-aware run_at."""
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
from cron.jobs import parse_schedule
|
||||
result = parse_schedule("30m")
|
||||
run_at = datetime.fromisoformat(result["run_at"])
|
||||
# The stored timestamp should be tz-aware
|
||||
assert run_at.tzinfo is not None
|
||||
|
||||
def test_compute_next_run_tz_aware(self):
|
||||
"""compute_next_run returns tz-aware timestamps."""
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
from cron.jobs import compute_next_run
|
||||
schedule = {"kind": "interval", "minutes": 60}
|
||||
result = compute_next_run(schedule)
|
||||
next_dt = datetime.fromisoformat(result)
|
||||
assert next_dt.tzinfo is not None
|
||||
|
||||
def test_get_due_jobs_handles_naive_timestamps(self, tmp_path, monkeypatch):
|
||||
"""Backward compat: naive timestamps from before tz support don't crash."""
|
||||
import cron.jobs as jobs_module
|
||||
monkeypatch.setattr(jobs_module, "CRON_DIR", tmp_path / "cron")
|
||||
monkeypatch.setattr(jobs_module, "JOBS_FILE", tmp_path / "cron" / "jobs.json")
|
||||
monkeypatch.setattr(jobs_module, "OUTPUT_DIR", tmp_path / "cron" / "output")
|
||||
|
||||
os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata"
|
||||
hermes_time.reset_cache()
|
||||
|
||||
# Create a job with a NAIVE past timestamp (simulating pre-tz data)
|
||||
from cron.jobs import create_job, load_jobs, save_jobs, get_due_jobs
|
||||
job = create_job(prompt="Test job", schedule="every 1h")
|
||||
jobs = load_jobs()
|
||||
# Force a naive (no timezone) past timestamp
|
||||
naive_past = (datetime.now() - timedelta(minutes=5)).isoformat()
|
||||
jobs[0]["next_run_at"] = naive_past
|
||||
save_jobs(jobs)
|
||||
|
||||
# Should not crash — _ensure_aware handles the naive timestamp
|
||||
due = get_due_jobs()
|
||||
assert len(due) == 1
|
||||
|
||||
def test_create_job_stores_tz_aware_timestamps(self, tmp_path, monkeypatch):
|
||||
"""New jobs store timezone-aware created_at and next_run_at."""
|
||||
import cron.jobs as jobs_module
|
||||
monkeypatch.setattr(jobs_module, "CRON_DIR", tmp_path / "cron")
|
||||
monkeypatch.setattr(jobs_module, "JOBS_FILE", tmp_path / "cron" / "jobs.json")
|
||||
monkeypatch.setattr(jobs_module, "OUTPUT_DIR", tmp_path / "cron" / "output")
|
||||
|
||||
os.environ["HERMES_TIMEZONE"] = "US/Eastern"
|
||||
hermes_time.reset_cache()
|
||||
|
||||
from cron.jobs import create_job
|
||||
job = create_job(prompt="TZ test", schedule="every 2h")
|
||||
|
||||
created = datetime.fromisoformat(job["created_at"])
|
||||
assert created.tzinfo is not None
|
||||
|
||||
next_run = datetime.fromisoformat(job["next_run_at"])
|
||||
assert next_run.tzinfo is not None
|
||||
635
tests/test_worktree.py
Normal file
635
tests/test_worktree.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""Tests for git worktree isolation (CLI --worktree / -w flag).
|
||||
|
||||
Verifies worktree creation, cleanup, .worktreeinclude handling,
|
||||
.gitignore management, and integration with the CLI. (#652)
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_repo(tmp_path):
|
||||
"""Create a temporary git repo for testing."""
|
||||
repo = tmp_path / "test-repo"
|
||||
repo.mkdir()
|
||||
subprocess.run(["git", "init"], cwd=repo, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "config", "user.email", "test@test.com"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "config", "user.name", "Test"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
# Create initial commit (worktrees need at least one commit)
|
||||
(repo / "README.md").write_text("# Test Repo\n")
|
||||
subprocess.run(["git", "add", "."], cwd=repo, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Initial commit"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
return repo
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lightweight reimplementations for testing (avoid importing cli.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _git_repo_root(cwd=None):
|
||||
"""Test version of _git_repo_root."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--show-toplevel"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
cwd=cwd,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _setup_worktree(repo_root):
|
||||
"""Test version of _setup_worktree — creates a worktree."""
|
||||
import uuid
|
||||
short_id = uuid.uuid4().hex[:8]
|
||||
wt_name = f"hermes-{short_id}"
|
||||
branch_name = f"hermes/{wt_name}"
|
||||
|
||||
worktrees_dir = Path(repo_root) / ".worktrees"
|
||||
worktrees_dir.mkdir(parents=True, exist_ok=True)
|
||||
wt_path = worktrees_dir / wt_name
|
||||
|
||||
result = subprocess.run(
|
||||
["git", "worktree", "add", str(wt_path), "-b", branch_name, "HEAD"],
|
||||
capture_output=True, text=True, timeout=30, cwd=repo_root,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"path": str(wt_path),
|
||||
"branch": branch_name,
|
||||
"repo_root": repo_root,
|
||||
}
|
||||
|
||||
|
||||
def _cleanup_worktree(info):
|
||||
"""Test version of _cleanup_worktree."""
|
||||
wt_path = info["path"]
|
||||
branch = info["branch"]
|
||||
repo_root = info["repo_root"]
|
||||
|
||||
if not Path(wt_path).exists():
|
||||
return
|
||||
|
||||
# Check for uncommitted changes
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True, text=True, timeout=10, cwd=wt_path,
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
|
||||
if has_changes:
|
||||
return False # Did not clean up
|
||||
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", wt_path, "--force"],
|
||||
capture_output=True, text=True, timeout=15, cwd=repo_root,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "branch", "-D", branch],
|
||||
capture_output=True, text=True, timeout=10, cwd=repo_root,
|
||||
)
|
||||
return True # Cleaned up
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGitRepoDetection:
|
||||
"""Test git repo root detection."""
|
||||
|
||||
def test_detects_git_repo(self, git_repo):
|
||||
root = _git_repo_root(cwd=str(git_repo))
|
||||
assert root is not None
|
||||
assert Path(root).resolve() == git_repo.resolve()
|
||||
|
||||
def test_detects_subdirectory(self, git_repo):
|
||||
subdir = git_repo / "src" / "lib"
|
||||
subdir.mkdir(parents=True)
|
||||
root = _git_repo_root(cwd=str(subdir))
|
||||
assert root is not None
|
||||
assert Path(root).resolve() == git_repo.resolve()
|
||||
|
||||
def test_returns_none_outside_repo(self, tmp_path):
|
||||
# tmp_path itself is not a git repo
|
||||
bare_dir = tmp_path / "not-a-repo"
|
||||
bare_dir.mkdir()
|
||||
root = _git_repo_root(cwd=str(bare_dir))
|
||||
assert root is None
|
||||
|
||||
|
||||
class TestWorktreeCreation:
|
||||
"""Test worktree setup."""
|
||||
|
||||
def test_creates_worktree(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert Path(info["path"]).exists()
|
||||
assert info["branch"].startswith("hermes/hermes-")
|
||||
assert info["repo_root"] == str(git_repo)
|
||||
|
||||
# Verify it's a valid git worktree
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
assert result.stdout.strip() == "true"
|
||||
|
||||
def test_worktree_has_own_branch(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Check branch name in worktree
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
assert result.stdout.strip() == info["branch"]
|
||||
|
||||
def test_worktree_is_independent(self, git_repo):
|
||||
"""Two worktrees from the same repo are independent."""
|
||||
info1 = _setup_worktree(str(git_repo))
|
||||
info2 = _setup_worktree(str(git_repo))
|
||||
assert info1 is not None
|
||||
assert info2 is not None
|
||||
assert info1["path"] != info2["path"]
|
||||
assert info1["branch"] != info2["branch"]
|
||||
|
||||
# Create a file in worktree 1
|
||||
(Path(info1["path"]) / "only-in-wt1.txt").write_text("hello")
|
||||
|
||||
# It should NOT appear in worktree 2
|
||||
assert not (Path(info2["path"]) / "only-in-wt1.txt").exists()
|
||||
|
||||
def test_worktrees_dir_created(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert (git_repo / ".worktrees").is_dir()
|
||||
|
||||
def test_worktree_has_repo_files(self, git_repo):
|
||||
"""Worktree should contain the repo's tracked files."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert (Path(info["path"]) / "README.md").exists()
|
||||
|
||||
|
||||
class TestWorktreeCleanup:
|
||||
"""Test worktree cleanup on exit."""
|
||||
|
||||
def test_clean_worktree_removed(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
result = _cleanup_worktree(info)
|
||||
assert result is True
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
def test_dirty_worktree_kept(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make uncommitted changes
|
||||
(Path(info["path"]) / "new-file.txt").write_text("uncommitted")
|
||||
subprocess.run(
|
||||
["git", "add", "new-file.txt"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
result = _cleanup_worktree(info)
|
||||
assert result is False
|
||||
assert Path(info["path"]).exists() # Still there
|
||||
|
||||
def test_branch_deleted_on_cleanup(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
branch = info["branch"]
|
||||
|
||||
_cleanup_worktree(info)
|
||||
|
||||
# Branch should be gone
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--list", branch],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
assert branch not in result.stdout
|
||||
|
||||
def test_cleanup_nonexistent_worktree(self, git_repo):
|
||||
"""Cleanup should handle already-removed worktrees gracefully."""
|
||||
info = {
|
||||
"path": str(git_repo / ".worktrees" / "nonexistent"),
|
||||
"branch": "hermes/nonexistent",
|
||||
"repo_root": str(git_repo),
|
||||
}
|
||||
# Should not raise
|
||||
_cleanup_worktree(info)
|
||||
|
||||
|
||||
class TestWorktreeInclude:
|
||||
"""Test .worktreeinclude file handling."""
|
||||
|
||||
def test_copies_included_files(self, git_repo):
|
||||
"""Files listed in .worktreeinclude should be copied to the worktree."""
|
||||
# Create a .env file (gitignored)
|
||||
(git_repo / ".env").write_text("SECRET=abc123")
|
||||
(git_repo / ".gitignore").write_text(".env\n.worktrees/\n")
|
||||
subprocess.run(
|
||||
["git", "add", ".gitignore"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Add gitignore"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
|
||||
# Create .worktreeinclude
|
||||
(git_repo / ".worktreeinclude").write_text(".env\n")
|
||||
|
||||
# Import and use the real _setup_worktree logic for include handling
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Manually copy .worktreeinclude entries (mirrors cli.py logic)
|
||||
import shutil
|
||||
include_file = git_repo / ".worktreeinclude"
|
||||
wt_path = Path(info["path"])
|
||||
for line in include_file.read_text().splitlines():
|
||||
entry = line.strip()
|
||||
if not entry or entry.startswith("#"):
|
||||
continue
|
||||
src = git_repo / entry
|
||||
dst = wt_path / entry
|
||||
if src.is_file():
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(str(src), str(dst))
|
||||
|
||||
# Verify .env was copied
|
||||
assert (wt_path / ".env").exists()
|
||||
assert (wt_path / ".env").read_text() == "SECRET=abc123"
|
||||
|
||||
def test_ignores_comments_and_blanks(self, git_repo):
|
||||
"""Comments and blank lines in .worktreeinclude should be skipped."""
|
||||
(git_repo / ".worktreeinclude").write_text(
|
||||
"# This is a comment\n"
|
||||
"\n"
|
||||
" # Another comment\n"
|
||||
)
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
# Should not crash — just skip all lines
|
||||
|
||||
|
||||
class TestGitignoreManagement:
|
||||
"""Test that .worktrees/ is added to .gitignore."""
|
||||
|
||||
def test_adds_to_gitignore(self, git_repo):
|
||||
"""Creating a worktree should add .worktrees/ to .gitignore."""
|
||||
# Remove any existing .gitignore
|
||||
gitignore = git_repo / ".gitignore"
|
||||
if gitignore.exists():
|
||||
gitignore.unlink()
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Now manually add .worktrees/ to .gitignore (mirrors cli.py logic)
|
||||
_ignore_entry = ".worktrees/"
|
||||
existing = gitignore.read_text() if gitignore.exists() else ""
|
||||
if _ignore_entry not in existing.splitlines():
|
||||
with open(gitignore, "a") as f:
|
||||
if existing and not existing.endswith("\n"):
|
||||
f.write("\n")
|
||||
f.write(f"{_ignore_entry}\n")
|
||||
|
||||
content = gitignore.read_text()
|
||||
assert ".worktrees/" in content
|
||||
|
||||
def test_does_not_duplicate_gitignore_entry(self, git_repo):
|
||||
"""If .worktrees/ is already in .gitignore, don't add again."""
|
||||
gitignore = git_repo / ".gitignore"
|
||||
gitignore.write_text(".worktrees/\n")
|
||||
|
||||
# The check should see it's already there
|
||||
existing = gitignore.read_text()
|
||||
assert ".worktrees/" in existing.splitlines()
|
||||
|
||||
|
||||
class TestMultipleWorktrees:
|
||||
"""Test running multiple worktrees concurrently (the core use case)."""
|
||||
|
||||
def test_ten_concurrent_worktrees(self, git_repo):
|
||||
"""Create 10 worktrees — simulating 10 parallel agents."""
|
||||
worktrees = []
|
||||
for _ in range(10):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
worktrees.append(info)
|
||||
|
||||
# All should exist and be independent
|
||||
paths = [info["path"] for info in worktrees]
|
||||
assert len(set(paths)) == 10 # All unique
|
||||
|
||||
# Each should have the repo files
|
||||
for info in worktrees:
|
||||
assert (Path(info["path"]) / "README.md").exists()
|
||||
|
||||
# Edit a file in one worktree
|
||||
(Path(worktrees[0]["path"]) / "README.md").write_text("Modified in wt0")
|
||||
|
||||
# Others should be unaffected
|
||||
for info in worktrees[1:]:
|
||||
assert (Path(info["path"]) / "README.md").read_text() == "# Test Repo\n"
|
||||
|
||||
# List worktrees via git
|
||||
result = subprocess.run(
|
||||
["git", "worktree", "list"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
# Should have 11 entries: main + 10 worktrees
|
||||
lines = [l for l in result.stdout.strip().splitlines() if l.strip()]
|
||||
assert len(lines) == 11
|
||||
|
||||
# Cleanup all
|
||||
for info in worktrees:
|
||||
# Discard changes first so cleanup works
|
||||
subprocess.run(
|
||||
["git", "checkout", "--", "."],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
_cleanup_worktree(info)
|
||||
|
||||
# All should be removed
|
||||
for info in worktrees:
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
|
||||
class TestWorktreeDirectorySymlink:
|
||||
"""Test .worktreeinclude with directories (symlinked)."""
|
||||
|
||||
def test_symlinks_directory(self, git_repo):
|
||||
"""Directories in .worktreeinclude should be symlinked."""
|
||||
# Create a .venv directory
|
||||
venv_dir = git_repo / ".venv" / "lib"
|
||||
venv_dir.mkdir(parents=True)
|
||||
(venv_dir / "marker.txt").write_text("venv marker")
|
||||
(git_repo / ".gitignore").write_text(".venv/\n.worktrees/\n")
|
||||
subprocess.run(
|
||||
["git", "add", ".gitignore"], cwd=str(git_repo), capture_output=True
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "gitignore"], cwd=str(git_repo), capture_output=True
|
||||
)
|
||||
|
||||
(git_repo / ".worktreeinclude").write_text(".venv/\n")
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
wt_path = Path(info["path"])
|
||||
src = git_repo / ".venv"
|
||||
dst = wt_path / ".venv"
|
||||
|
||||
# Manually symlink (mirrors cli.py logic)
|
||||
if not dst.exists():
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
os.symlink(str(src.resolve()), str(dst))
|
||||
|
||||
assert dst.is_symlink()
|
||||
assert (dst / "lib" / "marker.txt").read_text() == "venv marker"
|
||||
|
||||
|
||||
class TestStaleWorktreePruning:
|
||||
"""Test _prune_stale_worktrees garbage collection."""
|
||||
|
||||
def test_prunes_old_clean_worktree(self, git_repo):
|
||||
"""Old clean worktrees should be removed on prune."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
# Make the worktree look old (set mtime to 25h ago)
|
||||
old_time = time.time() - (25 * 3600)
|
||||
os.utime(info["path"], (old_time, old_time))
|
||||
|
||||
# Reimplementation of prune logic (matches cli.py)
|
||||
worktrees_dir = git_repo / ".worktrees"
|
||||
cutoff = time.time() - (24 * 3600)
|
||||
|
||||
for entry in worktrees_dir.iterdir():
|
||||
if not entry.is_dir() or not entry.name.startswith("hermes-"):
|
||||
continue
|
||||
try:
|
||||
mtime = entry.stat().st_mtime
|
||||
if mtime > cutoff:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True, text=True, timeout=5, cwd=str(entry),
|
||||
)
|
||||
if status.stdout.strip():
|
||||
continue
|
||||
|
||||
branch_result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
capture_output=True, text=True, timeout=5, cwd=str(entry),
|
||||
)
|
||||
branch = branch_result.stdout.strip()
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", str(entry), "--force"],
|
||||
capture_output=True, text=True, timeout=15, cwd=str(git_repo),
|
||||
)
|
||||
if branch:
|
||||
subprocess.run(
|
||||
["git", "branch", "-D", branch],
|
||||
capture_output=True, text=True, timeout=10, cwd=str(git_repo),
|
||||
)
|
||||
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
def test_keeps_recent_worktree(self, git_repo):
|
||||
"""Recent worktrees should NOT be pruned."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Don't modify mtime — it's recent
|
||||
worktrees_dir = git_repo / ".worktrees"
|
||||
cutoff = time.time() - (24 * 3600)
|
||||
|
||||
pruned = False
|
||||
for entry in worktrees_dir.iterdir():
|
||||
if not entry.is_dir() or not entry.name.startswith("hermes-"):
|
||||
continue
|
||||
mtime = entry.stat().st_mtime
|
||||
if mtime > cutoff:
|
||||
continue # Too recent
|
||||
pruned = True
|
||||
|
||||
assert not pruned
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
def test_keeps_dirty_old_worktree(self, git_repo):
|
||||
"""Old worktrees with uncommitted changes should NOT be pruned."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make it dirty
|
||||
(Path(info["path"]) / "dirty.txt").write_text("uncommitted")
|
||||
subprocess.run(
|
||||
["git", "add", "dirty.txt"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
# Make it old
|
||||
old_time = time.time() - (25 * 3600)
|
||||
os.utime(info["path"], (old_time, old_time))
|
||||
|
||||
# Check if it would be pruned
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
assert has_changes # Should be dirty → not pruned
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases for robustness."""
|
||||
|
||||
def test_no_commits_repo(self, tmp_path):
|
||||
"""Worktree creation should fail gracefully on a repo with no commits."""
|
||||
repo = tmp_path / "empty-repo"
|
||||
repo.mkdir()
|
||||
subprocess.run(["git", "init"], cwd=str(repo), capture_output=True)
|
||||
|
||||
info = _setup_worktree(str(repo))
|
||||
assert info is None # Should fail gracefully
|
||||
|
||||
def test_not_a_git_repo(self, tmp_path):
|
||||
"""Repo detection should return None for non-git directories."""
|
||||
bare = tmp_path / "not-git"
|
||||
bare.mkdir()
|
||||
root = _git_repo_root(cwd=str(bare))
|
||||
assert root is None
|
||||
|
||||
def test_worktrees_dir_already_exists(self, git_repo):
|
||||
"""Should work fine if .worktrees/ already exists."""
|
||||
(git_repo / ".worktrees").mkdir(exist_ok=True)
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
|
||||
class TestCLIFlagLogic:
|
||||
"""Test the flag/config OR logic from main()."""
|
||||
|
||||
def test_worktree_flag_triggers(self):
|
||||
"""--worktree flag should trigger worktree creation."""
|
||||
worktree = True
|
||||
w = False
|
||||
config_worktree = False
|
||||
use_worktree = worktree or w or config_worktree
|
||||
assert use_worktree
|
||||
|
||||
def test_w_flag_triggers(self):
|
||||
"""-w flag should trigger worktree creation."""
|
||||
worktree = False
|
||||
w = True
|
||||
config_worktree = False
|
||||
use_worktree = worktree or w or config_worktree
|
||||
assert use_worktree
|
||||
|
||||
def test_config_triggers(self):
|
||||
"""worktree: true in config should trigger worktree creation."""
|
||||
worktree = False
|
||||
w = False
|
||||
config_worktree = True
|
||||
use_worktree = worktree or w or config_worktree
|
||||
assert use_worktree
|
||||
|
||||
def test_none_set_no_trigger(self):
|
||||
"""No flags and no config should not trigger."""
|
||||
worktree = False
|
||||
w = False
|
||||
config_worktree = False
|
||||
use_worktree = worktree or w or config_worktree
|
||||
assert not use_worktree
|
||||
|
||||
|
||||
class TestTerminalCWDIntegration:
|
||||
"""Test that TERMINAL_CWD is correctly set to the worktree path."""
|
||||
|
||||
def test_terminal_cwd_set(self, git_repo):
|
||||
"""After worktree setup, TERMINAL_CWD should point to the worktree."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# This is what main() does:
|
||||
os.environ["TERMINAL_CWD"] = info["path"]
|
||||
assert os.environ["TERMINAL_CWD"] == info["path"]
|
||||
assert Path(os.environ["TERMINAL_CWD"]).exists()
|
||||
|
||||
# Clean up env
|
||||
del os.environ["TERMINAL_CWD"]
|
||||
|
||||
def test_terminal_cwd_is_valid_git_repo(self, git_repo):
|
||||
"""The TERMINAL_CWD worktree should be a valid git working tree."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
assert result.stdout.strip() == "true"
|
||||
|
||||
|
||||
class TestSystemPromptInjection:
|
||||
"""Test that the agent gets worktree context in its system prompt."""
|
||||
|
||||
def test_prompt_note_format(self, git_repo):
|
||||
"""Verify the system prompt note contains all required info."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# This is what main() does:
|
||||
wt_note = (
|
||||
f"\n\n[System note: You are working in an isolated git worktree at "
|
||||
f"{info['path']}. Your branch is `{info['branch']}`. "
|
||||
f"Changes here do not affect the main working tree or other agents. "
|
||||
f"Remember to commit and push your changes, and create a PR if appropriate. "
|
||||
f"The original repo is at {info['repo_root']}.]"
|
||||
)
|
||||
|
||||
assert info["path"] in wt_note
|
||||
assert info["branch"] in wt_note
|
||||
assert info["repo_root"] in wt_note
|
||||
assert "isolated git worktree" in wt_note
|
||||
assert "commit and push" in wt_note
|
||||
@@ -602,11 +602,11 @@ class TestHasClipboardImage:
|
||||
|
||||
|
||||
# ═════════════════════════════════════════════════════════════════════════
|
||||
# Level 2: _build_multimodal_content — image → OpenAI vision format
|
||||
# Level 2: _preprocess_images_with_vision — image → text via vision tool
|
||||
# ═════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestBuildMultimodalContent:
|
||||
"""Test the extracted _build_multimodal_content method directly."""
|
||||
class TestPreprocessImagesWithVision:
|
||||
"""Test vision-based image pre-processing for the CLI."""
|
||||
|
||||
@pytest.fixture
|
||||
def cli(self):
|
||||
@@ -637,55 +637,81 @@ class TestBuildMultimodalContent:
|
||||
img.write_bytes(content)
|
||||
return img
|
||||
|
||||
def _mock_vision_success(self, description="A test image with colored pixels."):
|
||||
"""Return an async mock that simulates a successful vision_analyze_tool call."""
|
||||
import json
|
||||
async def _fake_vision(**kwargs):
|
||||
return json.dumps({"success": True, "analysis": description})
|
||||
return _fake_vision
|
||||
|
||||
def _mock_vision_failure(self):
|
||||
"""Return an async mock that simulates a failed vision_analyze_tool call."""
|
||||
import json
|
||||
async def _fake_vision(**kwargs):
|
||||
return json.dumps({"success": False, "analysis": "Error"})
|
||||
return _fake_vision
|
||||
|
||||
def test_single_image_with_text(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path)
|
||||
result = cli._build_multimodal_content("Describe this", [img])
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("Describe this", [img])
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == {"type": "text", "text": "Describe this"}
|
||||
assert result[1]["type"] == "image_url"
|
||||
url = result[1]["image_url"]["url"]
|
||||
assert url.startswith("data:image/png;base64,")
|
||||
# Verify the base64 actually decodes to our image
|
||||
b64_data = url.split(",", 1)[1]
|
||||
assert base64.b64decode(b64_data) == FAKE_PNG
|
||||
assert isinstance(result, str)
|
||||
assert "A test image with colored pixels." in result
|
||||
assert "Describe this" in result
|
||||
assert str(img) in result
|
||||
assert "base64," not in result # no raw base64 image content
|
||||
|
||||
def test_multiple_images(self, cli, tmp_path):
|
||||
imgs = [self._make_image(tmp_path, f"img{i}.png") for i in range(3)]
|
||||
result = cli._build_multimodal_content("Compare", imgs)
|
||||
assert len(result) == 4 # 1 text + 3 images
|
||||
assert all(r["type"] == "image_url" for r in result[1:])
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("Compare", imgs)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Compare" in result
|
||||
# Each image path should be referenced
|
||||
for img in imgs:
|
||||
assert str(img) in result
|
||||
|
||||
def test_empty_text_gets_default_question(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path)
|
||||
result = cli._build_multimodal_content("", [img])
|
||||
assert result[0]["text"] == "What do you see in this image?"
|
||||
|
||||
def test_jpeg_mime_type(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path, "photo.jpg", b"\xff\xd8\xff\x00" * 20)
|
||||
result = cli._build_multimodal_content("test", [img])
|
||||
assert "image/jpeg" in result[1]["image_url"]["url"]
|
||||
|
||||
def test_webp_mime_type(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path, "img.webp", b"RIFF\x00\x00" * 10)
|
||||
result = cli._build_multimodal_content("test", [img])
|
||||
assert "image/webp" in result[1]["image_url"]["url"]
|
||||
|
||||
def test_unknown_extension_defaults_to_png(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path, "data.bmp", b"\x00" * 50)
|
||||
result = cli._build_multimodal_content("test", [img])
|
||||
assert "image/png" in result[1]["image_url"]["url"]
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("", [img])
|
||||
assert isinstance(result, str)
|
||||
assert "A test image with colored pixels." in result
|
||||
|
||||
def test_missing_image_skipped(self, cli, tmp_path):
|
||||
missing = tmp_path / "gone.png"
|
||||
result = cli._build_multimodal_content("test", [missing])
|
||||
assert len(result) == 1 # only text
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("test", [missing])
|
||||
# No images analyzed, falls back to default
|
||||
assert result == "test"
|
||||
|
||||
def test_mix_of_existing_and_missing(self, cli, tmp_path):
|
||||
real = self._make_image(tmp_path, "real.png")
|
||||
missing = tmp_path / "gone.png"
|
||||
result = cli._build_multimodal_content("test", [real, missing])
|
||||
assert len(result) == 2 # text + 1 real image
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("test", [real, missing])
|
||||
assert str(real) in result
|
||||
assert str(missing) not in result
|
||||
assert "test" in result
|
||||
|
||||
def test_vision_failure_includes_path(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path)
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_failure()):
|
||||
result = cli._preprocess_images_with_vision("check this", [img])
|
||||
assert isinstance(result, str)
|
||||
assert str(img) in result # path still included for retry
|
||||
assert "check this" in result
|
||||
|
||||
def test_vision_exception_includes_path(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path)
|
||||
async def _explode(**kwargs):
|
||||
raise RuntimeError("API down")
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=_explode):
|
||||
result = cli._preprocess_images_with_vision("check this", [img])
|
||||
assert isinstance(result, str)
|
||||
assert str(img) in result # path still included for retry
|
||||
|
||||
|
||||
# ═════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@@ -56,7 +56,6 @@ class TestDelegateRequirements(unittest.TestCase):
|
||||
self.assertIn("tasks", props)
|
||||
self.assertIn("context", props)
|
||||
self.assertIn("toolsets", props)
|
||||
self.assertIn("model", props)
|
||||
self.assertIn("max_iterations", props)
|
||||
self.assertEqual(props["tasks"]["maxItems"], 3)
|
||||
|
||||
|
||||
@@ -259,6 +259,70 @@ class TestShellFileOpsHelpers:
|
||||
assert ops.cwd == "/"
|
||||
|
||||
|
||||
class TestSearchPathValidation:
|
||||
"""Test that search() returns an error for non-existent paths."""
|
||||
|
||||
def test_search_nonexistent_path_returns_error(self, mock_env):
|
||||
"""search() should return an error when the path doesn't exist."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "not_found", "returncode": 1}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
return {"output": "", "returncode": 0}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/nonexistent/path")
|
||||
assert result.error is not None
|
||||
assert "not found" in result.error.lower() or "Path not found" in result.error
|
||||
|
||||
def test_search_nonexistent_path_files_mode(self, mock_env):
|
||||
"""search(target='files') should also return error for bad paths."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "not_found", "returncode": 1}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
return {"output": "", "returncode": 0}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("*.py", path="/nonexistent/path", target="files")
|
||||
assert result.error is not None
|
||||
assert "not found" in result.error.lower() or "Path not found" in result.error
|
||||
|
||||
def test_search_existing_path_proceeds(self, mock_env):
|
||||
"""search() should proceed normally when the path exists."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "exists", "returncode": 0}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
# rg returns exit 1 (no matches) with empty output
|
||||
return {"output": "", "returncode": 1}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/existing/path")
|
||||
assert result.error is None
|
||||
assert result.total_count == 0 # No matches but no error
|
||||
|
||||
def test_search_rg_error_exit_code(self, mock_env):
|
||||
"""search() should report error when rg returns exit code 2."""
|
||||
call_count = {"n": 0}
|
||||
def side_effect(command, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if "test -e" in command:
|
||||
return {"output": "exists", "returncode": 0}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
# rg returns exit 2 (error) with empty output
|
||||
return {"output": "", "returncode": 2}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/some/path")
|
||||
assert result.error is not None
|
||||
assert "search failed" in result.error.lower() or "Search error" in result.error
|
||||
|
||||
|
||||
class TestShellFileOpsWriteDenied:
|
||||
def test_write_file_denied_path(self, file_ops):
|
||||
result = file_ops.write_file("~/.ssh/authorized_keys", "evil key")
|
||||
|
||||
@@ -17,42 +17,62 @@ from tools.skills_sync import (
|
||||
|
||||
class TestReadWriteManifest:
|
||||
def test_read_missing_manifest(self, tmp_path):
|
||||
with patch.object(
|
||||
__import__("tools.skills_sync", fromlist=["MANIFEST_FILE"]),
|
||||
"MANIFEST_FILE",
|
||||
with patch(
|
||||
"tools.skills_sync.MANIFEST_FILE",
|
||||
tmp_path / "nonexistent",
|
||||
):
|
||||
result = _read_manifest()
|
||||
assert result == set()
|
||||
assert result == {}
|
||||
|
||||
def test_write_and_read_roundtrip(self, tmp_path):
|
||||
def test_write_and_read_roundtrip_v2(self, tmp_path):
|
||||
manifest_file = tmp_path / ".bundled_manifest"
|
||||
names = {"skill-a", "skill-b", "skill-c"}
|
||||
entries = {"skill-a": "abc123", "skill-b": "def456", "skill-c": "789012"}
|
||||
|
||||
with patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
_write_manifest(names)
|
||||
_write_manifest(entries)
|
||||
result = _read_manifest()
|
||||
|
||||
assert result == names
|
||||
assert result == entries
|
||||
|
||||
def test_write_manifest_sorted(self, tmp_path):
|
||||
manifest_file = tmp_path / ".bundled_manifest"
|
||||
names = {"zebra", "alpha", "middle"}
|
||||
entries = {"zebra": "hash1", "alpha": "hash2", "middle": "hash3"}
|
||||
|
||||
with patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
_write_manifest(names)
|
||||
_write_manifest(entries)
|
||||
|
||||
lines = manifest_file.read_text().strip().splitlines()
|
||||
assert lines == ["alpha", "middle", "zebra"]
|
||||
names = [line.split(":")[0] for line in lines]
|
||||
assert names == ["alpha", "middle", "zebra"]
|
||||
|
||||
def test_read_manifest_ignores_blank_lines(self, tmp_path):
|
||||
def test_read_v1_manifest_migration(self, tmp_path):
|
||||
"""v1 format (plain names, no hashes) should be read with empty hashes."""
|
||||
manifest_file = tmp_path / ".bundled_manifest"
|
||||
manifest_file.write_text("skill-a\n\n \nskill-b\n")
|
||||
manifest_file.write_text("skill-a\nskill-b\n")
|
||||
|
||||
with patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
result = _read_manifest()
|
||||
|
||||
assert result == {"skill-a", "skill-b"}
|
||||
assert result == {"skill-a": "", "skill-b": ""}
|
||||
|
||||
def test_read_manifest_ignores_blank_lines(self, tmp_path):
|
||||
manifest_file = tmp_path / ".bundled_manifest"
|
||||
manifest_file.write_text("skill-a:hash1\n\n \nskill-b:hash2\n")
|
||||
|
||||
with patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
result = _read_manifest()
|
||||
|
||||
assert result == {"skill-a": "hash1", "skill-b": "hash2"}
|
||||
|
||||
def test_read_manifest_mixed_v1_v2(self, tmp_path):
|
||||
"""Manifest with both v1 and v2 lines (shouldn't happen but handle gracefully)."""
|
||||
manifest_file = tmp_path / ".bundled_manifest"
|
||||
manifest_file.write_text("old-skill\nnew-skill:abc123\n")
|
||||
|
||||
with patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
result = _read_manifest()
|
||||
|
||||
assert result == {"old-skill": "", "new-skill": "abc123"}
|
||||
|
||||
|
||||
class TestDirHash:
|
||||
@@ -87,13 +107,10 @@ class TestDirHash:
|
||||
|
||||
class TestDiscoverBundledSkills:
|
||||
def test_finds_skills_with_skill_md(self, tmp_path):
|
||||
# Create two skills
|
||||
(tmp_path / "category" / "skill-a").mkdir(parents=True)
|
||||
(tmp_path / "category" / "skill-a" / "SKILL.md").write_text("# Skill A")
|
||||
(tmp_path / "skill-b").mkdir()
|
||||
(tmp_path / "skill-b" / "SKILL.md").write_text("# Skill B")
|
||||
|
||||
# A directory without SKILL.md — should NOT be found
|
||||
(tmp_path / "not-a-skill").mkdir()
|
||||
(tmp_path / "not-a-skill" / "README.md").write_text("Not a skill")
|
||||
|
||||
@@ -140,105 +157,198 @@ class TestSyncSkills:
|
||||
(bundled / "old-skill" / "SKILL.md").write_text("# Old")
|
||||
return bundled
|
||||
|
||||
def _patches(self, bundled, skills_dir, manifest_file):
|
||||
"""Return context manager stack for patching sync globals."""
|
||||
from contextlib import ExitStack
|
||||
stack = ExitStack()
|
||||
stack.enter_context(patch("tools.skills_sync._get_bundled_dir", return_value=bundled))
|
||||
stack.enter_context(patch("tools.skills_sync.SKILLS_DIR", skills_dir))
|
||||
stack.enter_context(patch("tools.skills_sync.MANIFEST_FILE", manifest_file))
|
||||
return stack
|
||||
|
||||
def test_fresh_install_copies_all(self, tmp_path):
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
with patch("tools.skills_sync._get_bundled_dir", return_value=bundled), \
|
||||
patch("tools.skills_sync.SKILLS_DIR", skills_dir), \
|
||||
patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
assert len(result["copied"]) == 2
|
||||
assert result["total_bundled"] == 2
|
||||
assert result["updated"] == []
|
||||
assert result["user_modified"] == []
|
||||
assert result["cleaned"] == []
|
||||
assert (skills_dir / "category" / "new-skill" / "SKILL.md").exists()
|
||||
assert (skills_dir / "old-skill" / "SKILL.md").exists()
|
||||
# DESCRIPTION.md should also be copied
|
||||
assert (skills_dir / "category" / "DESCRIPTION.md").exists()
|
||||
|
||||
def test_fresh_install_records_origin_hashes(self, tmp_path):
|
||||
"""After fresh install, manifest should have v2 format with hashes."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
sync_skills(quiet=True)
|
||||
manifest = _read_manifest()
|
||||
|
||||
assert "new-skill" in manifest
|
||||
assert "old-skill" in manifest
|
||||
# Hashes should be non-empty MD5 strings
|
||||
assert len(manifest["new-skill"]) == 32
|
||||
assert len(manifest["old-skill"]) == 32
|
||||
|
||||
def test_user_deleted_skill_not_re_added(self, tmp_path):
|
||||
"""Skill in manifest but not on disk = user deleted it. Don't re-add."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
skills_dir.mkdir(parents=True)
|
||||
# old-skill is in manifest but NOT on disk (user deleted it)
|
||||
manifest_file.write_text("old-skill\n")
|
||||
# old-skill is in manifest (v2 format) but NOT on disk
|
||||
old_hash = _dir_hash(bundled / "old-skill")
|
||||
manifest_file.write_text(f"old-skill:{old_hash}\n")
|
||||
|
||||
with patch("tools.skills_sync._get_bundled_dir", return_value=bundled), \
|
||||
patch("tools.skills_sync.SKILLS_DIR", skills_dir), \
|
||||
patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
# new-skill should be copied, old-skill should be skipped
|
||||
assert "new-skill" in result["copied"]
|
||||
assert "old-skill" not in result["copied"]
|
||||
assert "old-skill" not in result.get("updated", [])
|
||||
assert not (skills_dir / "old-skill").exists()
|
||||
|
||||
def test_existing_skill_gets_updated(self, tmp_path):
|
||||
"""Skill in manifest AND on disk with changed content = updated."""
|
||||
def test_unmodified_skill_gets_updated(self, tmp_path):
|
||||
"""Skill in manifest + on disk + user hasn't modified = update from bundled."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
# Pre-create old-skill on disk with DIFFERENT content
|
||||
# Simulate: user has old version that was synced from an older bundled
|
||||
user_skill = skills_dir / "old-skill"
|
||||
user_skill.mkdir(parents=True)
|
||||
(user_skill / "SKILL.md").write_text("# Old version from last sync")
|
||||
# Mark it in the manifest
|
||||
manifest_file.write_text("old-skill\n")
|
||||
(user_skill / "SKILL.md").write_text("# Old v1")
|
||||
old_origin_hash = _dir_hash(user_skill)
|
||||
|
||||
with patch("tools.skills_sync._get_bundled_dir", return_value=bundled), \
|
||||
patch("tools.skills_sync.SKILLS_DIR", skills_dir), \
|
||||
patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
# Record origin hash = hash of what was synced (the old version)
|
||||
manifest_file.write_text(f"old-skill:{old_origin_hash}\n")
|
||||
|
||||
# Now bundled has a newer version ("# Old" != "# Old v1")
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
# old-skill should be updated
|
||||
# Should be updated because user copy matches origin (unmodified)
|
||||
assert "old-skill" in result["updated"]
|
||||
assert (user_skill / "SKILL.md").read_text() == "# Old"
|
||||
|
||||
def test_unchanged_skill_not_updated(self, tmp_path):
|
||||
"""Skill in manifest AND on disk with same content = skipped."""
|
||||
def test_user_modified_skill_not_overwritten(self, tmp_path):
|
||||
"""Skill modified by user should NOT be overwritten even if bundled changed."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
# Pre-create old-skill on disk with SAME content
|
||||
# Simulate: user had the old version synced, then modified it
|
||||
user_skill = skills_dir / "old-skill"
|
||||
user_skill.mkdir(parents=True)
|
||||
(user_skill / "SKILL.md").write_text("# Old v1")
|
||||
old_origin_hash = _dir_hash(user_skill)
|
||||
|
||||
# Record origin hash from what was originally synced
|
||||
manifest_file.write_text(f"old-skill:{old_origin_hash}\n")
|
||||
|
||||
# User modifies their copy
|
||||
(user_skill / "SKILL.md").write_text("# My custom version")
|
||||
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
# Should NOT update — user modified it
|
||||
assert "old-skill" in result["user_modified"]
|
||||
assert "old-skill" not in result.get("updated", [])
|
||||
assert (user_skill / "SKILL.md").read_text() == "# My custom version"
|
||||
|
||||
def test_unchanged_skill_not_updated(self, tmp_path):
|
||||
"""Skill in sync (user == bundled == origin) = no action needed."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
# Copy bundled to user dir (simulating perfect sync state)
|
||||
user_skill = skills_dir / "old-skill"
|
||||
user_skill.mkdir(parents=True)
|
||||
(user_skill / "SKILL.md").write_text("# Old")
|
||||
# Mark it in the manifest
|
||||
manifest_file.write_text("old-skill\n")
|
||||
origin_hash = _dir_hash(user_skill)
|
||||
manifest_file.write_text(f"old-skill:{origin_hash}\n")
|
||||
|
||||
with patch("tools.skills_sync._get_bundled_dir", return_value=bundled), \
|
||||
patch("tools.skills_sync.SKILLS_DIR", skills_dir), \
|
||||
patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
# Should be skipped, not updated
|
||||
assert "old-skill" not in result.get("updated", [])
|
||||
assert "old-skill" not in result.get("user_modified", [])
|
||||
assert result["skipped"] >= 1
|
||||
|
||||
def test_v1_manifest_migration_sets_baseline(self, tmp_path):
|
||||
"""v1 manifest entries (no hash) should set baseline from user's current copy."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
# Pre-create skill on disk
|
||||
user_skill = skills_dir / "old-skill"
|
||||
user_skill.mkdir(parents=True)
|
||||
(user_skill / "SKILL.md").write_text("# Old modified by user")
|
||||
|
||||
# v1 manifest (no hashes)
|
||||
manifest_file.write_text("old-skill\n")
|
||||
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
result = sync_skills(quiet=True)
|
||||
# Should skip (migration baseline set), NOT update
|
||||
assert "old-skill" not in result.get("updated", [])
|
||||
assert "old-skill" not in result.get("user_modified", [])
|
||||
|
||||
# Now check manifest was upgraded to v2 with user's hash as baseline
|
||||
manifest = _read_manifest()
|
||||
assert len(manifest["old-skill"]) == 32 # MD5 hash
|
||||
|
||||
def test_v1_migration_then_bundled_update_detected(self, tmp_path):
|
||||
"""After v1 migration, a subsequent sync should detect bundled updates."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
# User has the SAME content as bundled (in sync)
|
||||
user_skill = skills_dir / "old-skill"
|
||||
user_skill.mkdir(parents=True)
|
||||
(user_skill / "SKILL.md").write_text("# Old")
|
||||
|
||||
# v1 manifest
|
||||
manifest_file.write_text("old-skill\n")
|
||||
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
# First sync: migration — sets baseline
|
||||
sync_skills(quiet=True)
|
||||
|
||||
# Now change bundled content
|
||||
(bundled / "old-skill" / "SKILL.md").write_text("# Old v2 — improved")
|
||||
|
||||
# Second sync: should detect bundled changed + user unmodified → update
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
assert "old-skill" in result["updated"]
|
||||
assert (user_skill / "SKILL.md").read_text() == "# Old v2 — improved"
|
||||
|
||||
def test_stale_manifest_entries_cleaned(self, tmp_path):
|
||||
"""Skills in manifest that no longer exist in bundled dir get cleaned."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
skills_dir.mkdir(parents=True)
|
||||
# Add a stale entry that doesn't exist in bundled
|
||||
manifest_file.write_text("old-skill\nremoved-skill\n")
|
||||
manifest_file.write_text("old-skill:abc123\nremoved-skill:def456\n")
|
||||
|
||||
with patch("tools.skills_sync._get_bundled_dir", return_value=bundled), \
|
||||
patch("tools.skills_sync.SKILLS_DIR", skills_dir), \
|
||||
patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
assert "removed-skill" in result["cleaned"]
|
||||
# Verify manifest no longer has removed-skill
|
||||
with patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
manifest = _read_manifest()
|
||||
assert "removed-skill" not in manifest
|
||||
@@ -249,20 +359,111 @@ class TestSyncSkills:
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
# Pre-create the skill dir with user content (not in manifest)
|
||||
user_skill = skills_dir / "category" / "new-skill"
|
||||
user_skill.mkdir(parents=True)
|
||||
(user_skill / "SKILL.md").write_text("# User modified")
|
||||
|
||||
with patch("tools.skills_sync._get_bundled_dir", return_value=bundled), \
|
||||
patch("tools.skills_sync.SKILLS_DIR", skills_dir), \
|
||||
patch("tools.skills_sync.MANIFEST_FILE", manifest_file):
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
# Should not overwrite user's version
|
||||
assert (user_skill / "SKILL.md").read_text() == "# User modified"
|
||||
|
||||
def test_nonexistent_bundled_dir(self, tmp_path):
|
||||
with patch("tools.skills_sync._get_bundled_dir", return_value=tmp_path / "nope"):
|
||||
result = sync_skills(quiet=True)
|
||||
assert result == {"copied": [], "updated": [], "skipped": 0, "cleaned": [], "total_bundled": 0}
|
||||
assert result == {
|
||||
"copied": [], "updated": [], "skipped": 0,
|
||||
"user_modified": [], "cleaned": [], "total_bundled": 0,
|
||||
}
|
||||
|
||||
def test_failed_copy_does_not_poison_manifest(self, tmp_path):
|
||||
"""If copytree fails, the skill must NOT be added to the manifest.
|
||||
|
||||
Otherwise the next sync treats it as 'user deleted' and never retries.
|
||||
"""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
# Patch copytree to fail for new-skill
|
||||
original_copytree = __import__("shutil").copytree
|
||||
|
||||
def failing_copytree(src, dst, *a, **kw):
|
||||
if "new-skill" in str(src):
|
||||
raise OSError("Simulated disk full")
|
||||
return original_copytree(src, dst, *a, **kw)
|
||||
|
||||
with patch("shutil.copytree", side_effect=failing_copytree):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
# new-skill should NOT be in copied (it failed)
|
||||
assert "new-skill" not in result["copied"]
|
||||
|
||||
# Critical: new-skill must NOT be in the manifest
|
||||
manifest = _read_manifest()
|
||||
assert "new-skill" not in manifest, (
|
||||
"Failed copy was recorded in manifest — next sync will "
|
||||
"treat it as 'user deleted' and never retry"
|
||||
)
|
||||
|
||||
# Now run sync again (copytree works this time) — it should retry
|
||||
result2 = sync_skills(quiet=True)
|
||||
assert "new-skill" in result2["copied"]
|
||||
assert (skills_dir / "category" / "new-skill" / "SKILL.md").exists()
|
||||
|
||||
def test_failed_update_does_not_destroy_user_copy(self, tmp_path):
|
||||
"""If copytree fails during update, the user's existing copy must survive."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
# Start with old synced version
|
||||
user_skill = skills_dir / "old-skill"
|
||||
user_skill.mkdir(parents=True)
|
||||
(user_skill / "SKILL.md").write_text("# Old v1")
|
||||
old_hash = _dir_hash(user_skill)
|
||||
manifest_file.write_text(f"old-skill:{old_hash}\n")
|
||||
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
# Patch copytree to fail (rmtree succeeds, copytree fails)
|
||||
original_copytree = __import__("shutil").copytree
|
||||
|
||||
def failing_copytree(src, dst, *a, **kw):
|
||||
if "old-skill" in str(src):
|
||||
raise OSError("Simulated write failure")
|
||||
return original_copytree(src, dst, *a, **kw)
|
||||
|
||||
with patch("shutil.copytree", side_effect=failing_copytree):
|
||||
result = sync_skills(quiet=True)
|
||||
|
||||
# old-skill should NOT be in updated (it failed)
|
||||
assert "old-skill" not in result.get("updated", [])
|
||||
|
||||
# The skill directory should still exist (rmtree destroyed it
|
||||
# but copytree failed to replace it — this is data loss)
|
||||
assert user_skill.exists(), (
|
||||
"Update failure destroyed user's skill copy without replacing it"
|
||||
)
|
||||
|
||||
def test_update_records_new_origin_hash(self, tmp_path):
|
||||
"""After updating a skill, the manifest should record the new bundled hash."""
|
||||
bundled = self._setup_bundled(tmp_path)
|
||||
skills_dir = tmp_path / "user_skills"
|
||||
manifest_file = skills_dir / ".bundled_manifest"
|
||||
|
||||
# Start with old synced version
|
||||
user_skill = skills_dir / "old-skill"
|
||||
user_skill.mkdir(parents=True)
|
||||
(user_skill / "SKILL.md").write_text("# Old v1")
|
||||
old_hash = _dir_hash(user_skill)
|
||||
manifest_file.write_text(f"old-skill:{old_hash}\n")
|
||||
|
||||
with self._patches(bundled, skills_dir, manifest_file):
|
||||
sync_skills(quiet=True) # updates to "# Old"
|
||||
manifest = _read_manifest()
|
||||
|
||||
# New origin hash should match the bundled version
|
||||
new_bundled_hash = _dir_hash(bundled / "old-skill")
|
||||
assert manifest["old-skill"] == new_bundled_hash
|
||||
assert manifest["old-skill"] != old_hash
|
||||
|
||||
@@ -11,6 +11,7 @@ from tools.skills_tool import (
|
||||
_estimate_tokens,
|
||||
_find_all_skills,
|
||||
_load_category_description,
|
||||
skill_matches_platform,
|
||||
skills_list,
|
||||
skills_categories,
|
||||
skill_view,
|
||||
@@ -332,3 +333,134 @@ class TestSkillsCategories:
|
||||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
assert result["categories"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# skill_matches_platform
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillMatchesPlatform:
|
||||
"""Tests for the platforms frontmatter field filtering."""
|
||||
|
||||
def test_no_platforms_field_matches_everything(self):
|
||||
"""Skills without a platforms field should load on any OS."""
|
||||
assert skill_matches_platform({}) is True
|
||||
assert skill_matches_platform({"name": "foo"}) is True
|
||||
|
||||
def test_empty_platforms_matches_everything(self):
|
||||
"""Empty platforms list should load on any OS."""
|
||||
assert skill_matches_platform({"platforms": []}) is True
|
||||
assert skill_matches_platform({"platforms": None}) is True
|
||||
|
||||
def test_macos_on_darwin(self):
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": ["macos"]}) is True
|
||||
|
||||
def test_macos_on_linux(self):
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["macos"]}) is False
|
||||
|
||||
def test_linux_on_linux(self):
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["linux"]}) is True
|
||||
|
||||
def test_linux_on_darwin(self):
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": ["linux"]}) is False
|
||||
|
||||
def test_windows_on_win32(self):
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "win32"
|
||||
assert skill_matches_platform({"platforms": ["windows"]}) is True
|
||||
|
||||
def test_windows_on_linux(self):
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["windows"]}) is False
|
||||
|
||||
def test_multi_platform_match(self):
|
||||
"""Skills listing multiple platforms should match any of them."""
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": ["macos", "linux"]}) is True
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["macos", "linux"]}) is True
|
||||
mock_sys.platform = "win32"
|
||||
assert skill_matches_platform({"platforms": ["macos", "linux"]}) is False
|
||||
|
||||
def test_string_instead_of_list(self):
|
||||
"""A single string value should be treated as a one-element list."""
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": "macos"}) is True
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": "macos"}) is False
|
||||
|
||||
def test_case_insensitive(self):
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": ["MacOS"]}) is True
|
||||
assert skill_matches_platform({"platforms": ["MACOS"]}) is True
|
||||
|
||||
def test_unknown_platform_no_match(self):
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["freebsd"]}) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_all_skills — platform filtering integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindAllSkillsPlatformFiltering:
|
||||
"""Test that _find_all_skills respects the platforms field."""
|
||||
|
||||
def test_excludes_incompatible_platform(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path), \
|
||||
patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
_make_skill(tmp_path, "universal-skill")
|
||||
_make_skill(tmp_path, "mac-only", frontmatter_extra="platforms: [macos]\n")
|
||||
skills = _find_all_skills()
|
||||
names = {s["name"] for s in skills}
|
||||
assert "universal-skill" in names
|
||||
assert "mac-only" not in names
|
||||
|
||||
def test_includes_matching_platform(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path), \
|
||||
patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
_make_skill(tmp_path, "mac-only", frontmatter_extra="platforms: [macos]\n")
|
||||
skills = _find_all_skills()
|
||||
names = {s["name"] for s in skills}
|
||||
assert "mac-only" in names
|
||||
|
||||
def test_no_platforms_always_included(self, tmp_path):
|
||||
"""Skills without platforms field should appear on any platform."""
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path), \
|
||||
patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "win32"
|
||||
_make_skill(tmp_path, "generic-skill")
|
||||
skills = _find_all_skills()
|
||||
assert len(skills) == 1
|
||||
assert skills[0]["name"] == "generic-skill"
|
||||
|
||||
def test_multi_platform_skill(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path), \
|
||||
patch("tools.skills_tool.sys") as mock_sys:
|
||||
_make_skill(tmp_path, "cross-plat", frontmatter_extra="platforms: [macos, linux]\n")
|
||||
mock_sys.platform = "darwin"
|
||||
skills_darwin = _find_all_skills()
|
||||
mock_sys.platform = "linux"
|
||||
skills_linux = _find_all_skills()
|
||||
mock_sys.platform = "win32"
|
||||
skills_win = _find_all_skills()
|
||||
assert len(skills_darwin) == 1
|
||||
assert len(skills_linux) == 1
|
||||
assert len(skills_win) == 0
|
||||
|
||||
@@ -2,17 +2,23 @@
|
||||
"""
|
||||
Browser Tool Module
|
||||
|
||||
This module provides browser automation tools using agent-browser CLI with
|
||||
Browserbase cloud execution. It enables AI agents to navigate websites,
|
||||
interact with page elements, and extract information in a text-based format.
|
||||
This module provides browser automation tools using agent-browser CLI. It
|
||||
supports two backends — **Browserbase** (cloud) and **local Chromium** — with
|
||||
identical agent-facing behaviour. The backend is auto-detected: if
|
||||
``BROWSERBASE_API_KEY`` is set the cloud service is used; otherwise a local
|
||||
headless Chromium instance is launched automatically.
|
||||
|
||||
The tool uses agent-browser's accessibility tree (ariaSnapshot) for text-based
|
||||
page representation, making it ideal for LLM agents without vision capabilities.
|
||||
|
||||
Features:
|
||||
- Cloud browser execution via Browserbase (no local browser needed)
|
||||
- Basic Stealth Mode always active (random fingerprints, CAPTCHA solving)
|
||||
- Proxies enabled by default for better CAPTCHA solving and anti-bot avoidance
|
||||
- **Local mode** (default): zero-cost headless Chromium via agent-browser.
|
||||
Works on Linux servers without a display. One-time setup:
|
||||
``agent-browser install`` (downloads Chromium) or
|
||||
``agent-browser install --with-deps`` (also installs system libraries for
|
||||
Debian/Ubuntu/Docker).
|
||||
- **Cloud mode**: Browserbase cloud execution with stealth features, proxies,
|
||||
and CAPTCHA solving. Activated when BROWSERBASE_API_KEY is set.
|
||||
- Session isolation per task ID
|
||||
- Text-based page snapshots using accessibility tree
|
||||
- Element interaction via ref selectors (@e1, @e2, etc.)
|
||||
@@ -20,8 +26,8 @@ Features:
|
||||
- Automatic cleanup of browser sessions
|
||||
|
||||
Environment Variables:
|
||||
- BROWSERBASE_API_KEY: API key for Browserbase (required)
|
||||
- BROWSERBASE_PROJECT_ID: Project ID for Browserbase (required)
|
||||
- BROWSERBASE_API_KEY: API key for Browserbase (enables cloud mode)
|
||||
- BROWSERBASE_PROJECT_ID: Project ID for Browserbase (required for cloud mode)
|
||||
- BROWSERBASE_PROXIES: Enable/disable residential proxies (default: "true")
|
||||
- BROWSERBASE_ADVANCED_STEALTH: Enable advanced stealth mode with custom Chromium,
|
||||
requires Scale Plan (default: "false")
|
||||
@@ -77,9 +83,20 @@ SNAPSHOT_SUMMARIZE_THRESHOLD = 8000
|
||||
# Resolve vision auxiliary client for extraction/vision tasks
|
||||
_aux_vision_client, EXTRACTION_MODEL = get_vision_auxiliary_client()
|
||||
|
||||
|
||||
def _is_local_mode() -> bool:
|
||||
"""Return True when no Browserbase credentials are configured.
|
||||
|
||||
In local mode the browser tools launch a headless Chromium instance via
|
||||
``agent-browser --session`` instead of connecting to a remote Browserbase
|
||||
session via ``--cdp``.
|
||||
"""
|
||||
return not (os.environ.get("BROWSERBASE_API_KEY") and os.environ.get("BROWSERBASE_PROJECT_ID"))
|
||||
|
||||
|
||||
# Track active sessions per task
|
||||
# Now stores tuple of (session_name, browserbase_session_id, cdp_url)
|
||||
_active_sessions: Dict[str, Dict[str, str]] = {} # task_id -> {session_name, bb_session_id, cdp_url}
|
||||
# Stores: session_name (always), bb_session_id + cdp_url (cloud mode only)
|
||||
_active_sessions: Dict[str, Dict[str, str]] = {} # task_id -> {session_name, ...}
|
||||
|
||||
# Flag to track if cleanup has been done
|
||||
_cleanup_done = False
|
||||
@@ -120,35 +137,56 @@ def _emergency_cleanup_all_sessions():
|
||||
logger.info("Emergency cleanup: closing %s active session(s)...", len(_active_sessions))
|
||||
|
||||
try:
|
||||
api_key = os.environ.get("BROWSERBASE_API_KEY")
|
||||
project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
|
||||
|
||||
if not api_key or not project_id:
|
||||
logger.warning("Cannot cleanup - missing BROWSERBASE credentials")
|
||||
return
|
||||
|
||||
for task_id, session_info in list(_active_sessions.items()):
|
||||
bb_session_id = session_info.get("bb_session_id")
|
||||
if bb_session_id:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"https://api.browserbase.com/v1/sessions/{bb_session_id}",
|
||||
headers={
|
||||
"X-BB-API-Key": api_key,
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"projectId": project_id,
|
||||
"status": "REQUEST_RELEASE"
|
||||
},
|
||||
timeout=5 # Short timeout for cleanup
|
||||
)
|
||||
if response.status_code in (200, 201, 204):
|
||||
logger.info("Closed session %s", bb_session_id)
|
||||
else:
|
||||
logger.warning("Failed to close session %s: HTTP %s", bb_session_id, response.status_code)
|
||||
except Exception as e:
|
||||
logger.error("Error closing session %s: %s", bb_session_id, e)
|
||||
if _is_local_mode():
|
||||
# Local mode: just close agent-browser sessions via CLI
|
||||
for task_id, session_info in list(_active_sessions.items()):
|
||||
session_name = session_info.get("session_name")
|
||||
if session_name:
|
||||
try:
|
||||
browser_cmd = _find_agent_browser()
|
||||
task_socket_dir = os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
f"agent-browser-{session_name}"
|
||||
)
|
||||
env = {**os.environ, "AGENT_BROWSER_SOCKET_DIR": task_socket_dir}
|
||||
subprocess.run(
|
||||
browser_cmd.split() + ["--session", session_name, "--json", "close"],
|
||||
capture_output=True, timeout=5, env=env,
|
||||
)
|
||||
logger.info("Closed local session %s", session_name)
|
||||
except Exception as e:
|
||||
logger.debug("Error closing local session %s: %s", session_name, e)
|
||||
else:
|
||||
# Cloud mode: release Browserbase sessions via API
|
||||
api_key = os.environ.get("BROWSERBASE_API_KEY")
|
||||
project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
|
||||
|
||||
if not api_key or not project_id:
|
||||
logger.warning("Cannot cleanup - missing BROWSERBASE credentials")
|
||||
return
|
||||
|
||||
for task_id, session_info in list(_active_sessions.items()):
|
||||
bb_session_id = session_info.get("bb_session_id")
|
||||
if bb_session_id:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"https://api.browserbase.com/v1/sessions/{bb_session_id}",
|
||||
headers={
|
||||
"X-BB-API-Key": api_key,
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"projectId": project_id,
|
||||
"status": "REQUEST_RELEASE"
|
||||
},
|
||||
timeout=5 # Short timeout for cleanup
|
||||
)
|
||||
if response.status_code in (200, 201, 204):
|
||||
logger.info("Closed session %s", bb_session_id)
|
||||
else:
|
||||
logger.warning("Failed to close session %s: HTTP %s", bb_session_id, response.status_code)
|
||||
except Exception as e:
|
||||
logger.error("Error closing session %s: %s", bb_session_id, e)
|
||||
|
||||
_active_sessions.clear()
|
||||
except Exception as e:
|
||||
@@ -184,7 +222,7 @@ def _cleanup_inactive_browser_sessions():
|
||||
|
||||
This function is called periodically by the background cleanup thread to
|
||||
automatically close sessions that haven't been used recently, preventing
|
||||
orphaned Browserbase sessions from accumulating.
|
||||
orphaned sessions (local or Browserbase) from accumulating.
|
||||
"""
|
||||
current_time = time.time()
|
||||
sessions_to_cleanup = []
|
||||
@@ -386,7 +424,7 @@ BROWSER_TOOL_SCHEMAS = [
|
||||
},
|
||||
{
|
||||
"name": "browser_vision",
|
||||
"description": "Take a screenshot of the current page and analyze it with vision AI. Use this when you need to visually understand what's on the page - especially useful for CAPTCHAs, visual verification challenges, complex layouts, or when the text snapshot doesn't capture important visual information. Requires browser_navigate to be called first.",
|
||||
"description": "Take a screenshot of the current page and analyze it with vision AI. Use this when you need to visually understand what's on the page - especially useful for CAPTCHAs, visual verification challenges, complex layouts, or when the text snapshot doesn't capture important visual information. Returns both the AI analysis and a screenshot_path that you can share with the user by including MEDIA:<screenshot_path> in your response. Requires browser_navigate to be called first.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -560,11 +598,29 @@ def _create_browserbase_session(task_id: str) -> Dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def _create_local_session(task_id: str) -> Dict[str, str]:
|
||||
"""Create a lightweight local browser session (no cloud API call).
|
||||
|
||||
Returns the same dict shape as ``_create_browserbase_session`` so the rest
|
||||
of the code can treat both modes uniformly.
|
||||
"""
|
||||
import uuid
|
||||
session_name = f"hermes_{task_id}_{uuid.uuid4().hex[:8]}"
|
||||
logger.info("Created local browser session %s", session_name)
|
||||
return {
|
||||
"session_name": session_name,
|
||||
"bb_session_id": None, # Not applicable in local mode
|
||||
"cdp_url": None, # Not applicable in local mode
|
||||
"features": {"local": True},
|
||||
}
|
||||
|
||||
|
||||
def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Get or create session info for the given task.
|
||||
|
||||
Creates a Browserbase session with proxies enabled if one doesn't exist.
|
||||
In cloud mode, creates a Browserbase session with proxies enabled.
|
||||
In local mode, generates a session name for agent-browser --session.
|
||||
Also starts the inactivity cleanup thread and updates activity tracking.
|
||||
Thread-safe: multiple subagents can call this concurrently.
|
||||
|
||||
@@ -572,7 +628,7 @@ def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
||||
task_id: Unique identifier for the task
|
||||
|
||||
Returns:
|
||||
Dict with session_name, bb_session_id, and cdp_url
|
||||
Dict with session_name (always), bb_session_id + cdp_url (cloud only)
|
||||
"""
|
||||
if task_id is None:
|
||||
task_id = "default"
|
||||
@@ -588,8 +644,11 @@ def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
||||
if task_id in _active_sessions:
|
||||
return _active_sessions[task_id]
|
||||
|
||||
# Create session outside the lock (network call - don't hold lock during I/O)
|
||||
session_info = _create_browserbase_session(task_id)
|
||||
# Create session outside the lock (network call in cloud mode)
|
||||
if _is_local_mode():
|
||||
session_info = _create_local_session(task_id)
|
||||
else:
|
||||
session_info = _create_browserbase_session(task_id)
|
||||
|
||||
with _cleanup_lock:
|
||||
_active_sessions[task_id] = session_info
|
||||
@@ -708,12 +767,20 @@ def _run_browser_command(
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Failed to create browser session: {str(e)}"}
|
||||
|
||||
# Connect via CDP to our pre-created Browserbase session.
|
||||
# IMPORTANT: Do NOT use --session with --cdp. In agent-browser >=0.13,
|
||||
# --session creates a local browser instance and silently ignores --cdp.
|
||||
# Per-task isolation is handled by AGENT_BROWSER_SOCKET_DIR instead.
|
||||
cmd_parts = browser_cmd.split() + [
|
||||
"--cdp", session_info["cdp_url"],
|
||||
# Build the command with the appropriate backend flag.
|
||||
# Cloud mode: --cdp <websocket_url> connects to Browserbase.
|
||||
# Local mode: --session <name> launches a local headless Chromium.
|
||||
# The rest of the command (--json, command, args) is identical.
|
||||
if session_info.get("cdp_url"):
|
||||
# Cloud mode — connect to remote Browserbase browser via CDP
|
||||
# IMPORTANT: Do NOT use --session with --cdp. In agent-browser >=0.13,
|
||||
# --session creates a local browser instance and silently ignores --cdp.
|
||||
backend_args = ["--cdp", session_info["cdp_url"]]
|
||||
else:
|
||||
# Local mode — launch a headless Chromium instance
|
||||
backend_args = ["--session", session_info["session_name"]]
|
||||
|
||||
cmd_parts = browser_cmd.split() + backend_args + [
|
||||
"--json",
|
||||
command
|
||||
] + args
|
||||
@@ -728,10 +795,12 @@ def _run_browser_command(
|
||||
)
|
||||
os.makedirs(task_socket_dir, exist_ok=True)
|
||||
|
||||
browser_env = {
|
||||
**os.environ,
|
||||
"AGENT_BROWSER_SOCKET_DIR": task_socket_dir,
|
||||
}
|
||||
browser_env = {**os.environ}
|
||||
# Ensure PATH includes standard dirs (systemd services may have minimal PATH)
|
||||
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
if "/usr/bin" not in browser_env.get("PATH", "").split(":"):
|
||||
browser_env["PATH"] = f"{browser_env.get('PATH', '')}:{_SANE_PATH}"
|
||||
browser_env["AGENT_BROWSER_SOCKET_DIR"] = task_socket_dir
|
||||
|
||||
result = subprocess.run(
|
||||
cmd_parts,
|
||||
@@ -741,10 +810,18 @@ def _run_browser_command(
|
||||
env=browser_env,
|
||||
)
|
||||
|
||||
# Log stderr for diagnostics (agent-browser may emit warnings there)
|
||||
# Log stderr for diagnostics — use warning level on failure so it's visible
|
||||
if result.stderr and result.stderr.strip():
|
||||
logger.debug("stderr from '%s': %s", command, result.stderr.strip()[:200])
|
||||
level = logging.WARNING if result.returncode != 0 else logging.DEBUG
|
||||
logger.log(level, "browser '%s' stderr: %s", command, result.stderr.strip()[:500])
|
||||
|
||||
# Log empty output as warning — common sign of broken agent-browser
|
||||
if not result.stdout.strip() and result.returncode == 0:
|
||||
logger.warning("browser '%s' returned empty stdout with rc=0. "
|
||||
"cmd=%s stderr=%s",
|
||||
command, " ".join(cmd_parts[:4]) + "...",
|
||||
(result.stderr or "")[:200])
|
||||
|
||||
# Parse JSON output
|
||||
if result.stdout.strip():
|
||||
try:
|
||||
@@ -1131,12 +1208,13 @@ def browser_close(task_id: Optional[str] = None) -> str:
|
||||
effective_task_id = task_id or "default"
|
||||
result = _run_browser_command(effective_task_id, "close", [])
|
||||
|
||||
# Close the BrowserBase session via API
|
||||
# Close the backend session (Browserbase API in cloud mode, nothing extra in local mode)
|
||||
session_key = task_id if task_id and task_id in _active_sessions else "default"
|
||||
if session_key in _active_sessions:
|
||||
session_info = _active_sessions[session_key]
|
||||
bb_session_id = session_info.get("bb_session_id")
|
||||
if bb_session_id:
|
||||
# Cloud mode: release the Browserbase session via API
|
||||
try:
|
||||
config = _get_browserbase_config()
|
||||
_close_browserbase_session(bb_session_id, config["api_key"], config["project_id"])
|
||||
@@ -1221,15 +1299,17 @@ def browser_vision(question: str, task_id: Optional[str] = None) -> str:
|
||||
text-based snapshot may not capture (CAPTCHAs, verification challenges,
|
||||
images, complex layouts, etc.).
|
||||
|
||||
The screenshot is saved persistently and its file path is returned alongside
|
||||
the analysis, so it can be shared with users via MEDIA:<path> in the response.
|
||||
|
||||
Args:
|
||||
question: What you want to know about the page visually
|
||||
task_id: Task identifier for session isolation
|
||||
|
||||
Returns:
|
||||
JSON string with vision analysis results
|
||||
JSON string with vision analysis results and screenshot_path
|
||||
"""
|
||||
import base64
|
||||
import tempfile
|
||||
import uuid as uuid_mod
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1243,11 +1323,17 @@ def browser_vision(question: str, task_id: Optional[str] = None) -> str:
|
||||
"Set OPENROUTER_API_KEY or configure Nous Portal to enable browser vision."
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Create a temporary file for the screenshot
|
||||
temp_dir = Path(tempfile.gettempdir())
|
||||
screenshot_path = temp_dir / f"browser_screenshot_{uuid_mod.uuid4().hex}.png"
|
||||
# Save screenshot to persistent location so it can be shared with users
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
screenshots_dir = hermes_home / "browser_screenshots"
|
||||
screenshot_path = screenshots_dir / f"browser_screenshot_{uuid_mod.uuid4().hex}.png"
|
||||
|
||||
try:
|
||||
screenshots_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Prune old screenshots (older than 24 hours) to prevent unbounded disk growth
|
||||
_cleanup_old_screenshots(screenshots_dir, max_age_hours=24)
|
||||
|
||||
# Take screenshot using agent-browser
|
||||
result = _run_browser_command(
|
||||
effective_task_id,
|
||||
@@ -1304,21 +1390,35 @@ def browser_vision(question: str, task_id: Optional[str] = None) -> str:
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"analysis": analysis,
|
||||
"screenshot_path": str(screenshot_path),
|
||||
}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Error during vision analysis: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
finally:
|
||||
# Clean up screenshot file
|
||||
# Clean up screenshot on failure
|
||||
if screenshot_path.exists():
|
||||
try:
|
||||
screenshot_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Error during vision analysis: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _cleanup_old_screenshots(screenshots_dir, max_age_hours=24):
|
||||
"""Remove browser screenshots older than max_age_hours to prevent disk bloat."""
|
||||
import time
|
||||
try:
|
||||
cutoff = time.time() - (max_age_hours * 3600)
|
||||
for f in screenshots_dir.glob("browser_screenshot_*.png"):
|
||||
try:
|
||||
if f.stat().st_mtime < cutoff:
|
||||
f.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass # Non-critical — don't fail the screenshot operation
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -1404,14 +1504,15 @@ def cleanup_browser(task_id: Optional[str] = None) -> None:
|
||||
_active_sessions.pop(task_id, None)
|
||||
_session_last_activity.pop(task_id, None)
|
||||
|
||||
# Close the Browserbase session immediately via API
|
||||
try:
|
||||
config = _get_browserbase_config()
|
||||
success = _close_browserbase_session(bb_session_id, config["api_key"], config["project_id"])
|
||||
if not success:
|
||||
logger.warning("Could not close BrowserBase session %s", bb_session_id)
|
||||
except Exception as e:
|
||||
logger.error("Exception during BrowserBase session close: %s", e)
|
||||
# Cloud mode: close the Browserbase session via API
|
||||
if bb_session_id and not _is_local_mode():
|
||||
try:
|
||||
config = _get_browserbase_config()
|
||||
success = _close_browserbase_session(bb_session_id, config["api_key"], config["project_id"])
|
||||
if not success:
|
||||
logger.warning("Could not close BrowserBase session %s", bb_session_id)
|
||||
except Exception as e:
|
||||
logger.error("Exception during BrowserBase session close: %s", e)
|
||||
|
||||
# Kill the daemon process and clean up socket directory
|
||||
session_name = session_info.get("session_name", "")
|
||||
@@ -1464,24 +1565,31 @@ def get_active_browser_sessions() -> Dict[str, Dict[str, str]]:
|
||||
def check_browser_requirements() -> bool:
|
||||
"""
|
||||
Check if browser tool requirements are met.
|
||||
|
||||
In **local mode** (no Browserbase credentials): only the ``agent-browser``
|
||||
CLI must be findable.
|
||||
|
||||
In **cloud mode** (BROWSERBASE_API_KEY set): the CLI *and* both
|
||||
``BROWSERBASE_API_KEY`` / ``BROWSERBASE_PROJECT_ID`` must be present.
|
||||
|
||||
Returns:
|
||||
True if all requirements are met, False otherwise
|
||||
"""
|
||||
# Check for Browserbase credentials
|
||||
api_key = os.environ.get("BROWSERBASE_API_KEY")
|
||||
project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
|
||||
|
||||
if not api_key or not project_id:
|
||||
return False
|
||||
|
||||
# Check for agent-browser CLI
|
||||
# The agent-browser CLI is always required
|
||||
try:
|
||||
_find_agent_browser()
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
# In cloud mode, also require Browserbase credentials
|
||||
if not _is_local_mode():
|
||||
api_key = os.environ.get("BROWSERBASE_API_KEY")
|
||||
project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
|
||||
if not api_key or not project_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Module Test
|
||||
@@ -1493,20 +1601,26 @@ if __name__ == "__main__":
|
||||
"""
|
||||
print("🌐 Browser Tool Module")
|
||||
print("=" * 40)
|
||||
|
||||
mode = "local" if _is_local_mode() else "cloud (Browserbase)"
|
||||
print(f" Mode: {mode}")
|
||||
|
||||
# Check requirements
|
||||
if check_browser_requirements():
|
||||
print("✅ All requirements met")
|
||||
else:
|
||||
print("❌ Missing requirements:")
|
||||
if not os.environ.get("BROWSERBASE_API_KEY"):
|
||||
print(" - BROWSERBASE_API_KEY not set")
|
||||
if not os.environ.get("BROWSERBASE_PROJECT_ID"):
|
||||
print(" - BROWSERBASE_PROJECT_ID not set")
|
||||
try:
|
||||
_find_agent_browser()
|
||||
except FileNotFoundError:
|
||||
print(" - agent-browser CLI not found")
|
||||
print(" Install: npm install -g agent-browser && agent-browser install --with-deps")
|
||||
if not _is_local_mode():
|
||||
if not os.environ.get("BROWSERBASE_API_KEY"):
|
||||
print(" - BROWSERBASE_API_KEY not set (required for cloud mode)")
|
||||
if not os.environ.get("BROWSERBASE_PROJECT_ID"):
|
||||
print(" - BROWSERBASE_PROJECT_ID not set (required for cloud mode)")
|
||||
print(" Tip: unset BROWSERBASE_API_KEY to use free local mode instead")
|
||||
|
||||
print("\n📋 Available Browser Tools:")
|
||||
for schema in BROWSER_TOOL_SCHEMAS:
|
||||
@@ -1531,7 +1645,6 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_navigate"],
|
||||
handler=lambda args, **kw: browser_navigate(url=args.get("url", ""), task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_snapshot",
|
||||
@@ -1540,7 +1653,6 @@ registry.register(
|
||||
handler=lambda args, **kw: browser_snapshot(
|
||||
full=args.get("full", False), task_id=kw.get("task_id"), user_task=kw.get("user_task")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_click",
|
||||
@@ -1548,7 +1660,6 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_click"],
|
||||
handler=lambda args, **kw: browser_click(**args, task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_type",
|
||||
@@ -1556,7 +1667,6 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_type"],
|
||||
handler=lambda args, **kw: browser_type(**args, task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_scroll",
|
||||
@@ -1564,7 +1674,6 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_scroll"],
|
||||
handler=lambda args, **kw: browser_scroll(**args, task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_back",
|
||||
@@ -1572,7 +1681,6 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_back"],
|
||||
handler=lambda args, **kw: browser_back(task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_press",
|
||||
@@ -1580,7 +1688,6 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_press"],
|
||||
handler=lambda args, **kw: browser_press(key=args.get("key", ""), task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_close",
|
||||
@@ -1588,7 +1695,6 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_close"],
|
||||
handler=lambda args, **kw: browser_close(task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_get_images",
|
||||
@@ -1596,7 +1702,6 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_get_images"],
|
||||
handler=lambda args, **kw: browser_get_images(task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
registry.register(
|
||||
name="browser_vision",
|
||||
@@ -1604,5 +1709,4 @@ registry.register(
|
||||
schema=_BROWSER_SCHEMA_MAP["browser_vision"],
|
||||
handler=lambda args, **kw: browser_vision(question=args.get("question", ""), task_id=kw.get("task_id")),
|
||||
check_fn=check_browser_requirements,
|
||||
requires_env=["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"],
|
||||
)
|
||||
|
||||
@@ -78,7 +78,7 @@ _TOOL_STUBS = {
|
||||
"web_extract": (
|
||||
"web_extract",
|
||||
"urls: list",
|
||||
'"""Extract content from URLs. Returns dict with results list of {url, content, error}."""',
|
||||
'"""Extract content from URLs. Returns dict with results list of {url, title, content, error}."""',
|
||||
'{"urls": urls}',
|
||||
),
|
||||
"read_file": (
|
||||
@@ -435,6 +435,11 @@ def execute_code(
|
||||
child_env[k] = v
|
||||
child_env["HERMES_RPC_SOCKET"] = sock_path
|
||||
child_env["PYTHONDONTWRITEBYTECODE"] = "1"
|
||||
# Inject user's configured timezone so datetime.now() in sandboxed
|
||||
# code reflects the correct wall-clock time.
|
||||
_tz_name = os.getenv("HERMES_TIMEZONE", "").strip()
|
||||
if _tz_name:
|
||||
child_env["TZ"] = _tz_name
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, "script.py"],
|
||||
@@ -592,9 +597,55 @@ def _load_config() -> dict:
|
||||
# OpenAI Function-Calling Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
EXECUTE_CODE_SCHEMA = {
|
||||
"name": "execute_code",
|
||||
"description": (
|
||||
# Per-tool documentation lines for the execute_code description.
|
||||
# Ordered to match the canonical display order.
|
||||
_TOOL_DOC_LINES = [
|
||||
("web_search",
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"),
|
||||
("web_extract",
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
" Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"),
|
||||
("read_file",
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"),
|
||||
("write_file",
|
||||
" write_file(path: str, content: str) -> dict\n"
|
||||
" Always overwrites the entire file."),
|
||||
("search_files",
|
||||
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
|
||||
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"),
|
||||
("patch",
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file."),
|
||||
("terminal",
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"),
|
||||
]
|
||||
|
||||
|
||||
def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict:
|
||||
"""Build the execute_code schema with description listing only enabled tools.
|
||||
|
||||
When tools are disabled via ``hermes tools`` (e.g. web is turned off),
|
||||
the schema description should NOT mention web_search / web_extract —
|
||||
otherwise the model thinks they are available and keeps trying to use them.
|
||||
"""
|
||||
if enabled_sandbox_tools is None:
|
||||
enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS
|
||||
|
||||
# Build tool documentation lines for only the enabled tools
|
||||
tool_lines = "\n".join(
|
||||
doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools
|
||||
)
|
||||
|
||||
# Build example import list from enabled tools
|
||||
import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools]
|
||||
if not import_examples:
|
||||
import_examples = sorted(enabled_sandbox_tools)[:2]
|
||||
import_str = ", ".join(import_examples) + ", ..."
|
||||
|
||||
description = (
|
||||
"Run a Python script that can call Hermes tools programmatically. "
|
||||
"Use this when you need 3+ tool calls with processing logic between them, "
|
||||
"need to filter/reduce large tool outputs before they enter your context, "
|
||||
@@ -603,21 +654,8 @@ EXECUTE_CODE_SCHEMA = {
|
||||
"Use normal tool calls instead when: single tool call with no processing, "
|
||||
"you need to see the full result and apply complex reasoning, "
|
||||
"or the task requires interactive user input.\n\n"
|
||||
"Available via `from hermes_tools import ...`:\n\n"
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}\n"
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
" Returns {\"results\": [{\"url\", \"content\", \"error\"}, ...]} where content is markdown\n"
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}\n"
|
||||
" write_file(path: str, content: str) -> dict\n"
|
||||
" Always overwrites the entire file.\n"
|
||||
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
|
||||
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}\n"
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file.\n"
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}\n\n"
|
||||
f"Available via `from hermes_tools import ...`:\n\n"
|
||||
f"{tool_lines}\n\n"
|
||||
"Limits: 5-minute timeout, 50KB stdout cap, max 50 tool calls per script. "
|
||||
"terminal() is foreground-only (no background or pty).\n\n"
|
||||
"Print your final result to stdout. Use Python stdlib (json, re, math, csv, "
|
||||
@@ -626,22 +664,30 @@ EXECUTE_CODE_SCHEMA = {
|
||||
" json_parse(text: str) — json.loads with strict=False; use for terminal() output with control chars\n"
|
||||
" shell_quote(s: str) — shlex.quote(); use when interpolating dynamic strings into shell commands\n"
|
||||
" retry(fn, max_attempts=3, delay=2) — retry with exponential backoff for transient failures"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Python code to execute. Import tools with "
|
||||
"`from hermes_tools import web_search, terminal, ...` "
|
||||
"and print your final result to stdout."
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"name": "execute_code",
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Python code to execute. Import tools with "
|
||||
f"`from hermes_tools import {import_str}` "
|
||||
"and print your final result to stdout."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Default schema used at registration time (all sandbox tools listed)
|
||||
EXECUTE_CODE_SCHEMA = build_execute_code_schema()
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
|
||||
@@ -174,7 +174,14 @@ def _run_single_child(
|
||||
|
||||
child_start = time.monotonic()
|
||||
|
||||
child_toolsets = _strip_blocked_tools(toolsets or DEFAULT_TOOLSETS)
|
||||
# When no explicit toolsets given, inherit from parent's enabled toolsets
|
||||
# so disabled tools (e.g. web) don't leak to subagents.
|
||||
if toolsets:
|
||||
child_toolsets = _strip_blocked_tools(toolsets)
|
||||
elif parent_agent and getattr(parent_agent, "enabled_toolsets", None):
|
||||
child_toolsets = _strip_blocked_tools(parent_agent.enabled_toolsets)
|
||||
else:
|
||||
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
|
||||
|
||||
child_prompt = _build_child_system_prompt(goal, context)
|
||||
|
||||
@@ -187,6 +194,10 @@ def _run_single_child(
|
||||
# Build progress callback to relay tool calls to parent display
|
||||
child_progress_cb = _build_child_progress_callback(task_index, parent_agent, task_count)
|
||||
|
||||
# Share the parent's iteration budget so subagent tool calls
|
||||
# count toward the session-wide limit.
|
||||
shared_budget = getattr(parent_agent, "iteration_budget", None)
|
||||
|
||||
child = AIAgent(
|
||||
base_url=parent_agent.base_url,
|
||||
api_key=parent_api_key,
|
||||
@@ -194,6 +205,9 @@ def _run_single_child(
|
||||
provider=getattr(parent_agent, "provider", None),
|
||||
api_mode=getattr(parent_agent, "api_mode", None),
|
||||
max_iterations=max_iterations,
|
||||
max_tokens=getattr(parent_agent, "max_tokens", None),
|
||||
reasoning_config=getattr(parent_agent, "reasoning_config", None),
|
||||
prefill_messages=getattr(parent_agent, "prefill_messages", None),
|
||||
enabled_toolsets=child_toolsets,
|
||||
quiet_mode=True,
|
||||
ephemeral_system_prompt=child_prompt,
|
||||
@@ -208,6 +222,7 @@ def _run_single_child(
|
||||
providers_order=parent_agent.providers_order,
|
||||
provider_sort=parent_agent.provider_sort,
|
||||
tool_progress_callback=child_progress_cb,
|
||||
iteration_budget=shared_budget,
|
||||
)
|
||||
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
@@ -281,7 +296,6 @@ def delegate_task(
|
||||
context: Optional[str] = None,
|
||||
toolsets: Optional[List[str]] = None,
|
||||
tasks: Optional[List[Dict[str, Any]]] = None,
|
||||
model: Optional[str] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
parent_agent=None,
|
||||
) -> str:
|
||||
@@ -343,7 +357,7 @@ def delegate_task(
|
||||
goal=t["goal"],
|
||||
context=t.get("context"),
|
||||
toolsets=t.get("toolsets") or toolsets,
|
||||
model=model,
|
||||
model=None,
|
||||
max_iterations=effective_max_iter,
|
||||
parent_agent=parent_agent,
|
||||
task_count=1,
|
||||
@@ -368,7 +382,7 @@ def delegate_task(
|
||||
goal=t["goal"],
|
||||
context=t.get("context"),
|
||||
toolsets=t.get("toolsets") or toolsets,
|
||||
model=model,
|
||||
model=None,
|
||||
max_iterations=effective_max_iter,
|
||||
parent_agent=parent_agent,
|
||||
task_count=n_tasks,
|
||||
@@ -493,7 +507,7 @@ DELEGATE_TASK_SCHEMA = {
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Toolsets to enable for this subagent. "
|
||||
"Default: ['terminal', 'file', 'web']. "
|
||||
"Default: inherits your enabled toolsets. "
|
||||
"Common patterns: ['terminal', 'file'] for code work, "
|
||||
"['web'] for research, ['terminal', 'file', 'web'] for "
|
||||
"full-stack tasks."
|
||||
@@ -521,13 +535,6 @@ DELEGATE_TASK_SCHEMA = {
|
||||
"When provided, top-level goal/context/toolsets are ignored."
|
||||
),
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Model override for the subagent(s). Omit to use your "
|
||||
"same model. Use a cheaper/faster model for simple subtasks."
|
||||
),
|
||||
},
|
||||
"max_iterations": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
@@ -553,7 +560,6 @@ registry.register(
|
||||
context=args.get("context"),
|
||||
toolsets=args.get("toolsets"),
|
||||
tasks=args.get("tasks"),
|
||||
model=args.get("model"),
|
||||
max_iterations=args.get("max_iterations"),
|
||||
parent_agent=kw.get("parent_agent")),
|
||||
check_fn=check_delegate_requirements,
|
||||
|
||||
@@ -17,15 +17,21 @@ from tools.environments.base import BaseEnvironment
|
||||
_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__"
|
||||
|
||||
|
||||
def _find_shell() -> str:
|
||||
"""Find the best shell for command execution.
|
||||
def _find_bash() -> str:
|
||||
"""Find bash for command execution.
|
||||
|
||||
On Unix: uses $SHELL, falls back to bash.
|
||||
The fence wrapper uses bash syntax (semicolons, $?, printf), so we
|
||||
must use bash — not the user's $SHELL which could be fish/zsh/etc.
|
||||
On Windows: uses Git Bash (bundled with Git for Windows).
|
||||
Raises RuntimeError if no suitable shell is found on Windows.
|
||||
"""
|
||||
if not _IS_WINDOWS:
|
||||
return os.environ.get("SHELL") or shutil.which("bash") or "/bin/bash"
|
||||
return (
|
||||
shutil.which("bash")
|
||||
or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None)
|
||||
or ("/bin/bash" if os.path.isfile("/bin/bash") else None)
|
||||
or os.environ.get("SHELL") # last resort: whatever they have
|
||||
or "/bin/sh"
|
||||
)
|
||||
|
||||
# Windows: look for Git Bash (installed with Git for Windows).
|
||||
# Allow override via env var (same pattern as Claude Code).
|
||||
@@ -53,6 +59,11 @@ def _find_shell() -> str:
|
||||
"Or set HERMES_GIT_BASH_PATH to your bash.exe location."
|
||||
)
|
||||
|
||||
|
||||
# Backward compat — process_registry.py imports this name
|
||||
_find_shell = _find_bash
|
||||
|
||||
|
||||
# Noise lines emitted by interactive shells when stdin is not a terminal.
|
||||
# Used as a fallback when output fence markers are missing.
|
||||
_SHELL_NOISE_SUBSTRINGS = (
|
||||
@@ -153,13 +164,11 @@ class LocalEnvironment(BaseEnvironment):
|
||||
exec_command = self._prepare_command(command)
|
||||
|
||||
try:
|
||||
# Use the user's shell as an interactive login shell (-lic) so
|
||||
# that ALL rc files are sourced — including content after the
|
||||
# interactive guard in .bashrc (case $- in *i*)..esac) where
|
||||
# tools like nvm, pyenv, and cargo install their init scripts.
|
||||
# -l alone isn't enough: .profile sources .bashrc, but the guard
|
||||
# returns early because the shell isn't interactive.
|
||||
user_shell = _find_shell()
|
||||
# The fence wrapper uses bash syntax (semicolons, $?, printf).
|
||||
# Always use bash for the wrapper — NOT $SHELL which could be
|
||||
# fish, zsh, or another shell with incompatible syntax.
|
||||
# The -lic flags source rc files so tools like nvm/pyenv work.
|
||||
user_shell = _find_bash()
|
||||
# Wrap with output fences so we can later extract the real
|
||||
# command output and discard shell init/exit noise.
|
||||
fenced_cmd = (
|
||||
@@ -169,11 +178,19 @@ class LocalEnvironment(BaseEnvironment):
|
||||
f" printf '{_OUTPUT_FENCE}';"
|
||||
f" exit $__hermes_rc"
|
||||
)
|
||||
# Ensure PATH always includes standard dirs — systemd services
|
||||
# and some terminal multiplexers inherit a minimal PATH.
|
||||
_SANE_PATH = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
run_env = dict(os.environ | self.env)
|
||||
existing_path = run_env.get("PATH", "")
|
||||
if "/usr/bin" not in existing_path.split(":"):
|
||||
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", fenced_cmd],
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
env=os.environ | self.env,
|
||||
env=run_env,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
stdout=subprocess.PIPE,
|
||||
|
||||
@@ -819,6 +819,14 @@ class ShellFileOperations(FileOperations):
|
||||
# Expand ~ and other shell paths
|
||||
path = self._expand_path(path)
|
||||
|
||||
# Validate that the path exists before searching
|
||||
check = self._exec(f"test -e {self._escape_shell_arg(path)} && echo exists || echo not_found")
|
||||
if "not_found" in check.stdout:
|
||||
return SearchResult(
|
||||
error=f"Path not found: {path}. Verify the path exists (use 'terminal' to check).",
|
||||
total_count=0
|
||||
)
|
||||
|
||||
if target == "files":
|
||||
return self._search_files(pattern, path, limit, offset)
|
||||
else:
|
||||
@@ -919,6 +927,11 @@ class ShellFileOperations(FileOperations):
|
||||
cmd = " ".join(cmd_parts)
|
||||
result = self._exec(cmd, timeout=60)
|
||||
|
||||
# rg exit codes: 0=matches found, 1=no matches, 2=error
|
||||
if result.exit_code == 2 and not result.stdout.strip():
|
||||
error_msg = result.stderr.strip() if hasattr(result, 'stderr') and result.stderr else "Search error"
|
||||
return SearchResult(error=f"Search failed: {error_msg}", total_count=0)
|
||||
|
||||
# Parse results based on output mode
|
||||
if output_mode == "files_only":
|
||||
all_files = [f for f in result.stdout.strip().split('\n') if f]
|
||||
@@ -1013,6 +1026,11 @@ class ShellFileOperations(FileOperations):
|
||||
cmd = " ".join(cmd_parts)
|
||||
result = self._exec(cmd, timeout=60)
|
||||
|
||||
# grep exit codes: 0=matches found, 1=no matches, 2=error
|
||||
if result.exit_code == 2 and not result.stdout.strip():
|
||||
error_msg = result.stderr.strip() if hasattr(result, 'stderr') and result.stderr else "Search error"
|
||||
return SearchResult(error=f"Search failed: {error_msg}", total_count=0)
|
||||
|
||||
if output_mode == "files_only":
|
||||
all_files = [f for f in result.stdout.strip().split('\n') if f]
|
||||
total = len(all_files)
|
||||
|
||||
@@ -31,7 +31,6 @@ Usage:
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
import fal_client
|
||||
@@ -153,10 +152,13 @@ def _validate_parameters(
|
||||
return validated
|
||||
|
||||
|
||||
async def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upscale an image using FAL.ai's Clarity Upscaler.
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues
|
||||
when called from threaded contexts (e.g. gateway thread pool).
|
||||
|
||||
Args:
|
||||
image_url (str): URL of the image to upscale
|
||||
original_prompt (str): Original prompt used to generate the image
|
||||
@@ -180,14 +182,17 @@ async def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
|
||||
}
|
||||
|
||||
# Submit upscaler request
|
||||
handler = await fal_client.submit_async(
|
||||
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
|
||||
# The async API (submit_async) caches a global httpx.AsyncClient via
|
||||
# @cached_property, which breaks when asyncio.run() destroys the loop
|
||||
# between calls (gateway thread-pool pattern).
|
||||
handler = fal_client.submit(
|
||||
UPSCALER_MODEL,
|
||||
arguments=upscaler_arguments
|
||||
)
|
||||
|
||||
# Get the upscaled result
|
||||
result = await handler.get()
|
||||
# Get the upscaled result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
if result and "image" in result:
|
||||
upscaled_image = result["image"]
|
||||
@@ -208,7 +213,7 @@ async def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]
|
||||
return None
|
||||
|
||||
|
||||
async def image_generate_tool(
|
||||
def image_generate_tool(
|
||||
prompt: str,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS,
|
||||
@@ -220,10 +225,10 @@ async def image_generate_tool(
|
||||
"""
|
||||
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
|
||||
|
||||
This tool uses FAL.ai's FLUX 2 Pro model for high-quality text-to-image generation
|
||||
with extensive customization options. Generated images are automatically upscaled 2x
|
||||
using FAL.ai's Clarity Upscaler for enhanced quality. The final upscaled images are
|
||||
returned as URLs that can be displayed using <img src="{URL}"></img> tags.
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
|
||||
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
|
||||
when asyncio.run() destroys and recreates event loops between calls, which
|
||||
happens in the gateway's thread-pool pattern.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt describing the desired image
|
||||
@@ -306,14 +311,14 @@ async def image_generate_tool(
|
||||
logger.info(" Steps: %s", validated_params['num_inference_steps'])
|
||||
logger.info(" Guidance: %s", validated_params['guidance_scale'])
|
||||
|
||||
# Submit request to FAL.ai
|
||||
handler = await fal_client.submit_async(
|
||||
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
|
||||
handler = fal_client.submit(
|
||||
DEFAULT_MODEL,
|
||||
arguments=arguments
|
||||
)
|
||||
|
||||
# Get the result
|
||||
result = await handler.get()
|
||||
# Get the result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
|
||||
@@ -336,7 +341,7 @@ async def image_generate_tool(
|
||||
}
|
||||
|
||||
# Attempt to upscale the image
|
||||
upscaled_image = await _upscale_image(img["url"], prompt.strip())
|
||||
upscaled_image = _upscale_image(img["url"], prompt.strip())
|
||||
|
||||
if upscaled_image:
|
||||
# Use upscaled image if successful
|
||||
@@ -552,5 +557,5 @@ registry.register(
|
||||
handler=_handle_image_generate,
|
||||
check_fn=check_image_generation_requirements,
|
||||
requires_env=["FAL_KEY"],
|
||||
is_async=True,
|
||||
is_async=False, # Switched to sync fal_client API to fix "Event loop is closed" in gateway
|
||||
)
|
||||
|
||||
@@ -993,7 +993,7 @@ async def rl_list_runs() -> str:
|
||||
TEST_MODELS = [
|
||||
{"id": "qwen/qwen3-8b", "name": "Qwen3 8B", "scale": "small"},
|
||||
{"id": "z-ai/glm-4.7-flash", "name": "GLM-4.7 Flash", "scale": "medium"},
|
||||
{"id": "minimax/minimax-m2.1", "name": "MiniMax M2.1", "scale": "large"},
|
||||
{"id": "minimax/minimax-m2.5", "name": "MiniMax M2.5", "scale": "large"},
|
||||
]
|
||||
|
||||
# Default test parameters - quick but representative
|
||||
@@ -1353,7 +1353,7 @@ RL_CHECK_STATUS_SCHEMA = {"name": "rl_check_status", "description": "Get status
|
||||
RL_STOP_TRAINING_SCHEMA = {"name": "rl_stop_training", "description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID to stop"}}, "required": ["run_id"]}}
|
||||
RL_GET_RESULTS_SCHEMA = {"name": "rl_get_results", "description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID to get results for"}}, "required": ["run_id"]}}
|
||||
RL_LIST_RUNS_SCHEMA = {"name": "rl_list_runs", "description": "List all training runs (active and completed) with their status.", "parameters": {"type": "object", "properties": {}, "required": []}}
|
||||
RL_TEST_INFERENCE_SCHEMA = {"name": "rl_test_inference", "description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps x 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.", "parameters": {"type": "object", "properties": {"num_steps": {"type": "integer", "description": "Number of steps to run (default: 3, recommended max for testing)", "default": 3}, "group_size": {"type": "integer", "description": "Completions per step (default: 16, like training)", "default": 16}, "models": {"type": "array", "items": {"type": "string"}, "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.1"}}, "required": []}}
|
||||
RL_TEST_INFERENCE_SCHEMA = {"name": "rl_test_inference", "description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps x 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.", "parameters": {"type": "object", "properties": {"num_steps": {"type": "integer", "description": "Number of steps to run (default: 3, recommended max for testing)", "default": 3}, "group_size": {"type": "integer", "description": "Completions per step (default: 16, like training)", "default": 16}, "models": {"type": "array", "items": {"type": "string"}, "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.5"}}, "required": []}}
|
||||
|
||||
_rl_env = ["TINKER_API_KEY", "WANDB_API_KEY"]
|
||||
|
||||
|
||||
@@ -946,6 +946,11 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult,
|
||||
client = OpenAI(
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
api_key=api_key,
|
||||
default_headers={
|
||||
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
|
||||
"X-OpenRouter-Title": "Hermes Agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
},
|
||||
)
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
|
||||
@@ -3,16 +3,22 @@
|
||||
Skills Sync -- Manifest-based seeding and updating of bundled skills.
|
||||
|
||||
Copies bundled skills from the repo's skills/ directory into ~/.hermes/skills/
|
||||
and uses a manifest to track which skills have been offered.
|
||||
and uses a manifest to track which skills have been synced and their origin hash.
|
||||
|
||||
Behavior:
|
||||
- NEW skills (not in manifest): copied to user dir, added to manifest.
|
||||
- EXISTING skills (in manifest, present in user dir): UPDATED from bundled.
|
||||
- DELETED by user (in manifest, absent from user dir): respected -- not re-added.
|
||||
Manifest format (v2): each line is "skill_name:origin_hash" where origin_hash
|
||||
is the MD5 of the bundled skill at the time it was last synced to the user dir.
|
||||
Old v1 manifests (plain names without hashes) are auto-migrated.
|
||||
|
||||
Update logic:
|
||||
- NEW skills (not in manifest): copied to user dir, origin hash recorded.
|
||||
- EXISTING skills (in manifest, present in user dir):
|
||||
* If user copy matches origin hash: user hasn't modified it → safe to
|
||||
update from bundled if bundled changed. New origin hash recorded.
|
||||
* If user copy differs from origin hash: user customized it → SKIP.
|
||||
- DELETED by user (in manifest, absent from user dir): respected, not re-added.
|
||||
- REMOVED from bundled (in manifest, gone from repo): cleaned from manifest.
|
||||
|
||||
The manifest lives at ~/.hermes/skills/.bundled_manifest and is a simple
|
||||
newline-delimited list of skill names that have been offered to the user.
|
||||
The manifest lives at ~/.hermes/skills/.bundled_manifest.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
@@ -20,7 +26,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,27 +41,38 @@ def _get_bundled_dir() -> Path:
|
||||
return Path(__file__).parent.parent / "skills"
|
||||
|
||||
|
||||
def _read_manifest() -> set:
|
||||
"""Read the set of skill names already offered to the user."""
|
||||
def _read_manifest() -> Dict[str, str]:
|
||||
"""
|
||||
Read the manifest as a dict of {skill_name: origin_hash}.
|
||||
|
||||
Handles both v1 (plain names) and v2 (name:hash) formats.
|
||||
v1 entries get an empty hash string which triggers migration on next sync.
|
||||
"""
|
||||
if not MANIFEST_FILE.exists():
|
||||
return set()
|
||||
return {}
|
||||
try:
|
||||
return set(
|
||||
line.strip()
|
||||
for line in MANIFEST_FILE.read_text(encoding="utf-8").splitlines()
|
||||
if line.strip()
|
||||
)
|
||||
result = {}
|
||||
for line in MANIFEST_FILE.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if ":" in line:
|
||||
# v2 format: name:hash
|
||||
name, _, hash_val = line.partition(":")
|
||||
result[name.strip()] = hash_val.strip()
|
||||
else:
|
||||
# v1 format: plain name — empty hash triggers migration
|
||||
result[line] = ""
|
||||
return result
|
||||
except (OSError, IOError):
|
||||
return set()
|
||||
return {}
|
||||
|
||||
|
||||
def _write_manifest(names: set):
|
||||
"""Write the manifest file."""
|
||||
def _write_manifest(entries: Dict[str, str]):
|
||||
"""Write the manifest file in v2 format (name:hash)."""
|
||||
MANIFEST_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
MANIFEST_FILE.write_text(
|
||||
"\n".join(sorted(names)) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
lines = [f"{name}:{hash_val}" for name, hash_val in sorted(entries.items())]
|
||||
MANIFEST_FILE.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def _discover_bundled_skills(bundled_dir: Path) -> List[Tuple[str, Path]]:
|
||||
@@ -105,18 +122,16 @@ def sync_skills(quiet: bool = False) -> dict:
|
||||
"""
|
||||
Sync bundled skills into ~/.hermes/skills/ using the manifest.
|
||||
|
||||
- NEW skills (not in manifest): copied to user dir, added to manifest.
|
||||
- EXISTING skills (in manifest, present in user dir): updated from bundled.
|
||||
- DELETED by user (in manifest, absent from user dir): respected, not re-added.
|
||||
- REMOVED from bundled (in manifest, gone from repo): cleaned from manifest.
|
||||
|
||||
Returns:
|
||||
dict with keys: copied (list), updated (list), skipped (int),
|
||||
cleaned (list), total_bundled (int)
|
||||
user_modified (list), cleaned (list), total_bundled (int)
|
||||
"""
|
||||
bundled_dir = _get_bundled_dir()
|
||||
if not bundled_dir.exists():
|
||||
return {"copied": [], "updated": [], "skipped": 0, "cleaned": [], "total_bundled": 0}
|
||||
return {
|
||||
"copied": [], "updated": [], "skipped": 0,
|
||||
"user_modified": [], "cleaned": [], "total_bundled": 0,
|
||||
}
|
||||
|
||||
SKILLS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
manifest = _read_manifest()
|
||||
@@ -125,52 +140,88 @@ def sync_skills(quiet: bool = False) -> dict:
|
||||
|
||||
copied = []
|
||||
updated = []
|
||||
user_modified = []
|
||||
skipped = 0
|
||||
|
||||
for skill_name, skill_src in bundled_skills:
|
||||
dest = _compute_relative_dest(skill_src, bundled_dir)
|
||||
bundled_hash = _dir_hash(skill_src)
|
||||
|
||||
if skill_name not in manifest:
|
||||
# New skill -- never offered before
|
||||
# ── New skill — never offered before ──
|
||||
try:
|
||||
if dest.exists():
|
||||
# User already has a skill with the same name (unlikely but possible)
|
||||
# User already has a skill with the same name — don't overwrite
|
||||
skipped += 1
|
||||
manifest[skill_name] = bundled_hash
|
||||
else:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copytree(skill_src, dest)
|
||||
copied.append(skill_name)
|
||||
manifest[skill_name] = bundled_hash
|
||||
if not quiet:
|
||||
print(f" + {skill_name}")
|
||||
except (OSError, IOError) as e:
|
||||
if not quiet:
|
||||
print(f" ! Failed to copy {skill_name}: {e}")
|
||||
manifest.add(skill_name)
|
||||
# Do NOT add to manifest — next sync should retry
|
||||
|
||||
elif dest.exists():
|
||||
# Existing skill in manifest AND on disk -- check for updates
|
||||
src_hash = _dir_hash(skill_src)
|
||||
dst_hash = _dir_hash(dest)
|
||||
if src_hash != dst_hash:
|
||||
# ── Existing skill — in manifest AND on disk ──
|
||||
origin_hash = manifest.get(skill_name, "")
|
||||
user_hash = _dir_hash(dest)
|
||||
|
||||
if not origin_hash:
|
||||
# v1 migration: no origin hash recorded. Set baseline from
|
||||
# user's current copy so future syncs can detect modifications.
|
||||
manifest[skill_name] = user_hash
|
||||
if user_hash == bundled_hash:
|
||||
skipped += 1 # already in sync
|
||||
else:
|
||||
# Can't tell if user modified or bundled changed — be safe
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
if user_hash != origin_hash:
|
||||
# User modified this skill — don't overwrite their changes
|
||||
user_modified.append(skill_name)
|
||||
if not quiet:
|
||||
print(f" ~ {skill_name} (user-modified, skipping)")
|
||||
continue
|
||||
|
||||
# User copy matches origin — check if bundled has a newer version
|
||||
if bundled_hash != origin_hash:
|
||||
try:
|
||||
shutil.rmtree(dest)
|
||||
shutil.copytree(skill_src, dest)
|
||||
updated.append(skill_name)
|
||||
if not quiet:
|
||||
print(f" ↑ {skill_name} (updated)")
|
||||
# Move old copy to a backup so we can restore on failure
|
||||
backup = dest.with_suffix(".bak")
|
||||
shutil.move(str(dest), str(backup))
|
||||
try:
|
||||
shutil.copytree(skill_src, dest)
|
||||
manifest[skill_name] = bundled_hash
|
||||
updated.append(skill_name)
|
||||
if not quiet:
|
||||
print(f" ↑ {skill_name} (updated)")
|
||||
# Remove backup after successful copy
|
||||
shutil.rmtree(backup, ignore_errors=True)
|
||||
except (OSError, IOError):
|
||||
# Restore from backup
|
||||
if backup.exists() and not dest.exists():
|
||||
shutil.move(str(backup), str(dest))
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
if not quiet:
|
||||
print(f" ! Failed to update {skill_name}: {e}")
|
||||
else:
|
||||
skipped += 1
|
||||
skipped += 1 # bundled unchanged, user unchanged
|
||||
|
||||
else:
|
||||
# In manifest but not on disk -- user deleted it, respect that
|
||||
# ── In manifest but not on disk — user deleted it ──
|
||||
skipped += 1
|
||||
|
||||
# Clean stale manifest entries (skills removed from bundled dir)
|
||||
cleaned = sorted(manifest - bundled_names)
|
||||
manifest -= set(cleaned)
|
||||
cleaned = sorted(set(manifest.keys()) - bundled_names)
|
||||
for name in cleaned:
|
||||
del manifest[name]
|
||||
|
||||
# Also copy DESCRIPTION.md files for categories (if not already present)
|
||||
for desc_md in bundled_dir.rglob("DESCRIPTION.md"):
|
||||
@@ -189,6 +240,7 @@ def sync_skills(quiet: bool = False) -> dict:
|
||||
"copied": copied,
|
||||
"updated": updated,
|
||||
"skipped": skipped,
|
||||
"user_modified": user_modified,
|
||||
"cleaned": cleaned,
|
||||
"total_bundled": len(bundled_skills),
|
||||
}
|
||||
@@ -197,6 +249,13 @@ def sync_skills(quiet: bool = False) -> dict:
|
||||
if __name__ == "__main__":
|
||||
print("Syncing bundled skills into ~/.hermes/skills/ ...")
|
||||
result = sync_skills(quiet=False)
|
||||
print(f"\nDone: {len(result['copied'])} new, {len(result['updated'])} updated, "
|
||||
f"{result['skipped']} unchanged, {len(result['cleaned'])} cleaned from manifest, "
|
||||
f"{result['total_bundled']} total bundled.")
|
||||
parts = [
|
||||
f"{len(result['copied'])} new",
|
||||
f"{len(result['updated'])} updated",
|
||||
f"{result['skipped']} unchanged",
|
||||
]
|
||||
if result["user_modified"]:
|
||||
parts.append(f"{len(result['user_modified'])} user-modified (kept)")
|
||||
if result["cleaned"]:
|
||||
parts.append(f"{len(result['cleaned'])} cleaned from manifest")
|
||||
print(f"\nDone: {', '.join(parts)}. {result['total_bundled']} total bundled.")
|
||||
|
||||
@@ -31,6 +31,9 @@ SKILL.md Format (YAML Frontmatter, agentskills.io compatible):
|
||||
description: Brief description # Required, max 1024 chars
|
||||
version: 1.0.0 # Optional
|
||||
license: MIT # Optional (agentskills.io)
|
||||
platforms: [macos] # Optional — restrict to specific OS platforms
|
||||
# Valid: macos, linux, windows
|
||||
# Omit to load on all platforms (default)
|
||||
compatibility: Requires X # Optional (agentskills.io)
|
||||
metadata: # Optional, arbitrary key-value (agentskills.io)
|
||||
hermes:
|
||||
@@ -62,6 +65,7 @@ Usage:
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
@@ -78,6 +82,41 @@ SKILLS_DIR = HERMES_HOME / "skills"
|
||||
MAX_NAME_LENGTH = 64
|
||||
MAX_DESCRIPTION_LENGTH = 1024
|
||||
|
||||
# Platform identifiers for the 'platforms' frontmatter field.
|
||||
# Maps user-friendly names to sys.platform prefixes.
|
||||
_PLATFORM_MAP = {
|
||||
"macos": "darwin",
|
||||
"linux": "linux",
|
||||
"windows": "win32",
|
||||
}
|
||||
|
||||
|
||||
def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool:
|
||||
"""Check if a skill is compatible with the current OS platform.
|
||||
|
||||
Skills declare platform requirements via a top-level ``platforms`` list
|
||||
in their YAML frontmatter::
|
||||
|
||||
platforms: [macos] # macOS only
|
||||
platforms: [macos, linux] # macOS and Linux
|
||||
|
||||
Valid values: ``macos``, ``linux``, ``windows``.
|
||||
|
||||
If the field is absent or empty the skill is compatible with **all**
|
||||
platforms (backward-compatible default).
|
||||
"""
|
||||
platforms = frontmatter.get("platforms")
|
||||
if not platforms:
|
||||
return True # No restriction → loads everywhere
|
||||
if not isinstance(platforms, list):
|
||||
platforms = [platforms]
|
||||
current = sys.platform
|
||||
for p in platforms:
|
||||
mapped = _PLATFORM_MAP.get(str(p).lower().strip(), str(p).lower().strip())
|
||||
if current.startswith(mapped):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_skills_requirements() -> bool:
|
||||
"""Skills are always available -- the directory is created on first use if needed."""
|
||||
@@ -204,6 +243,10 @@ def _find_all_skills() -> List[Dict[str, Any]]:
|
||||
try:
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
|
||||
name = frontmatter.get('name', skill_dir.name)[:MAX_NAME_LENGTH]
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user