mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-07 19:26:56 +08:00
Compare commits
212 Commits
feat/modal
...
add-upstre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f9c02bb37 | ||
|
|
3dbeaea3dc | ||
|
|
26d9b5af29 | ||
|
|
ef8cb9afd2 | ||
|
|
407a1e24b2 | ||
|
|
e1e69dfd32 | ||
|
|
003b6e49df | ||
|
|
dab2cfe566 | ||
|
|
c87bd5dd87 | ||
|
|
2a67e4fa57 | ||
|
|
136a64942d | ||
|
|
9f74d1f2ec | ||
|
|
11ad4173de | ||
|
|
92cb77eaa7 | ||
|
|
c5e8166c8b | ||
|
|
2b88568653 | ||
|
|
34b4fe495e | ||
|
|
4fdd6c0dac | ||
|
|
60b6abefd9 | ||
|
|
4d53b7ccaa | ||
|
|
cd77c7100c | ||
|
|
cf810c2950 | ||
|
|
a23bcb81ce | ||
|
|
d07d867718 | ||
|
|
666f2dd486 | ||
|
|
34792dd907 | ||
|
|
7ad6fc8a40 | ||
|
|
f824c10429 | ||
|
|
132e5ec179 | ||
|
|
66d3e6a0c2 | ||
|
|
4a09ae2985 | ||
|
|
8c734f2f27 | ||
|
|
245d174359 | ||
|
|
77f47768dd | ||
|
|
90fa9e54ca | ||
|
|
9d3a44e0e8 | ||
|
|
932d596466 | ||
|
|
d518f40e8b | ||
|
|
f016cfca46 | ||
|
|
b8120df860 | ||
|
|
0df7df52f3 | ||
|
|
bfa27d0a68 | ||
|
|
5a20c486e3 | ||
|
|
78e19ebc95 | ||
|
|
b383cafc44 | ||
|
|
b10ff83566 | ||
|
|
daa1f542f9 | ||
|
|
d507f593d0 | ||
|
|
f210510276 | ||
|
|
19b6f81ee7 | ||
|
|
76545ab365 | ||
|
|
b8c3bc7841 | ||
|
|
a680367568 | ||
|
|
dfd37a4b31 | ||
|
|
5ee9b67d9b | ||
|
|
542faf225f | ||
|
|
5684c68121 | ||
|
|
4be783446a | ||
|
|
8d719b180a | ||
|
|
bf048c8aec | ||
|
|
c5a9d1ef9d | ||
|
|
c7b6f423c7 | ||
|
|
6d34207167 | ||
|
|
fcde9be10d | ||
|
|
3830bbda41 | ||
|
|
4447e7d71a | ||
|
|
7bccd904c7 | ||
|
|
313d522b61 | ||
|
|
9ee4fe41fe | ||
|
|
39ee3512cb | ||
|
|
42673556af | ||
|
|
faab73ad58 | ||
|
|
7e36468511 | ||
|
|
9ba5d399e5 | ||
|
|
306d92a9d7 | ||
|
|
5baae0df88 | ||
|
|
24f6a193e7 | ||
|
|
8c0f8baf32 | ||
|
|
d80c30cc92 | ||
|
|
e64d646bad | ||
|
|
b84f9e410c | ||
|
|
ee5daba061 | ||
|
|
23e84de830 | ||
|
|
48e0dc8791 | ||
|
|
fb0f579b16 | ||
|
|
5a711f32b1 | ||
|
|
4d34427cc7 | ||
|
|
41877183bc | ||
|
|
451a007fb1 | ||
|
|
0a82396718 | ||
|
|
5da55ea1e3 | ||
|
|
064c009deb | ||
|
|
caab1cf453 | ||
|
|
55c70f3508 | ||
|
|
d29249b8fa | ||
|
|
f668e9fc75 | ||
|
|
74fe1e2254 | ||
|
|
348936752a | ||
|
|
69a36a3361 | ||
|
|
8712dd6d1c | ||
|
|
55a21fe37b | ||
|
|
f55f625277 | ||
|
|
9dac85b069 | ||
|
|
99bd69baa8 | ||
|
|
a62a137a4f | ||
|
|
82b18e8ac2 | ||
|
|
0111c9848d | ||
|
|
ab9cadfeee | ||
|
|
8bf28e1441 | ||
|
|
ce28f847ce | ||
|
|
5609117882 | ||
|
|
b4fbb6fe10 | ||
|
|
82d7e9429e | ||
|
|
e2821effb5 | ||
|
|
9742f11fda | ||
|
|
388dd4789c | ||
|
|
fdebca4573 | ||
|
|
479dfc096a | ||
|
|
3c6c11b7c9 | ||
|
|
bc091eb7ef | ||
|
|
f75b1d21b4 | ||
|
|
94053d75a6 | ||
|
|
2a68099675 | ||
|
|
6cd3bc6640 | ||
|
|
211b55815e | ||
|
|
8ae4a6f824 | ||
|
|
b98301677a | ||
|
|
f2fdde5ba4 | ||
|
|
4f56e31dc7 | ||
|
|
6d3804770c | ||
|
|
ab0f4126cf | ||
|
|
585f8528b2 | ||
|
|
75f523f5c0 | ||
|
|
68fbae5692 | ||
|
|
80f1dd8d37 | ||
|
|
b52b37ae64 | ||
|
|
d63b363cde | ||
|
|
c05c60665e | ||
|
|
b4873a5de7 | ||
|
|
913f8ce0a5 | ||
|
|
4a63737227 | ||
|
|
3e93db16bd | ||
|
|
f863a42351 | ||
|
|
dc55f493be | ||
|
|
936fda3f9e | ||
|
|
ecb8148a9f | ||
|
|
2dbbedc05a | ||
|
|
c30967806c | ||
|
|
145f719d30 | ||
|
|
b89eb29174 | ||
|
|
3670089a42 | ||
|
|
3982fcf095 | ||
|
|
8481fdcf08 | ||
|
|
39299e2de4 | ||
|
|
efec4fcaab | ||
|
|
5ce2c47d60 | ||
|
|
f6f3d1de9b | ||
|
|
ec0fe3242a | ||
|
|
f2e24faaca | ||
|
|
8c80b96318 | ||
|
|
2387465dcc | ||
|
|
32636ecf8a | ||
|
|
6055adbe1b | ||
|
|
ffd2f8dc50 | ||
|
|
e93b4d1dcd | ||
|
|
014a5b712d | ||
|
|
2317d115cd | ||
|
|
8253b54be9 | ||
|
|
5c867fd79f | ||
|
|
a44e041acf | ||
|
|
e9f05b3524 | ||
|
|
e2a834578d | ||
|
|
ffc752a79e | ||
|
|
399562a7d1 | ||
|
|
fec8a0da72 | ||
|
|
9f4542b3db | ||
|
|
363633e2ba | ||
|
|
a41ba57a7a | ||
|
|
884c8ea70a | ||
|
|
c886333d32 | ||
|
|
55b173dd03 | ||
|
|
9079a27814 | ||
|
|
d7d10b14cd | ||
|
|
a6499b6107 | ||
|
|
74a36b0729 | ||
|
|
efc7a7b957 | ||
|
|
4f1464b3af | ||
|
|
3a41079fac | ||
|
|
5279540bb4 | ||
|
|
577da79a47 | ||
|
|
1faa9648d3 | ||
|
|
ad57bf1e4b | ||
|
|
d5efb82c7c | ||
|
|
ea2f7ef2f6 | ||
|
|
435530018b | ||
|
|
df61054a84 | ||
|
|
690b8bb563 | ||
|
|
c43451a50b | ||
|
|
1e312c6582 | ||
|
|
e36c8cd49a | ||
|
|
16cb6d1a6e | ||
|
|
e25ad79d5d | ||
|
|
82cb1752d9 | ||
|
|
3221818b6e | ||
|
|
a1c25046a9 | ||
|
|
d10108f8ca | ||
|
|
8b520f9848 | ||
|
|
a718aed1be | ||
|
|
5f29e7b63c | ||
|
|
5fa3e24b76 | ||
|
|
ac6d747fa6 | ||
|
|
ee541c84f1 |
32
.env.example
32
.env.example
@@ -13,6 +13,38 @@ OPENROUTER_API_KEY=
|
||||
# Examples: anthropic/claude-opus-4.6, openai/gpt-4o, google/gemini-3-flash-preview, zhipuai/glm-4-plus
|
||||
LLM_MODEL=anthropic/claude-opus-4.6
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (z.ai / GLM)
|
||||
# =============================================================================
|
||||
# z.ai provides access to ZhipuAI GLM models (GLM-4-Plus, etc.)
|
||||
# Get your key at: https://z.ai or https://open.bigmodel.cn
|
||||
GLM_API_KEY=
|
||||
# GLM_BASE_URL=https://api.z.ai/api/paas/v4 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Kimi / Moonshot)
|
||||
# =============================================================================
|
||||
# Kimi Code provides access to Moonshot AI coding models (kimi-k2.5, etc.)
|
||||
# Get your key at: https://platform.kimi.ai (Kimi Code console)
|
||||
# Keys prefixed sk-kimi- use the Kimi Code API (api.kimi.com) by default.
|
||||
# Legacy keys from platform.moonshot.ai need KIMI_BASE_URL override below.
|
||||
KIMI_API_KEY=
|
||||
# KIMI_BASE_URL=https://api.kimi.com/coding/v1 # Default for sk-kimi- keys
|
||||
# KIMI_BASE_URL=https://api.moonshot.ai/v1 # For legacy Moonshot keys
|
||||
# KIMI_BASE_URL=https://api.moonshot.cn/v1 # For Moonshot China keys
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (MiniMax)
|
||||
# =============================================================================
|
||||
# MiniMax provides access to MiniMax models (global endpoint)
|
||||
# Get your key at: https://www.minimax.io
|
||||
MINIMAX_API_KEY=
|
||||
# MINIMAX_BASE_URL=https://api.minimax.io/v1 # Override default base URL
|
||||
|
||||
# MiniMax China endpoint (for users in mainland China)
|
||||
MINIMAX_CN_API_KEY=
|
||||
# MINIMAX_CN_BASE_URL=https://api.minimaxi.com/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -47,4 +47,6 @@ cli-config.yaml
|
||||
|
||||
# Skills Hub state (lives in ~/.hermes/skills/.hub/ at runtime, but just in case)
|
||||
skills/.hub/
|
||||
ignored/
|
||||
ignored/
|
||||
.worktrees/
|
||||
environments/benchmarks/evals/
|
||||
|
||||
51
AGENTS.md
51
AGENTS.md
@@ -44,7 +44,8 @@ hermes-agent/
|
||||
│ │ ├── docker.py # Docker container execution
|
||||
│ │ ├── ssh.py # SSH remote execution
|
||||
│ │ ├── singularity.py # Singularity/Apptainer + SIF management
|
||||
│ │ └── modal.py # Modal cloud execution
|
||||
│ │ ├── modal.py # Modal cloud execution
|
||||
│ │ └── daytona.py # Daytona cloud sandboxes
|
||||
│ ├── terminal_tool.py # Terminal orchestration (sudo, lifecycle, factory)
|
||||
│ ├── todo_tool.py # Planning & task management
|
||||
│ ├── process_registry.py # Background process management
|
||||
@@ -55,7 +56,9 @@ hermes-agent/
|
||||
├── cron/ # Scheduler implementation
|
||||
├── environments/ # RL training environments (Atropos integration)
|
||||
├── skills/ # Bundled skill sources
|
||||
├── optional-skills/ # Official optional skills (not activated by default)
|
||||
├── cli.py # Interactive CLI orchestrator (HermesCLI class)
|
||||
├── hermes_state.py # SessionDB — SQLite session store (schema, titles, FTS5 search)
|
||||
├── run_agent.py # AIAgent class (core conversation loop)
|
||||
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
|
||||
├── toolsets.py # Tool groupings
|
||||
@@ -96,7 +99,7 @@ The main agent is implemented in `run_agent.py`:
|
||||
class AIAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4",
|
||||
model: str = "anthropic/claude-sonnet-4.6",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_iterations: int = 60, # Max tool-calling loops
|
||||
@@ -202,7 +205,7 @@ Every installed skill in `~/.hermes/skills/` is automatically registered as a sl
|
||||
The skill name (from frontmatter or folder name) becomes the command: `axolotl` → `/axolotl`.
|
||||
|
||||
Implementation (`agent/skill_commands.py`, shared between CLI and gateway):
|
||||
1. `scan_skill_commands()` scans all SKILL.md files at startup
|
||||
1. `scan_skill_commands()` scans all SKILL.md files at startup, filtering out skills incompatible with the current OS platform (via the `platforms` frontmatter field)
|
||||
2. `build_skill_invocation_message()` loads the SKILL.md content and builds a user-turn message
|
||||
3. The message includes the full skill content, a list of supporting files (not loaded), and the user's instruction
|
||||
4. Supporting files can be loaded on demand via the `skill_view` tool
|
||||
@@ -224,6 +227,10 @@ The unified `hermes` command provides all functionality:
|
||||
|---------|-------------|
|
||||
| `hermes` | Interactive chat (default) |
|
||||
| `hermes chat -q "..."` | Single query mode |
|
||||
| `hermes -c` / `hermes --continue` | Resume the most recent session |
|
||||
| `hermes -c "my project"` | Resume a session by name (latest in lineage) |
|
||||
| `hermes --resume <session_id>` | Resume a specific session by ID or title |
|
||||
| `hermes -w` / `hermes --worktree` | Start in isolated git worktree (for parallel agents) |
|
||||
| `hermes setup` | Configure API keys and settings |
|
||||
| `hermes config` | View current configuration |
|
||||
| `hermes config edit` | Open config in editor |
|
||||
@@ -237,6 +244,8 @@ The unified `hermes` command provides all functionality:
|
||||
| `hermes gateway` | Start gateway (messaging + cron scheduler) |
|
||||
| `hermes gateway setup` | Configure messaging platforms interactively |
|
||||
| `hermes gateway install` | Install gateway as system service |
|
||||
| `hermes sessions list` | List past sessions (title, preview, last active) |
|
||||
| `hermes sessions rename <id> <title>` | Rename/title a session |
|
||||
| `hermes cron list` | View scheduled jobs |
|
||||
| `hermes cron status` | Check if cron scheduler is running |
|
||||
| `hermes version` | Show version info |
|
||||
@@ -421,16 +430,19 @@ The system uses `_config_version` to detect outdated configs:
|
||||
API keys are loaded from `~/.hermes/.env`:
|
||||
- `OPENROUTER_API_KEY` - Main LLM API access (primary provider)
|
||||
- `FIRECRAWL_API_KEY` - Web search/extract tools
|
||||
- `FIRECRAWL_API_URL` - Self-hosted Firecrawl endpoint (optional)
|
||||
- `BROWSERBASE_API_KEY` / `BROWSERBASE_PROJECT_ID` - Browser automation
|
||||
- `FAL_KEY` - Image generation (FLUX model)
|
||||
- `NOUS_API_KEY` - Vision and Mixture-of-Agents tools
|
||||
|
||||
Terminal tool configuration (in `~/.hermes/config.yaml`):
|
||||
- `terminal.backend` - Backend: local, docker, singularity, modal, or ssh
|
||||
- `terminal.backend` - Backend: local, docker, singularity, modal, daytona, or ssh
|
||||
- `terminal.cwd` - Working directory ("." = host CWD for local only; for remote backends set an absolute path inside the target, or omit to use the backend's default)
|
||||
- `terminal.docker_image` - Image for Docker backend
|
||||
- `terminal.singularity_image` - Image for Singularity backend
|
||||
- `terminal.modal_image` - Image for Modal backend
|
||||
- `terminal.daytona_image` - Image for Daytona backend
|
||||
- `DAYTONA_API_KEY` - API key for Daytona backend (in .env)
|
||||
- SSH: `TERMINAL_SSH_HOST`, `TERMINAL_SSH_USER`, `TERMINAL_SSH_KEY` in .env
|
||||
|
||||
Agent behavior (in `~/.hermes/.env`):
|
||||
@@ -494,7 +506,7 @@ terminal(command="pytest -v tests/", background=true)
|
||||
- `process(action="submit", session_id="proc_abc123", data="yes")` -- send + Enter
|
||||
|
||||
**Key behaviors:**
|
||||
- Background processes execute through the configured terminal backend (local/Docker/Modal/SSH/Singularity) -- never directly on the host unless `TERMINAL_ENV=local`
|
||||
- Background processes execute through the configured terminal backend (local/Docker/Modal/Daytona/SSH/Singularity) -- never directly on the host unless `TERMINAL_ENV=local`
|
||||
- The `wait` action blocks the tool call until the process finishes, times out, or is interrupted by a new user message
|
||||
- PTY mode (`pty=true` on terminal) enables interactive CLI tools (Codex, Claude Code)
|
||||
- In RL training, background processes are auto-killed when the episode ends (`tool_context.cleanup()`)
|
||||
@@ -652,6 +664,7 @@ SKILL.md files use YAML frontmatter (agentskills.io format):
|
||||
name: skill-name
|
||||
description: Brief description for listing
|
||||
version: 1.0.0
|
||||
platforms: [macos] # Optional — restrict to specific OS (macos/linux/windows)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [tag1, tag2]
|
||||
@@ -660,16 +673,40 @@ metadata:
|
||||
# Skill Content...
|
||||
```
|
||||
|
||||
**Skills Hub** — user-driven skill search/install from online registries (GitHub, ClawHub, Claude marketplaces, LobeHub). Not exposed as an agent tool — the model cannot search for or install skills. Users manage skills via `hermes skills ...` CLI commands or the `/skills` slash command in chat.
|
||||
**Platform filtering** — Skills with a `platforms` field are automatically excluded from the system prompt index, `skills_list()`, and slash commands on incompatible platforms. Skills without the field load everywhere (backward compatible). See `skills/apple/` for macOS-only examples (iMessage, Reminders, Notes, FindMy).
|
||||
|
||||
**Skills Hub** — user-driven skill search/install from online registries and official optional skills. Sources: official optional skills (shipped with repo, labeled "official"), GitHub (openai/skills, anthropics/skills, custom taps), ClawHub, Claude marketplace, LobeHub. Not exposed as an agent tool — the model cannot search for or install skills. Users manage skills via `hermes skills browse/search/install` CLI commands or the `/skills` slash command in chat.
|
||||
|
||||
Key files:
|
||||
- `tools/skills_tool.py` — Agent-facing skill list/view (progressive disclosure)
|
||||
- `tools/skills_guard.py` — Security scanner (regex + LLM audit, trust-aware install policy)
|
||||
- `tools/skills_hub.py` — Source adapters (GitHub, ClawHub, Claude marketplace, LobeHub), lock file, auth
|
||||
- `tools/skills_hub.py` — Source adapters (OptionalSkillSource, GitHub, ClawHub, Claude marketplace, LobeHub), lock file, auth
|
||||
- `hermes_cli/skills_hub.py` — CLI subcommands + `/skills` slash command handler
|
||||
|
||||
---
|
||||
|
||||
## Known Pitfalls
|
||||
|
||||
### DO NOT use `simple_term_menu` for interactive menus
|
||||
|
||||
`simple_term_menu` has rendering bugs in tmux, iTerm2, and other non-standard terminals. When the user scrolls with arrow keys, previously highlighted items "ghost" — duplicating upward and corrupting the display. This happens because the library uses ANSI cursor-up codes to redraw in place, and tmux/iTerm miscalculate positions when the menu is near the bottom of the viewport.
|
||||
|
||||
**Rule:** All interactive menus in `hermes_cli/` must use `curses` (Python stdlib) instead. See `tools_config.py` for the pattern — both `_prompt_choice()` (single-select) and `_prompt_toolset_checklist()` (multi-select with space toggle) use `curses.wrapper()`. The numbered-input fallback handles Windows where curses isn't available.
|
||||
|
||||
### DO NOT use `\033[K` (ANSI erase-to-EOL) in spinner/display code
|
||||
|
||||
The ANSI escape `\033[K` leaks as literal `?[K` text when `prompt_toolkit`'s `patch_stdout` is active. Use space-padding instead to clear lines: `f"\r{line}{' ' * pad}"`. See `agent/display.py` `KawaiiSpinner`.
|
||||
|
||||
### `_last_resolved_tool_names` is a process-global in `model_tools.py`
|
||||
|
||||
The `execute_code` sandbox uses `_last_resolved_tool_names` (set by `get_tool_definitions()`) to decide which tool stubs to generate. When subagents run with restricted toolsets, they overwrite this global. After delegation returns to the parent, `execute_code` may see the child's restricted list instead of the parent's full list. This is a known bug — `execute_code` calls after delegation may fail with `ImportError: cannot import name 'patch' from 'hermes_tools'`.
|
||||
|
||||
### Tests must not write to `~/.hermes/`
|
||||
|
||||
The `autouse` fixture `_isolate_hermes_home` in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Every test runs in isolation. If you add a test that creates `AIAgent` instances or writes session logs, the fixture handles cleanup automatically. Never hardcode `~/.hermes/` paths in tests.
|
||||
|
||||
---
|
||||
|
||||
## Testing Changes
|
||||
|
||||
After making changes:
|
||||
|
||||
@@ -43,7 +43,9 @@ Bundled skills (in `skills/`) ship with every Hermes install. They should be **b
|
||||
- Document handling, web research, common dev workflows, system administration
|
||||
- Used regularly by a wide range of people
|
||||
|
||||
If your skill is specialized (a niche engineering tool, a specific SaaS integration, a game), it's better suited for a **Skills Hub** — upload it to a skills registry and share it in the [Nous Research Discord](https://discord.gg/NousResearch). Users can install it with `hermes skills install`.
|
||||
If your skill is official and useful but not universally needed (e.g., a paid service integration, a heavyweight dependency), put it in **`optional-skills/`** — it ships with the repo but isn't activated by default. Users can discover it via `hermes skills browse` (labeled "official") and install it with `hermes skills install` (no third-party warning, builtin trust).
|
||||
|
||||
If your skill is specialized, community-contributed, or niche, it's better suited for a **Skills Hub** — upload it to a skills registry and share it in the [Nous Research Discord](https://discord.gg/NousResearch). Users can install it with `hermes skills install`.
|
||||
|
||||
---
|
||||
|
||||
@@ -116,7 +118,7 @@ hermes-agent/
|
||||
├── cli.py # HermesCLI class — interactive TUI, prompt_toolkit integration
|
||||
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
|
||||
├── toolsets.py # Tool groupings and presets (hermes-cli, hermes-telegram, etc.)
|
||||
├── hermes_state.py # SQLite session database with FTS5 full-text search
|
||||
├── hermes_state.py # SQLite session database with FTS5 full-text search, session titles
|
||||
├── batch_runner.py # Parallel batch processing for trajectory generation
|
||||
│
|
||||
├── agent/ # Agent internals (extracted modules)
|
||||
@@ -153,7 +155,7 @@ hermes-agent/
|
||||
│ ├── skill_tools.py # Skill search, load, manage
|
||||
│ └── environments/ # Terminal execution backends
|
||||
│ ├── base.py # BaseEnvironment ABC
|
||||
│ ├── local.py, docker.py, ssh.py, singularity.py, modal.py
|
||||
│ ├── local.py, docker.py, ssh.py, singularity.py, modal.py, daytona.py
|
||||
│
|
||||
├── gateway/ # Messaging gateway
|
||||
│ ├── run.py # GatewayRunner — platform lifecycle, message routing, cron
|
||||
@@ -168,6 +170,7 @@ hermes-agent/
|
||||
│ └── whatsapp-bridge/ # Node.js WhatsApp bridge (Baileys)
|
||||
│
|
||||
├── skills/ # Bundled skills (copied to ~/.hermes/skills/ on install)
|
||||
├── optional-skills/ # Official optional skills (discoverable via hub, not activated by default)
|
||||
├── environments/ # RL training environments (Atropos integration)
|
||||
├── tests/ # Test suite
|
||||
├── website/ # Documentation site (hermes-agent.nousresearch.com)
|
||||
@@ -215,7 +218,7 @@ User message → AIAgent._run_agent_loop()
|
||||
|
||||
- **Self-registering tools**: Each tool file calls `registry.register()` at import time. `model_tools.py` triggers discovery by importing all tool modules.
|
||||
- **Toolset grouping**: Tools are grouped into toolsets (`web`, `terminal`, `file`, `browser`, etc.) that can be enabled/disabled per platform.
|
||||
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search. JSON logs go to `~/.hermes/sessions/`.
|
||||
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search and unique session titles. JSON logs go to `~/.hermes/sessions/`.
|
||||
- **Ephemeral injection**: System prompts and prefill messages are injected at API call time, never persisted to the database or logs.
|
||||
- **Provider abstraction**: The agent works with any OpenAI-compatible API. Provider resolution happens at init time (Nous Portal OAuth, OpenRouter API key, or custom endpoint).
|
||||
- **Provider routing**: When using OpenRouter, `provider_routing` in config.yaml controls provider selection (sort by throughput/latency/price, allow/ignore specific providers, data retention policies). These are injected as `extra_body.provider` in API requests.
|
||||
@@ -294,9 +297,9 @@ If it's a new toolset, add it to `toolsets.py` and to the relevant platform pres
|
||||
|
||||
---
|
||||
|
||||
## Adding a Bundled Skill
|
||||
## Adding a Skill
|
||||
|
||||
Bundled skills live in `skills/` organized by category:
|
||||
Bundled skills live in `skills/` organized by category. Official optional skills use the same structure in `optional-skills/`:
|
||||
|
||||
```
|
||||
skills/
|
||||
@@ -322,6 +325,9 @@ description: Brief description (shown in skill search results)
|
||||
version: 1.0.0
|
||||
author: Your Name
|
||||
license: MIT
|
||||
platforms: [macos, linux] # Optional — restrict to specific OS platforms
|
||||
# Valid: macos, linux, windows
|
||||
# Omit to load on all platforms (default)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Category, Subcategory, Keywords]
|
||||
@@ -348,6 +354,18 @@ Known failure modes and how to handle them.
|
||||
How the agent confirms it worked.
|
||||
```
|
||||
|
||||
### Platform-specific skills
|
||||
|
||||
Skills can declare which OS platforms they support via the `platforms` frontmatter field. Skills with this field are automatically hidden from the system prompt, `skills_list()`, and slash commands on incompatible platforms.
|
||||
|
||||
```yaml
|
||||
platforms: [macos] # macOS only (e.g., iMessage, Apple Reminders)
|
||||
platforms: [macos, linux] # macOS and Linux
|
||||
platforms: [windows] # Windows only
|
||||
```
|
||||
|
||||
If the field is omitted or empty, the skill loads on all platforms (backward compatible). See `skills/apple/` for examples of macOS-only skills.
|
||||
|
||||
### Skill guidelines
|
||||
|
||||
- **No external dependencies unless absolutely necessary.** Prefer stdlib Python, curl, and existing Hermes tools (`web_extract`, `terminal`, `read_file`).
|
||||
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Nous Research
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -11,17 +11,17 @@
|
||||
<a href="https://nousresearch.com"><img src="https://img.shields.io/badge/Built%20by-Nous%20Research-blueviolet?style=for-the-badge" alt="Built by Nous Research"></a>
|
||||
</p>
|
||||
|
||||
**The fully open-source AI agent that grows with you.** Install it on a machine, give it your messaging accounts, and it becomes a persistent personal agent — learning your projects, building its own skills, running tasks on a schedule, and reaching you wherever you are.
|
||||
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
|
||||
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai), OpenAI Codex, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
|
||||
<table>
|
||||
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
|
||||
<tr><td><b>Lives where you do</b></td><td>Telegram, Discord, Slack, WhatsApp, and CLI — all from a single gateway process. Voice memo transcription, cross-platform conversation continuity.</td></tr>
|
||||
<tr><td><b>Grows the longer it runs</b></td><td>Persistent memory across sessions. When it solves a hard problem, it writes a skill document for next time. Skills are searchable, shareable, and compatible with the <a href="https://agentskills.io">agentskills.io</a> open standard.</td></tr>
|
||||
<tr><td><b>A closed learning loop</b></td><td>Agent-curated memory with periodic nudges. Autonomous skill creation after complex tasks. Skills self-improve during use. FTS5 session search with LLM summarization for cross-session recall. <a href="https://github.com/plastic-labs/honcho">Honcho</a> dialectic user modeling. Compatible with the <a href="https://agentskills.io">agentskills.io</a> open standard.</td></tr>
|
||||
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
|
||||
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
|
||||
<tr><td><b>Real sandboxing</b></td><td>Five terminal backends — local, Docker, SSH, Singularity, and Modal — with persistent workspaces and container security hardening.</td></tr>
|
||||
<tr><td><b>Runs anywhere, not just your laptop</b></td><td>Six terminal backends — local, Docker, SSH, Daytona, Singularity, and Modal. Daytona and Modal offer serverless persistence — your agent's environment hibernates when idle and wakes on demand, costing nearly nothing between sessions. Run it on a $5 VPS or a GPU cluster.</td></tr>
|
||||
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, Atropos RL environments, trajectory compression for training the next generation of tool-calling models.</td></tr>
|
||||
</table>
|
||||
|
||||
|
||||
129
TODO.md
129
TODO.md
@@ -1,129 +0,0 @@
|
||||
# Hermes Agent - Future Improvements
|
||||
|
||||
---
|
||||
|
||||
|
||||
|
||||
## 3. Local Browser Control via CDP 🌐
|
||||
|
||||
**Status:** Not started (currently Browserbase cloud only)
|
||||
**Priority:** Medium
|
||||
|
||||
Support local Chrome/Chromium via Chrome DevTools Protocol alongside existing Browserbase cloud backend.
|
||||
|
||||
**What other agents do:**
|
||||
- **OpenClaw**: Full CDP-based Chrome control with snapshots, actions, uploads, profiles, file chooser, PDF save, console messages, tab management. Uses local Chrome for persistent login sessions.
|
||||
- **Cline**: Headless browser with Computer Use (click, type, scroll, screenshot, console logs)
|
||||
|
||||
**Our approach:**
|
||||
- Add a `local` backend option to `browser_tool.py` using Playwright or raw CDP
|
||||
- Config toggle: `browser.backend: local | browserbase | auto`
|
||||
- `auto` mode: try local first, fall back to Browserbase
|
||||
- Local advantages: free, persistent login sessions, no API key needed
|
||||
- Local disadvantages: no CAPTCHA solving, no stealth mode, requires Chrome installed
|
||||
- Reuse the same 10-tool interface -- just swap the backend
|
||||
- Later: Chrome profile management for persistent sessions across restarts
|
||||
|
||||
---
|
||||
|
||||
## 4. Signal Integration 📡
|
||||
|
||||
**Status:** Not started
|
||||
**Priority:** Low
|
||||
|
||||
New platform adapter using signal-cli daemon (JSON-RPC HTTP + SSE). Requires Java runtime and phone number registration.
|
||||
|
||||
**Reference:** OpenClaw has Signal support via signal-cli.
|
||||
|
||||
---
|
||||
|
||||
## 5. Plugin/Extension System 🔌
|
||||
|
||||
**Status:** Partially implemented (event hooks exist in `gateway/hooks.py`)
|
||||
**Priority:** Medium
|
||||
|
||||
Full Python plugin interface that goes beyond the current hook system.
|
||||
|
||||
**What other agents do:**
|
||||
- **OpenClaw**: Plugin SDK with tool-send capabilities, lifecycle phase hooks (before-agent-start, after-tool-call, model-override), plugin registry with install/uninstall.
|
||||
- **Pi**: Extensions are TypeScript modules that can register tools, commands, keyboard shortcuts, custom UI widgets, overlays, status lines, dialogs, compaction hooks, raw terminal input listeners. Extremely comprehensive.
|
||||
- **OpenCode**: MCP client support (stdio, SSE, StreamableHTTP), OAuth auth for MCP servers. Also has Copilot/Codex plugins.
|
||||
- **Codex**: Full MCP integration with skill dependencies.
|
||||
- **Cline**: MCP integration + lifecycle hooks with cancellation support.
|
||||
|
||||
**Our approach (phased):**
|
||||
|
||||
### Phase 1: Enhanced hooks
|
||||
- Expand the existing `gateway/hooks.py` to support more events: `before-tool-call`, `after-tool-call`, `before-response`, `context-compress`, `session-end`
|
||||
- Allow hooks to modify tool results (e.g., filter sensitive output)
|
||||
|
||||
### Phase 2: Plugin interface
|
||||
- `~/.hermes/plugins/<name>/plugin.yaml` + `handler.py`
|
||||
- Plugins can: register new tools, add CLI commands, subscribe to events, inject system prompt sections
|
||||
- `hermes plugin list|install|uninstall|create` CLI commands
|
||||
- Plugin discovery and validation on startup
|
||||
|
||||
### Phase 3: MCP support (industry standard) ✅ DONE
|
||||
- ✅ MCP client that connects to external MCP servers (stdio + HTTP/StreamableHTTP)
|
||||
- ✅ Config: `mcp_servers` in config.yaml with connection details
|
||||
- ✅ Each MCP server's tools auto-registered as a dynamic toolset
|
||||
- Future: Resources, Prompts, Progress notifications, `hermes mcp` CLI command
|
||||
|
||||
---
|
||||
|
||||
## 6. MCP (Model Context Protocol) Support 🔗 ✅ DONE
|
||||
|
||||
**Status:** Implemented (PR #301)
|
||||
**Priority:** Complete
|
||||
|
||||
Native MCP client support with stdio and HTTP/StreamableHTTP transports, auto-discovery, reconnection with exponential backoff, env var filtering, and credential stripping. See `docs/mcp.md` for full documentation.
|
||||
|
||||
**Still TODO:**
|
||||
- `hermes mcp` CLI subcommand (list/test/status)
|
||||
- `hermes tools` UI integration for MCP toolsets
|
||||
- MCP Resources and Prompts support
|
||||
- OAuth authentication for remote servers
|
||||
- Progress notifications for long-running tools
|
||||
|
||||
---
|
||||
|
||||
## 8. Filesystem Checkpointing / Rollback 🔄
|
||||
|
||||
**Status:** Not started
|
||||
**Priority:** Low-Medium
|
||||
|
||||
Automatic filesystem snapshots after each agent loop iteration so the user can roll back destructive changes to their project.
|
||||
|
||||
**What other agents do:**
|
||||
- **Cline**: Workspace checkpoints at each step with Compare/Restore UI
|
||||
- **OpenCode**: Git-backed workspace snapshots per step, with weekly gc
|
||||
- **Codex**: Sandboxed execution with commit-per-step, rollback on failure
|
||||
|
||||
**Our approach:**
|
||||
- After each tool call (or batch of tool calls in a single turn) that modifies files, create a lightweight checkpoint of the affected files
|
||||
- Git-based when the project is a repo: auto-commit to a detached/temporary branch (`hermes/checkpoints/<session>`) after each agent turn, squash or discard on session end
|
||||
- Non-git fallback: tar snapshots of changed files in `~/.hermes/checkpoints/<session_id>/`
|
||||
- `hermes rollback` CLI command to restore to a previous checkpoint
|
||||
- Agent-accessible via a `checkpoint` tool: `list` (show available restore points), `restore` (roll back to a named point), `diff` (show what changed since a checkpoint)
|
||||
- Configurable: off by default (opt-in via `config.yaml`), since auto-committing can be surprising
|
||||
- Cleanup: checkpoints expire after session ends (or configurable retention period)
|
||||
- Integration with the terminal backend: works with local, SSH, and Docker backends (snapshots happen on the execution host)
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority Order
|
||||
|
||||
### Tier 1: Next Up
|
||||
|
||||
1. ~~MCP Support -- #6~~ ✅ Done (PR #301)
|
||||
|
||||
### Tier 2: Quality of Life
|
||||
|
||||
3. Local Browser Control via CDP -- #3
|
||||
4. Plugin/Extension System -- #5
|
||||
|
||||
### Tier 3: Nice to Have
|
||||
|
||||
5. Session Branching / Checkpoints -- #7
|
||||
6. Filesystem Checkpointing / Rollback -- #8
|
||||
7. Signal Integration -- #4
|
||||
@@ -10,7 +10,9 @@ Resolution order for text tasks:
|
||||
3. Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY)
|
||||
4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex,
|
||||
wrapped to look like a chat.completions client)
|
||||
5. None
|
||||
5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN)
|
||||
— checked via PROVIDER_REGISTRY entries with auth_type='api_key'
|
||||
6. None
|
||||
|
||||
Resolution order for vision/multimodal tasks:
|
||||
1. OpenRouter
|
||||
@@ -31,6 +33,14 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"zai": "glm-4.5-flash",
|
||||
"kimi-coding": "kimi-k2-turbo-preview",
|
||||
"minimax": "MiniMax-M2.5-highspeed",
|
||||
"minimax-cn": "MiniMax-M2.5-highspeed",
|
||||
}
|
||||
|
||||
# OpenRouter app attribution headers
|
||||
_OR_HEADERS = {
|
||||
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
|
||||
@@ -282,12 +292,58 @@ def _read_codex_access_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
Returns (client, model) for the first provider whose env var is set,
|
||||
or (None, None) if none are configured.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
except ImportError:
|
||||
logger.debug("Could not import PROVIDER_REGISTRY for API-key fallback")
|
||||
return None, None
|
||||
|
||||
for provider_id, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
# Check if any of the provider's env vars are set
|
||||
api_key = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
break
|
||||
if not api_key:
|
||||
continue
|
||||
# Resolve base URL (with optional env-var override)
|
||||
# Kimi Code keys (sk-kimi-) need api.kimi.com/coding/v1
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
if env_url:
|
||||
base_url = env_url.rstrip("/")
|
||||
elif provider_id == "kimi-coding" and api_key.startswith("sk-kimi-"):
|
||||
base_url = "https://api.kimi.com/coding/v1"
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
def get_text_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, model_slug) for text-only auxiliary tasks.
|
||||
|
||||
Falls through OpenRouter -> Nous Portal -> custom endpoint -> Codex OAuth -> (None, None).
|
||||
Falls through OpenRouter -> Nous Portal -> custom endpoint -> Codex OAuth
|
||||
-> direct API-key providers -> (None, None).
|
||||
"""
|
||||
# 1. OpenRouter
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
@@ -323,7 +379,12 @@ def get_text_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
# 5. Nothing available
|
||||
# 5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, etc.)
|
||||
api_client, api_model = _resolve_api_key_provider()
|
||||
if api_client is not None:
|
||||
return api_client, api_model
|
||||
|
||||
# 6. Nothing available
|
||||
logger.debug("Auxiliary text client: none available")
|
||||
return None, None
|
||||
|
||||
@@ -350,6 +411,8 @@ def get_async_text_auxiliary_client():
|
||||
}
|
||||
if "openrouter" in str(sync_client.base_url).lower():
|
||||
async_kwargs["default_headers"] = dict(_OR_HEADERS)
|
||||
elif "api.kimi.com" in str(sync_client.base_url).lower():
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ protecting head and tail context.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
from agent.model_metadata import (
|
||||
@@ -34,17 +34,20 @@ class ContextCompressor:
|
||||
summary_target_tokens: int = 2500,
|
||||
quiet_mode: bool = False,
|
||||
summary_model_override: str = None,
|
||||
base_url: str = "",
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.threshold_percent = threshold_percent
|
||||
self.protect_first_n = protect_first_n
|
||||
self.protect_last_n = protect_last_n
|
||||
self.summary_target_tokens = summary_target_tokens
|
||||
self.quiet_mode = quiet_mode
|
||||
|
||||
self.context_length = get_model_context_length(model)
|
||||
self.context_length = get_model_context_length(model, base_url=base_url)
|
||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||
self.compression_count = 0
|
||||
self._context_probed = False # True after a step-down from context error
|
||||
|
||||
self.last_prompt_tokens = 0
|
||||
self.last_completion_tokens = 0
|
||||
@@ -79,11 +82,14 @@ class ContextCompressor:
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> str:
|
||||
"""Generate a concise summary of conversation turns using a fast model."""
|
||||
if not self.client:
|
||||
return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed to save space. The assistant performed various actions and received responses."
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
|
||||
Tries the auxiliary model first, then falls back to the user's main
|
||||
model. Returns None if all attempts fail — the caller should drop
|
||||
the middle turns without a summary rather than inject a useless
|
||||
placeholder.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns_to_summarize:
|
||||
role = msg.get("role", "unknown")
|
||||
@@ -114,28 +120,28 @@ TURNS TO SUMMARIZE:
|
||||
|
||||
Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
|
||||
try:
|
||||
return self._call_summary_model(self.client, self.summary_model, prompt)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
|
||||
# 1. Try the auxiliary model (cheap/fast)
|
||||
if self.client:
|
||||
try:
|
||||
return self._call_summary_model(self.client, self.summary_model, prompt)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
|
||||
|
||||
# Fallback: try the main model's endpoint. This handles the common
|
||||
# case where the user switched providers (e.g. OpenRouter → local LLM)
|
||||
# but a stale API key causes the auxiliary client to pick the old
|
||||
# provider which then fails (402, auth error, etc.).
|
||||
fallback_client, fallback_model = self._get_fallback_client()
|
||||
if fallback_client is not None:
|
||||
try:
|
||||
logger.info("Retrying context summary with fallback client (%s)", fallback_model)
|
||||
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
|
||||
# Success — swap in the working client for future compressions
|
||||
self.client = fallback_client
|
||||
self.summary_model = fallback_model
|
||||
return summary
|
||||
except Exception as fallback_err:
|
||||
logging.warning(f"Fallback summary model also failed: {fallback_err}")
|
||||
# 2. Fallback: try the user's main model endpoint
|
||||
fallback_client, fallback_model = self._get_fallback_client()
|
||||
if fallback_client is not None:
|
||||
try:
|
||||
logger.info("Retrying context summary with main model (%s)", fallback_model)
|
||||
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
|
||||
self.client = fallback_client
|
||||
self.summary_model = fallback_model
|
||||
return summary
|
||||
except Exception as fallback_err:
|
||||
logging.warning(f"Main model summary also failed: {fallback_err}")
|
||||
|
||||
return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed. The assistant performed tool calls and received responses."
|
||||
# 3. All models failed — return None so the caller drops turns without a summary
|
||||
logging.warning("Context compression: no model available for summary. Middle turns will be dropped without summary.")
|
||||
return None
|
||||
|
||||
def _call_summary_model(self, client, model: str, prompt: str) -> str:
|
||||
"""Make the actual LLM call to generate a summary. Raises on failure."""
|
||||
@@ -193,10 +199,111 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
logger.debug("Could not build fallback auxiliary client: %s", exc)
|
||||
return None, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call / tool-result pair integrity helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_call_id(tc) -> str:
|
||||
"""Extract the call ID from a tool_call entry (dict or SimpleNamespace)."""
|
||||
if isinstance(tc, dict):
|
||||
return tc.get("id", "")
|
||||
return getattr(tc, "id", "") or ""
|
||||
|
||||
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Fix orphaned tool_call / tool_result pairs after compression.
|
||||
|
||||
Two failure modes:
|
||||
1. A tool *result* references a call_id whose assistant tool_call was
|
||||
removed (summarized/truncated). The API rejects this with
|
||||
"No tool call found for function call output with call_id ...".
|
||||
2. An assistant message has tool_calls whose results were dropped.
|
||||
The API rejects this because every tool_call must be followed by
|
||||
a tool result with the matching call_id.
|
||||
|
||||
This method removes orphaned results and inserts stub results for
|
||||
orphaned calls so the message list is always well-formed.
|
||||
"""
|
||||
surviving_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid:
|
||||
surviving_call_ids.add(cid)
|
||||
|
||||
result_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
cid = msg.get("tool_call_id")
|
||||
if cid:
|
||||
result_call_ids.add(cid)
|
||||
|
||||
# 1. Remove tool results whose call_id has no matching assistant tool_call
|
||||
orphaned_results = result_call_ids - surviving_call_ids
|
||||
if orphaned_results:
|
||||
messages = [
|
||||
m for m in messages
|
||||
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
]
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
|
||||
|
||||
# 2. Add stub results for assistant tool_calls whose results were dropped
|
||||
missing_results = surviving_call_ids - result_call_ids
|
||||
if missing_results:
|
||||
patched: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid in missing_results:
|
||||
patched.append({
|
||||
"role": "tool",
|
||||
"content": "[Result from earlier conversation — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
})
|
||||
messages = patched
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
|
||||
|
||||
return messages
|
||||
|
||||
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Push a compress-start boundary forward past any orphan tool results.
|
||||
|
||||
If ``messages[idx]`` is a tool result, slide forward until we hit a
|
||||
non-tool message so we don't start the summarised region mid-group.
|
||||
"""
|
||||
while idx < len(messages) and messages[idx].get("role") == "tool":
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Pull a compress-end boundary backward to avoid splitting a
|
||||
tool_call / result group.
|
||||
|
||||
If the message just before ``idx`` is an assistant message with
|
||||
tool_calls, those tool results will start at ``idx`` and would be
|
||||
separated from their parent. Move backwards to include the whole
|
||||
group in the summarised region.
|
||||
"""
|
||||
if idx <= 0 or idx >= len(messages):
|
||||
return idx
|
||||
prev = messages[idx - 1]
|
||||
if prev.get("role") == "assistant" and prev.get("tool_calls"):
|
||||
# The results for this assistant turn sit at idx..idx+k.
|
||||
# Include the assistant message in the summarised region too.
|
||||
idx -= 1
|
||||
return idx
|
||||
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Keeps first N + last N turns, summarizes everything in between.
|
||||
After compression, orphaned tool_call / tool_result pairs are cleaned
|
||||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
@@ -209,6 +316,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
# Adjust boundaries to avoid splitting tool_call/result groups.
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
compress_end = self._align_boundary_backward(messages, compress_end)
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
@@ -216,24 +329,6 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
|
||||
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
|
||||
|
||||
# Truncation fallback when no auxiliary model is available
|
||||
if self.client is None:
|
||||
print("⚠️ Context compression: no auxiliary model available. Falling back to message truncation.")
|
||||
# Keep system message(s) at the front and the protected tail;
|
||||
# simply drop the oldest non-system messages until under threshold.
|
||||
kept = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
kept.append(msg.copy())
|
||||
else:
|
||||
break
|
||||
tail = messages[-self.protect_last_n:]
|
||||
kept.extend(m.copy() for m in tail)
|
||||
self.compression_count += 1
|
||||
if not self.quiet_mode:
|
||||
print(f" ✂️ Truncated: {len(messages)} → {len(kept)} messages (dropped middle turns)")
|
||||
return kept
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
|
||||
@@ -246,13 +341,19 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
msg["content"] = (msg.get("content") or "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
|
||||
compressed.append(msg)
|
||||
|
||||
compressed.append({"role": "user", "content": summary})
|
||||
if summary:
|
||||
compressed.append({"role": "user", "content": summary})
|
||||
else:
|
||||
if not self.quiet_mode:
|
||||
print(" ⚠️ No summary model available — middle turns dropped without summary")
|
||||
|
||||
for i in range(compress_end, n_messages):
|
||||
compressed.append(messages[i].copy())
|
||||
|
||||
self.compression_count += 1
|
||||
|
||||
compressed = self._sanitize_tool_pairs(compressed)
|
||||
|
||||
if not self.quiet_mode:
|
||||
new_estimate = estimate_messages_tokens_rough(compressed)
|
||||
saved_estimate = display_tokens - new_estimate
|
||||
|
||||
818
agent/insights.py
Normal file
818
agent/insights.py
Normal file
@@ -0,0 +1,818 @@
|
||||
"""
|
||||
Session Insights Engine for Hermes Agent.
|
||||
|
||||
Analyzes historical session data from the SQLite state database to produce
|
||||
comprehensive usage insights — token consumption, cost estimates, tool usage
|
||||
patterns, activity trends, model/platform breakdowns, and session metrics.
|
||||
|
||||
Inspired by Claude Code's /insights command, adapted for Hermes Agent's
|
||||
multi-platform architecture with additional cost estimation and platform
|
||||
breakdown capabilities.
|
||||
|
||||
Usage:
|
||||
from agent.insights import InsightsEngine
|
||||
engine = InsightsEngine(db)
|
||||
report = engine.generate(days=30)
|
||||
print(engine.format_terminal(report))
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# =========================================================================
|
||||
# Model pricing (USD per million tokens) — approximate as of early 2026
|
||||
# =========================================================================
|
||||
MODEL_PRICING = {
|
||||
# OpenAI
|
||||
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||
"gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
|
||||
"gpt-5": {"input": 10.00, "output": 30.00},
|
||||
"gpt-5.4": {"input": 10.00, "output": 30.00},
|
||||
"o3": {"input": 10.00, "output": 40.00},
|
||||
"o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"o4-mini": {"input": 1.10, "output": 4.40},
|
||||
# Anthropic
|
||||
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
|
||||
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
|
||||
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
|
||||
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
|
||||
# DeepSeek
|
||||
"deepseek-chat": {"input": 0.14, "output": 0.28},
|
||||
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
|
||||
# Google
|
||||
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
|
||||
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
# Meta (via providers)
|
||||
"llama-4-maverick": {"input": 0.50, "output": 0.70},
|
||||
"llama-4-scout": {"input": 0.20, "output": 0.30},
|
||||
# Z.AI / GLM (direct provider — pricing not published externally, treat as local)
|
||||
"glm-5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.7": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
|
||||
# Kimi / Moonshot (direct provider — pricing not published externally, treat as local)
|
||||
"kimi-k2.5": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
|
||||
# MiniMax (direct provider — pricing not published externally, treat as local)
|
||||
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
|
||||
}
|
||||
|
||||
# Fallback: unknown/custom models get zero cost (we can't assume pricing
|
||||
# for self-hosted models, custom OAI endpoints, local inference, etc.)
|
||||
_DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||
|
||||
|
||||
def _has_known_pricing(model_name: str) -> bool:
|
||||
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
|
||||
return _get_pricing(model_name) is not _DEFAULT_PRICING
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
"""Look up pricing for a model. Uses fuzzy matching on model name.
|
||||
|
||||
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
|
||||
we can't assume costs for self-hosted endpoints, local inference, etc.
|
||||
"""
|
||||
if not model_name:
|
||||
return _DEFAULT_PRICING
|
||||
|
||||
# Strip provider prefix (e.g., "anthropic/claude-..." -> "claude-...")
|
||||
bare = model_name.split("/")[-1].lower()
|
||||
|
||||
# Exact match first
|
||||
if bare in MODEL_PRICING:
|
||||
return MODEL_PRICING[bare]
|
||||
|
||||
# Fuzzy prefix match — prefer the LONGEST matching key to avoid
|
||||
# e.g. "gpt-4o" matching before "gpt-4o-mini" for "gpt-4o-mini-2024-07-18"
|
||||
best_match = None
|
||||
best_len = 0
|
||||
for key, price in MODEL_PRICING.items():
|
||||
if bare.startswith(key) and len(key) > best_len:
|
||||
best_match = price
|
||||
best_len = len(key)
|
||||
if best_match:
|
||||
return best_match
|
||||
|
||||
# Keyword heuristics (checked in most-specific-first order)
|
||||
if "opus" in bare:
|
||||
return {"input": 15.00, "output": 75.00}
|
||||
if "sonnet" in bare:
|
||||
return {"input": 3.00, "output": 15.00}
|
||||
if "haiku" in bare:
|
||||
return {"input": 0.80, "output": 4.00}
|
||||
if "gpt-4o-mini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
if "gpt-4o" in bare:
|
||||
return {"input": 2.50, "output": 10.00}
|
||||
if "gpt-5" in bare:
|
||||
return {"input": 10.00, "output": 30.00}
|
||||
if "deepseek" in bare:
|
||||
return {"input": 0.14, "output": 0.28}
|
||||
if "gemini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
|
||||
return _DEFAULT_PRICING
|
||||
|
||||
|
||||
def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
"""Estimate the USD cost for a given model and token counts."""
|
||||
pricing = _get_pricing(model)
|
||||
return (input_tokens * pricing["input"] + output_tokens * pricing["output"]) / 1_000_000
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""Format seconds into a human-readable duration string."""
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
minutes = seconds / 60
|
||||
if minutes < 60:
|
||||
return f"{minutes:.0f}m"
|
||||
hours = minutes / 60
|
||||
if hours < 24:
|
||||
remaining_min = int(minutes % 60)
|
||||
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
|
||||
days = hours / 24
|
||||
return f"{days:.1f}d"
|
||||
|
||||
|
||||
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
|
||||
"""Create simple horizontal bar chart strings from values."""
|
||||
peak = max(values) if values else 1
|
||||
if peak == 0:
|
||||
return ["" for _ in values]
|
||||
return ["█" * max(1, int(v / peak * max_width)) if v > 0 else "" for v in values]
|
||||
|
||||
|
||||
class InsightsEngine:
|
||||
"""
|
||||
Analyzes session history and produces usage insights.
|
||||
|
||||
Works directly with a SessionDB instance (or raw sqlite3 connection)
|
||||
to query session and message data.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
"""
|
||||
Initialize with a SessionDB instance.
|
||||
|
||||
Args:
|
||||
db: A SessionDB instance (from hermes_state.py)
|
||||
"""
|
||||
self.db = db
|
||||
self._conn = db._conn
|
||||
|
||||
def generate(self, days: int = 30, source: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a complete insights report.
|
||||
|
||||
Args:
|
||||
days: Number of days to look back (default: 30)
|
||||
source: Optional filter by source platform
|
||||
|
||||
Returns:
|
||||
Dict with all computed insights
|
||||
"""
|
||||
cutoff = time.time() - (days * 86400)
|
||||
|
||||
# Gather raw data
|
||||
sessions = self._get_sessions(cutoff, source)
|
||||
tool_usage = self._get_tool_usage(cutoff, source)
|
||||
message_stats = self._get_message_stats(cutoff, source)
|
||||
|
||||
if not sessions:
|
||||
return {
|
||||
"days": days,
|
||||
"source_filter": source,
|
||||
"empty": True,
|
||||
"overview": {},
|
||||
"models": [],
|
||||
"platforms": [],
|
||||
"tools": [],
|
||||
"activity": {},
|
||||
"top_sessions": [],
|
||||
}
|
||||
|
||||
# Compute insights
|
||||
overview = self._compute_overview(sessions, message_stats)
|
||||
models = self._compute_model_breakdown(sessions)
|
||||
platforms = self._compute_platform_breakdown(sessions)
|
||||
tools = self._compute_tool_breakdown(tool_usage)
|
||||
activity = self._compute_activity_patterns(sessions)
|
||||
top_sessions = self._compute_top_sessions(sessions)
|
||||
|
||||
return {
|
||||
"days": days,
|
||||
"source_filter": source,
|
||||
"empty": False,
|
||||
"generated_at": time.time(),
|
||||
"overview": overview,
|
||||
"models": models,
|
||||
"platforms": platforms,
|
||||
"tools": tools,
|
||||
"activity": activity,
|
||||
"top_sessions": top_sessions,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Data gathering (SQL queries)
|
||||
# =========================================================================
|
||||
|
||||
# Columns we actually need (skip system_prompt, model_config blobs)
|
||||
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
|
||||
"message_count, tool_call_count, input_tokens, output_tokens")
|
||||
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
"""Fetch sessions within the time window."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
f"""SELECT {self._SESSION_COLS} FROM sessions
|
||||
WHERE started_at >= ? AND source = ?
|
||||
ORDER BY started_at DESC""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
f"""SELECT {self._SESSION_COLS} FROM sessions
|
||||
WHERE started_at >= ?
|
||||
ORDER BY started_at DESC""",
|
||||
(cutoff,),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
"""Get tool call counts from messages.
|
||||
|
||||
Uses two sources:
|
||||
1. tool_name column on 'tool' role messages (set by gateway)
|
||||
2. tool_calls JSON on 'assistant' role messages (covers CLI where
|
||||
tool_name is not populated on tool responses)
|
||||
"""
|
||||
tool_counts = Counter()
|
||||
|
||||
# Source 1: explicit tool_name on tool response messages
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT m.tool_name, COUNT(*) as count
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ? AND s.source = ?
|
||||
AND m.role = 'tool' AND m.tool_name IS NOT NULL
|
||||
GROUP BY m.tool_name
|
||||
ORDER BY count DESC""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT m.tool_name, COUNT(*) as count
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ?
|
||||
AND m.role = 'tool' AND m.tool_name IS NOT NULL
|
||||
GROUP BY m.tool_name
|
||||
ORDER BY count DESC""",
|
||||
(cutoff,),
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
tool_counts[row["tool_name"]] += row["count"]
|
||||
|
||||
# Source 2: extract from tool_calls JSON on assistant messages
|
||||
# (covers CLI sessions where tool_name is NULL on tool responses)
|
||||
if source:
|
||||
cursor2 = self._conn.execute(
|
||||
"""SELECT m.tool_calls
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ? AND s.source = ?
|
||||
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor2 = self._conn.execute(
|
||||
"""SELECT m.tool_calls
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ?
|
||||
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
|
||||
(cutoff,),
|
||||
)
|
||||
|
||||
tool_calls_counts = Counter()
|
||||
for row in cursor2.fetchall():
|
||||
try:
|
||||
calls = row["tool_calls"]
|
||||
if isinstance(calls, str):
|
||||
calls = json.loads(calls)
|
||||
if isinstance(calls, list):
|
||||
for call in calls:
|
||||
func = call.get("function", {}) if isinstance(call, dict) else {}
|
||||
name = func.get("name")
|
||||
if name:
|
||||
tool_calls_counts[name] += 1
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
continue
|
||||
|
||||
# Merge: prefer tool_name source, supplement with tool_calls source
|
||||
# for tools not already counted
|
||||
if not tool_counts and tool_calls_counts:
|
||||
# No tool_name data at all — use tool_calls exclusively
|
||||
tool_counts = tool_calls_counts
|
||||
elif tool_counts and tool_calls_counts:
|
||||
# Both sources have data — use whichever has the higher count per tool
|
||||
# (they may overlap, so take the max to avoid double-counting)
|
||||
all_tools = set(tool_counts) | set(tool_calls_counts)
|
||||
merged = Counter()
|
||||
for tool in all_tools:
|
||||
merged[tool] = max(tool_counts.get(tool, 0), tool_calls_counts.get(tool, 0))
|
||||
tool_counts = merged
|
||||
|
||||
# Convert to the expected format
|
||||
return [
|
||||
{"tool_name": name, "count": count}
|
||||
for name, count in tool_counts.most_common()
|
||||
]
|
||||
|
||||
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
|
||||
"""Get aggregate message statistics."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT
|
||||
COUNT(*) as total_messages,
|
||||
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
|
||||
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
|
||||
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ? AND s.source = ?""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT
|
||||
COUNT(*) as total_messages,
|
||||
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
|
||||
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
|
||||
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
|
||||
FROM messages m
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE s.started_at >= ?""",
|
||||
(cutoff,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else {
|
||||
"total_messages": 0, "user_messages": 0,
|
||||
"assistant_messages": 0, "tool_messages": 0,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Computation
|
||||
# =========================================================================
|
||||
|
||||
def _compute_overview(self, sessions: List[Dict], message_stats: Dict) -> Dict:
|
||||
"""Compute high-level overview statistics."""
|
||||
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
|
||||
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
|
||||
total_tokens = total_input + total_output
|
||||
total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions)
|
||||
total_messages = sum(s.get("message_count") or 0 for s in sessions)
|
||||
|
||||
# Cost estimation (weighted by model)
|
||||
total_cost = 0.0
|
||||
models_with_pricing = set()
|
||||
models_without_pricing = set()
|
||||
for s in sessions:
|
||||
model = s.get("model") or ""
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
total_cost += _estimate_cost(model, inp, out)
|
||||
display = model.split("/")[-1] if "/" in model else (model or "unknown")
|
||||
if _has_known_pricing(model):
|
||||
models_with_pricing.add(display)
|
||||
else:
|
||||
models_without_pricing.add(display)
|
||||
|
||||
# Session duration stats (guard against negative durations from clock drift)
|
||||
durations = []
|
||||
for s in sessions:
|
||||
start = s.get("started_at")
|
||||
end = s.get("ended_at")
|
||||
if start and end and end > start:
|
||||
durations.append(end - start)
|
||||
|
||||
total_hours = sum(durations) / 3600 if durations else 0
|
||||
avg_duration = sum(durations) / len(durations) if durations else 0
|
||||
|
||||
# Earliest and latest session
|
||||
started_timestamps = [s["started_at"] for s in sessions if s.get("started_at")]
|
||||
date_range_start = min(started_timestamps) if started_timestamps else None
|
||||
date_range_end = max(started_timestamps) if started_timestamps else None
|
||||
|
||||
return {
|
||||
"total_sessions": len(sessions),
|
||||
"total_messages": total_messages,
|
||||
"total_tool_calls": total_tool_calls,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_tokens": total_tokens,
|
||||
"estimated_cost": total_cost,
|
||||
"total_hours": total_hours,
|
||||
"avg_session_duration": avg_duration,
|
||||
"avg_messages_per_session": total_messages / len(sessions) if sessions else 0,
|
||||
"avg_tokens_per_session": total_tokens / len(sessions) if sessions else 0,
|
||||
"user_messages": message_stats.get("user_messages") or 0,
|
||||
"assistant_messages": message_stats.get("assistant_messages") or 0,
|
||||
"tool_messages": message_stats.get("tool_messages") or 0,
|
||||
"date_range_start": date_range_start,
|
||||
"date_range_end": date_range_end,
|
||||
"models_with_pricing": sorted(models_with_pricing),
|
||||
"models_without_pricing": sorted(models_without_pricing),
|
||||
}
|
||||
|
||||
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Break down usage by model."""
|
||||
model_data = defaultdict(lambda: {
|
||||
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
|
||||
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
|
||||
})
|
||||
|
||||
for s in sessions:
|
||||
model = s.get("model") or "unknown"
|
||||
# Normalize: strip provider prefix for display
|
||||
display_model = model.split("/")[-1] if "/" in model else model
|
||||
d = model_data[display_model]
|
||||
d["sessions"] += 1
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
d["input_tokens"] += inp
|
||||
d["output_tokens"] += out
|
||||
d["total_tokens"] += inp + out
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
d["cost"] += _estimate_cost(model, inp, out)
|
||||
d["has_pricing"] = _has_known_pricing(model)
|
||||
|
||||
result = [
|
||||
{"model": model, **data}
|
||||
for model, data in model_data.items()
|
||||
]
|
||||
# Sort by tokens first, fall back to session count when tokens are 0
|
||||
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
|
||||
return result
|
||||
|
||||
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Break down usage by platform/source."""
|
||||
platform_data = defaultdict(lambda: {
|
||||
"sessions": 0, "messages": 0, "input_tokens": 0,
|
||||
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
|
||||
})
|
||||
|
||||
for s in sessions:
|
||||
source = s.get("source") or "unknown"
|
||||
d = platform_data[source]
|
||||
d["sessions"] += 1
|
||||
d["messages"] += s.get("message_count") or 0
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
d["input_tokens"] += inp
|
||||
d["output_tokens"] += out
|
||||
d["total_tokens"] += inp + out
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
|
||||
result = [
|
||||
{"platform": platform, **data}
|
||||
for platform, data in platform_data.items()
|
||||
]
|
||||
result.sort(key=lambda x: x["sessions"], reverse=True)
|
||||
return result
|
||||
|
||||
def _compute_tool_breakdown(self, tool_usage: List[Dict]) -> List[Dict]:
|
||||
"""Process tool usage data into a ranked list with percentages."""
|
||||
total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0
|
||||
result = []
|
||||
for t in tool_usage:
|
||||
pct = (t["count"] / total_calls * 100) if total_calls else 0
|
||||
result.append({
|
||||
"tool": t["tool_name"],
|
||||
"count": t["count"],
|
||||
"percentage": pct,
|
||||
})
|
||||
return result
|
||||
|
||||
def _compute_activity_patterns(self, sessions: List[Dict]) -> Dict:
|
||||
"""Analyze activity patterns by day of week and hour."""
|
||||
day_counts = Counter() # 0=Monday ... 6=Sunday
|
||||
hour_counts = Counter()
|
||||
daily_counts = Counter() # date string -> count
|
||||
|
||||
for s in sessions:
|
||||
ts = s.get("started_at")
|
||||
if not ts:
|
||||
continue
|
||||
dt = datetime.fromtimestamp(ts)
|
||||
day_counts[dt.weekday()] += 1
|
||||
hour_counts[dt.hour] += 1
|
||||
daily_counts[dt.strftime("%Y-%m-%d")] += 1
|
||||
|
||||
day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
|
||||
day_breakdown = [
|
||||
{"day": day_names[i], "count": day_counts.get(i, 0)}
|
||||
for i in range(7)
|
||||
]
|
||||
|
||||
hour_breakdown = [
|
||||
{"hour": i, "count": hour_counts.get(i, 0)}
|
||||
for i in range(24)
|
||||
]
|
||||
|
||||
# Busiest day and hour
|
||||
busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None
|
||||
busiest_hour = max(hour_breakdown, key=lambda x: x["count"]) if hour_breakdown else None
|
||||
|
||||
# Active days (days with at least one session)
|
||||
active_days = len(daily_counts)
|
||||
|
||||
# Streak calculation
|
||||
if daily_counts:
|
||||
all_dates = sorted(daily_counts.keys())
|
||||
current_streak = 1
|
||||
max_streak = 1
|
||||
for i in range(1, len(all_dates)):
|
||||
d1 = datetime.strptime(all_dates[i - 1], "%Y-%m-%d")
|
||||
d2 = datetime.strptime(all_dates[i], "%Y-%m-%d")
|
||||
if (d2 - d1).days == 1:
|
||||
current_streak += 1
|
||||
max_streak = max(max_streak, current_streak)
|
||||
else:
|
||||
current_streak = 1
|
||||
else:
|
||||
max_streak = 0
|
||||
|
||||
return {
|
||||
"by_day": day_breakdown,
|
||||
"by_hour": hour_breakdown,
|
||||
"busiest_day": busiest_day,
|
||||
"busiest_hour": busiest_hour,
|
||||
"active_days": active_days,
|
||||
"max_streak": max_streak,
|
||||
}
|
||||
|
||||
def _compute_top_sessions(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Find notable sessions (longest, most messages, most tokens)."""
|
||||
top = []
|
||||
|
||||
# Longest by duration
|
||||
sessions_with_duration = [
|
||||
s for s in sessions
|
||||
if s.get("started_at") and s.get("ended_at")
|
||||
]
|
||||
if sessions_with_duration:
|
||||
longest = max(
|
||||
sessions_with_duration,
|
||||
key=lambda s: (s["ended_at"] - s["started_at"]),
|
||||
)
|
||||
dur = longest["ended_at"] - longest["started_at"]
|
||||
top.append({
|
||||
"label": "Longest session",
|
||||
"session_id": longest["id"][:16],
|
||||
"value": _format_duration(dur),
|
||||
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
|
||||
})
|
||||
|
||||
# Most messages
|
||||
most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0)
|
||||
if (most_msgs.get("message_count") or 0) > 0:
|
||||
top.append({
|
||||
"label": "Most messages",
|
||||
"session_id": most_msgs["id"][:16],
|
||||
"value": f"{most_msgs['message_count']} msgs",
|
||||
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") if most_msgs.get("started_at") else "?",
|
||||
})
|
||||
|
||||
# Most tokens
|
||||
most_tokens = max(
|
||||
sessions,
|
||||
key=lambda s: (s.get("input_tokens") or 0) + (s.get("output_tokens") or 0),
|
||||
)
|
||||
token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0)
|
||||
if token_total > 0:
|
||||
top.append({
|
||||
"label": "Most tokens",
|
||||
"session_id": most_tokens["id"][:16],
|
||||
"value": f"{token_total:,} tokens",
|
||||
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") if most_tokens.get("started_at") else "?",
|
||||
})
|
||||
|
||||
# Most tool calls
|
||||
most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0)
|
||||
if (most_tools.get("tool_call_count") or 0) > 0:
|
||||
top.append({
|
||||
"label": "Most tool calls",
|
||||
"session_id": most_tools["id"][:16],
|
||||
"value": f"{most_tools['tool_call_count']} calls",
|
||||
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") if most_tools.get("started_at") else "?",
|
||||
})
|
||||
|
||||
return top
|
||||
|
||||
# =========================================================================
|
||||
# Formatting
|
||||
# =========================================================================
|
||||
|
||||
def format_terminal(self, report: Dict) -> str:
|
||||
"""Format the insights report for terminal display (CLI)."""
|
||||
if report.get("empty"):
|
||||
days = report.get("days", 30)
|
||||
src = f" (source: {report['source_filter']})" if report.get("source_filter") else ""
|
||||
return f" No sessions found in the last {days} days{src}."
|
||||
|
||||
lines = []
|
||||
o = report["overview"]
|
||||
days = report["days"]
|
||||
src_filter = report.get("source_filter")
|
||||
|
||||
# Header
|
||||
lines.append("")
|
||||
lines.append(" ╔══════════════════════════════════════════════════════════╗")
|
||||
lines.append(" ║ 📊 Hermes Insights ║")
|
||||
period_label = f"Last {days} days"
|
||||
if src_filter:
|
||||
period_label += f" ({src_filter})"
|
||||
padding = 58 - len(period_label) - 2
|
||||
left_pad = padding // 2
|
||||
right_pad = padding - left_pad
|
||||
lines.append(f" ║{' ' * left_pad} {period_label} {' ' * right_pad}║")
|
||||
lines.append(" ╚══════════════════════════════════════════════════════════╝")
|
||||
lines.append("")
|
||||
|
||||
# Date range
|
||||
if o.get("date_range_start") and o.get("date_range_end"):
|
||||
start_str = datetime.fromtimestamp(o["date_range_start"]).strftime("%b %d, %Y")
|
||||
end_str = datetime.fromtimestamp(o["date_range_end"]).strftime("%b %d, %Y")
|
||||
lines.append(f" Period: {start_str} — {end_str}")
|
||||
lines.append("")
|
||||
|
||||
# Overview
|
||||
lines.append(" 📋 Overview")
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
|
||||
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
|
||||
lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}")
|
||||
cost_str = f"${o['estimated_cost']:.2f}"
|
||||
if o.get("models_without_pricing"):
|
||||
cost_str += " *"
|
||||
lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}")
|
||||
if o["total_hours"] > 0:
|
||||
lines.append(f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}")
|
||||
lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}")
|
||||
lines.append("")
|
||||
|
||||
# Model breakdown
|
||||
if report["models"]:
|
||||
lines.append(" 🤖 Models Used")
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" {'Model':<30} {'Sessions':>8} {'Tokens':>12} {'Cost':>8}")
|
||||
for m in report["models"]:
|
||||
model_name = m["model"][:28]
|
||||
if m.get("has_pricing"):
|
||||
cost_cell = f"${m['cost']:>6.2f}"
|
||||
else:
|
||||
cost_cell = " N/A"
|
||||
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
|
||||
if o.get("models_without_pricing"):
|
||||
lines.append(f" * Cost N/A for custom/self-hosted models")
|
||||
lines.append("")
|
||||
|
||||
# Platform breakdown
|
||||
if len(report["platforms"]) > 1 or (report["platforms"] and report["platforms"][0]["platform"] != "cli"):
|
||||
lines.append(" 📱 Platforms")
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" {'Platform':<14} {'Sessions':>8} {'Messages':>10} {'Tokens':>14}")
|
||||
for p in report["platforms"]:
|
||||
lines.append(f" {p['platform']:<14} {p['sessions']:>8} {p['messages']:>10,} {p['total_tokens']:>14,}")
|
||||
lines.append("")
|
||||
|
||||
# Tool usage
|
||||
if report["tools"]:
|
||||
lines.append(" 🔧 Top Tools")
|
||||
lines.append(" " + "─" * 56)
|
||||
lines.append(f" {'Tool':<28} {'Calls':>8} {'%':>8}")
|
||||
for t in report["tools"][:15]: # Top 15
|
||||
lines.append(f" {t['tool']:<28} {t['count']:>8,} {t['percentage']:>7.1f}%")
|
||||
if len(report["tools"]) > 15:
|
||||
lines.append(f" ... and {len(report['tools']) - 15} more tools")
|
||||
lines.append("")
|
||||
|
||||
# Activity patterns
|
||||
act = report.get("activity", {})
|
||||
if act.get("by_day"):
|
||||
lines.append(" 📅 Activity Patterns")
|
||||
lines.append(" " + "─" * 56)
|
||||
|
||||
# Day of week chart
|
||||
day_values = [d["count"] for d in act["by_day"]]
|
||||
bars = _bar_chart(day_values, max_width=15)
|
||||
for i, d in enumerate(act["by_day"]):
|
||||
bar = bars[i]
|
||||
lines.append(f" {d['day']} {bar:<15} {d['count']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Peak hours (show top 5 busiest hours)
|
||||
busy_hours = sorted(act["by_hour"], key=lambda x: x["count"], reverse=True)
|
||||
busy_hours = [h for h in busy_hours if h["count"] > 0][:5]
|
||||
if busy_hours:
|
||||
hour_strs = []
|
||||
for h in busy_hours:
|
||||
hr = h["hour"]
|
||||
ampm = "AM" if hr < 12 else "PM"
|
||||
display_hr = hr % 12 or 12
|
||||
hour_strs.append(f"{display_hr}{ampm} ({h['count']})")
|
||||
lines.append(f" Peak hours: {', '.join(hour_strs)}")
|
||||
|
||||
if act.get("active_days"):
|
||||
lines.append(f" Active days: {act['active_days']}")
|
||||
if act.get("max_streak") and act["max_streak"] > 1:
|
||||
lines.append(f" Best streak: {act['max_streak']} consecutive days")
|
||||
lines.append("")
|
||||
|
||||
# Notable sessions
|
||||
if report.get("top_sessions"):
|
||||
lines.append(" 🏆 Notable Sessions")
|
||||
lines.append(" " + "─" * 56)
|
||||
for ts in report["top_sessions"]:
|
||||
lines.append(f" {ts['label']:<20} {ts['value']:<18} ({ts['date']}, {ts['session_id']})")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_gateway(self, report: Dict) -> str:
|
||||
"""Format the insights report for gateway/messaging (shorter)."""
|
||||
if report.get("empty"):
|
||||
days = report.get("days", 30)
|
||||
return f"No sessions found in the last {days} days."
|
||||
|
||||
lines = []
|
||||
o = report["overview"]
|
||||
days = report["days"]
|
||||
|
||||
lines.append(f"📊 **Hermes Insights** — Last {days} days\n")
|
||||
|
||||
# Overview
|
||||
lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}")
|
||||
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
|
||||
cost_note = ""
|
||||
if o.get("models_without_pricing"):
|
||||
cost_note = " _(excludes custom/self-hosted models)_"
|
||||
lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}")
|
||||
if o["total_hours"] > 0:
|
||||
lines.append(f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}")
|
||||
lines.append("")
|
||||
|
||||
# Models (top 5)
|
||||
if report["models"]:
|
||||
lines.append("**🤖 Models:**")
|
||||
for m in report["models"][:5]:
|
||||
cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A"
|
||||
lines.append(f" {m['model'][:25]} — {m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}")
|
||||
lines.append("")
|
||||
|
||||
# Platforms (if multi-platform)
|
||||
if len(report["platforms"]) > 1:
|
||||
lines.append("**📱 Platforms:**")
|
||||
for p in report["platforms"]:
|
||||
lines.append(f" {p['platform']} — {p['sessions']} sessions, {p['messages']:,} msgs")
|
||||
lines.append("")
|
||||
|
||||
# Tools (top 8)
|
||||
if report["tools"]:
|
||||
lines.append("**🔧 Top Tools:**")
|
||||
for t in report["tools"][:8]:
|
||||
lines.append(f" {t['tool']} — {t['count']:,} calls ({t['percentage']:.1f}%)")
|
||||
lines.append("")
|
||||
|
||||
# Activity summary
|
||||
act = report.get("activity", {})
|
||||
if act.get("busiest_day") and act.get("busiest_hour"):
|
||||
hr = act["busiest_hour"]["hour"]
|
||||
ampm = "AM" if hr < 12 else "PM"
|
||||
display_hr = hr % 12 or 12
|
||||
lines.append(f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)")
|
||||
if act.get("active_days"):
|
||||
lines.append(f"**Active days:** {act['active_days']}", )
|
||||
if act.get("max_streak", 0) > 1:
|
||||
lines.append(f"**Best streak:** {act['max_streak']} consecutive days")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -5,10 +5,14 @@ and run_agent.py for pre-flight context checks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
@@ -18,6 +22,18 @@ _model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_model_metadata_cache_time: float = 0
|
||||
_MODEL_CACHE_TTL = 3600
|
||||
|
||||
# Descending tiers for context length probing when the model is unknown.
|
||||
# We start high and step down on context-length errors until one works.
|
||||
CONTEXT_PROBE_TIERS = [
|
||||
2_000_000,
|
||||
1_000_000,
|
||||
512_000,
|
||||
200_000,
|
||||
128_000,
|
||||
64_000,
|
||||
32_000,
|
||||
]
|
||||
|
||||
DEFAULT_CONTEXT_LENGTHS = {
|
||||
"anthropic/claude-opus-4": 200000,
|
||||
"anthropic/claude-opus-4.5": 200000,
|
||||
@@ -33,6 +49,17 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"meta-llama/llama-3.3-70b-instruct": 131072,
|
||||
"deepseek/deepseek-chat-v3": 65536,
|
||||
"qwen/qwen-2.5-72b-instruct": 32768,
|
||||
"glm-4.7": 202752,
|
||||
"glm-5": 202752,
|
||||
"glm-4.5": 131072,
|
||||
"glm-4.5-flash": 131072,
|
||||
"kimi-k2.5": 262144,
|
||||
"kimi-k2-thinking": 262144,
|
||||
"kimi-k2-turbo-preview": 262144,
|
||||
"kimi-k2-0905-preview": 131072,
|
||||
"MiniMax-M2.5": 204800,
|
||||
"MiniMax-M2.5-highspeed": 204800,
|
||||
"MiniMax-M2.1": 204800,
|
||||
}
|
||||
|
||||
|
||||
@@ -71,17 +98,117 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
|
||||
return _model_metadata_cache or {}
|
||||
|
||||
|
||||
def get_model_context_length(model: str) -> int:
|
||||
"""Get the context length for a model (API first, then fallback defaults)."""
|
||||
def _get_context_cache_path() -> Path:
|
||||
"""Return path to the persistent context length cache file."""
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
return hermes_home / "context_length_cache.yaml"
|
||||
|
||||
|
||||
def _load_context_cache() -> Dict[str, int]:
|
||||
"""Load the model+provider → context_length cache from disk."""
|
||||
path = _get_context_cache_path()
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
return data.get("context_lengths", {})
|
||||
except Exception as e:
|
||||
logger.debug("Failed to load context length cache: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
def save_context_length(model: str, base_url: str, length: int) -> None:
|
||||
"""Persist a discovered context length for a model+provider combo.
|
||||
|
||||
Cache key is ``model@base_url`` so the same model name served from
|
||||
different providers can have different limits.
|
||||
"""
|
||||
key = f"{model}@{base_url}"
|
||||
cache = _load_context_cache()
|
||||
if cache.get(key) == length:
|
||||
return # already stored
|
||||
cache[key] = length
|
||||
path = _get_context_cache_path()
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
||||
logger.info("Cached context length %s → %s tokens", key, f"{length:,}")
|
||||
except Exception as e:
|
||||
logger.debug("Failed to save context length cache: %s", e)
|
||||
|
||||
|
||||
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
"""Look up a previously discovered context length for model+provider."""
|
||||
key = f"{model}@{base_url}"
|
||||
cache = _load_context_cache()
|
||||
return cache.get(key)
|
||||
|
||||
|
||||
def get_next_probe_tier(current_length: int) -> Optional[int]:
|
||||
"""Return the next lower probe tier, or None if already at minimum."""
|
||||
for tier in CONTEXT_PROBE_TIERS:
|
||||
if tier < current_length:
|
||||
return tier
|
||||
return None
|
||||
|
||||
|
||||
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||
"""Try to extract the actual context limit from an API error message.
|
||||
|
||||
Many providers include the limit in their error text, e.g.:
|
||||
- "maximum context length is 32768 tokens"
|
||||
- "context_length_exceeded: 131072"
|
||||
- "Maximum context size 32768 exceeded"
|
||||
- "model's max context length is 65536"
|
||||
"""
|
||||
error_lower = error_msg.lower()
|
||||
# Pattern: look for numbers near context-related keywords
|
||||
patterns = [
|
||||
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
|
||||
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
|
||||
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
|
||||
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
|
||||
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, error_lower)
|
||||
if match:
|
||||
limit = int(match.group(1))
|
||||
# Sanity check: must be a reasonable context length
|
||||
if 1024 <= limit <= 10_000_000:
|
||||
return limit
|
||||
return None
|
||||
|
||||
|
||||
def get_model_context_length(model: str, base_url: str = "") -> int:
|
||||
"""Get the context length for a model.
|
||||
|
||||
Resolution order:
|
||||
1. Persistent cache (previously discovered via probing)
|
||||
2. OpenRouter API metadata
|
||||
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match)
|
||||
4. First probe tier (2M) — will be narrowed on first context error
|
||||
"""
|
||||
# 1. Check persistent cache (model+provider)
|
||||
if base_url:
|
||||
cached = get_cached_context_length(model, base_url)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. OpenRouter API metadata
|
||||
metadata = fetch_model_metadata()
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length", 128000)
|
||||
|
||||
# 3. Hardcoded defaults (fuzzy match)
|
||||
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
if default_model in model or model in default_model:
|
||||
return length
|
||||
|
||||
return 128000
|
||||
# 4. Unknown model — start at highest probe tier
|
||||
return CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
|
||||
def estimate_tokens_rough(text: str) -> int:
|
||||
|
||||
@@ -66,7 +66,8 @@ DEFAULT_AGENT_IDENTITY = (
|
||||
"range of tasks including answering questions, writing and editing code, "
|
||||
"analyzing information, creative work, and executing actions via your tools. "
|
||||
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
|
||||
"being genuinely useful over being verbose unless otherwise directed below."
|
||||
"being genuinely useful over being verbose unless otherwise directed below. "
|
||||
"Be targeted and efficient in your exploration and investigations."
|
||||
)
|
||||
|
||||
MEMORY_GUIDANCE = (
|
||||
@@ -102,12 +103,24 @@ PLATFORM_HINTS = {
|
||||
"You are on a text messaging communication platform, Telegram. "
|
||||
"Please do not use markdown as it does not render. "
|
||||
"You can send media files natively: to deliver a file to the user, "
|
||||
"include MEDIA:/absolute/path/to/file in your response. Audio "
|
||||
"(.ogg) sends as voice bubbles. You can also include image URLs "
|
||||
"in markdown format  and they will be sent as native photos."
|
||||
"include MEDIA:/absolute/path/to/file in your response. Images "
|
||||
"(.png, .jpg, .webp) appear as photos, audio (.ogg) sends as voice "
|
||||
"bubbles, and videos (.mp4) play inline. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as native photos."
|
||||
),
|
||||
"discord": (
|
||||
"You are in a Discord server or group chat communicating with your user."
|
||||
"You are in a Discord server or group chat communicating with your user. "
|
||||
"You can send media files natively: include MEDIA:/absolute/path/to/file "
|
||||
"in your response. Images (.png, .jpg, .webp) are sent as photo "
|
||||
"attachments, audio as file attachments. You can also include image URLs "
|
||||
"in markdown format  and they will be sent as attachments."
|
||||
),
|
||||
"slack": (
|
||||
"You are in a Slack workspace communicating with your user. "
|
||||
"You can send media files natively: include MEDIA:/absolute/path/to/file "
|
||||
"in your response. Images (.png, .jpg, .webp) are uploaded as photo "
|
||||
"attachments, audio as file attachments. You can also include image URLs "
|
||||
"in markdown format  and they will be uploaded as attachments."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
@@ -142,12 +155,28 @@ def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _skill_is_platform_compatible(skill_file: Path) -> bool:
|
||||
"""Quick check if a SKILL.md is compatible with the current OS platform.
|
||||
|
||||
Reads just enough to parse the ``platforms`` frontmatter field.
|
||||
Skills without the field (the vast majority) are always compatible.
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
return skill_matches_platform(frontmatter)
|
||||
except Exception:
|
||||
return True # Err on the side of showing the skill
|
||||
|
||||
|
||||
def build_skills_system_prompt() -> str:
|
||||
"""Build a compact skill index for the system prompt.
|
||||
|
||||
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
|
||||
Includes per-skill descriptions from frontmatter so the model can
|
||||
match skills by meaning, not just name.
|
||||
Filters out skills incompatible with the current OS platform.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
skills_dir = hermes_home / "skills"
|
||||
@@ -159,6 +188,9 @@ def build_skills_system_prompt() -> str:
|
||||
# Each entry: (skill_name, description)
|
||||
skills_by_category: dict[str, list[tuple[str, str]]] = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not _skill_is_platform_compatible(skill_file):
|
||||
continue
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
|
||||
@@ -22,7 +22,7 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
global _skill_commands
|
||||
_skill_commands = {}
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
|
||||
if not SKILLS_DIR.exists():
|
||||
return _skill_commands
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
@@ -31,6 +31,9 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
try:
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
name = frontmatter.get('name', skill_md.parent.name)
|
||||
description = frontmatter.get('description', '')
|
||||
if not description:
|
||||
|
||||
@@ -29,7 +29,6 @@ from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from multiprocessing import Pool, Lock
|
||||
import traceback
|
||||
|
||||
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||
from rich.console import Console
|
||||
import fire
|
||||
@@ -250,7 +249,7 @@ def _process_single_prompt(
|
||||
task_id = f"task_{prompt_index}"
|
||||
|
||||
# Per-prompt container image override: if the dataset row has an 'image' field,
|
||||
# register it for this task's sandbox. Works with Docker, Modal, and Singularity.
|
||||
# register it for this task's sandbox. Works with Docker, Modal, Singularity, and Daytona.
|
||||
container_image = prompt_data.get("image") or prompt_data.get("docker_image")
|
||||
if container_image:
|
||||
# Verify the image is accessible before spending tokens on the agent loop.
|
||||
@@ -292,6 +291,7 @@ def _process_single_prompt(
|
||||
"docker_image": container_image,
|
||||
"modal_image": container_image,
|
||||
"singularity_image": f"docker://{container_image}",
|
||||
"daytona_image": container_image,
|
||||
}
|
||||
if prompt_data.get("cwd"):
|
||||
overrides["cwd"] = prompt_data["cwd"]
|
||||
@@ -700,14 +700,13 @@ class BatchRunner:
|
||||
lock (Lock): Optional lock for thread-safe access
|
||||
"""
|
||||
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
||||
|
||||
|
||||
from utils import atomic_json_write
|
||||
if lock:
|
||||
with lock:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
atomic_json_write(self.checkpoint_file, checkpoint_data)
|
||||
else:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
atomic_json_write(self.checkpoint_file, checkpoint_data)
|
||||
|
||||
def _scan_completed_prompts_by_content(self) -> set:
|
||||
"""
|
||||
@@ -832,13 +831,15 @@ class BatchRunner:
|
||||
print(f" New batches created: {len(batches_to_process)}")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
# Initialize checkpoint data (needed for saving at the end)
|
||||
checkpoint_data = {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
# Load existing checkpoint (so resume doesn't clobber prior progress)
|
||||
checkpoint_data = self._load_checkpoint()
|
||||
if checkpoint_data.get("run_name") != self.run_name:
|
||||
checkpoint_data = {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
# Prepare configuration for workers
|
||||
config = {
|
||||
@@ -860,7 +861,7 @@ class BatchRunner:
|
||||
}
|
||||
|
||||
# For backward compatibility, still track by index (but this is secondary to content matching)
|
||||
completed_prompts_set = set()
|
||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||
|
||||
# Aggregate statistics across all batches
|
||||
total_tool_stats = {}
|
||||
@@ -869,6 +870,9 @@ class BatchRunner:
|
||||
|
||||
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
|
||||
|
||||
# Checkpoint writes happen in the parent process; keep a lock for safety.
|
||||
checkpoint_lock = Lock()
|
||||
|
||||
# Process batches in parallel
|
||||
with Pool(processes=self.num_workers) as pool:
|
||||
# Create tasks for each batch
|
||||
@@ -914,6 +918,25 @@ class BatchRunner:
|
||||
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||
results.append(result)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
# Incremental checkpoint update (so resume works after crash)
|
||||
try:
|
||||
batch_num = result.get('batch_num')
|
||||
completed = result.get('completed_prompts', []) or []
|
||||
completed_prompts_set.update(completed)
|
||||
|
||||
if isinstance(batch_num, int):
|
||||
checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = {
|
||||
'processed': result.get('processed', 0),
|
||||
'skipped': result.get('skipped', 0),
|
||||
'discarded_no_reasoning': result.get('discarded_no_reasoning', 0),
|
||||
}
|
||||
|
||||
checkpoint_data['completed_prompts'] = sorted(completed_prompts_set)
|
||||
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
|
||||
except Exception as ckpt_err:
|
||||
# Don't fail the run if checkpoint write fails
|
||||
print(f"⚠️ Warning: Failed to save incremental checkpoint: {ckpt_err}")
|
||||
except Exception as e:
|
||||
logger.error("Batch worker failed: %s", e, exc_info=True)
|
||||
raise
|
||||
@@ -945,9 +968,12 @@ class BatchRunner:
|
||||
for key in total_reasoning_stats:
|
||||
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
|
||||
|
||||
# Save final checkpoint
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data)
|
||||
# Save final checkpoint (best-effort; incremental writes already happened)
|
||||
try:
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
|
||||
except Exception as ckpt_err:
|
||||
print(f"âš ï¸ Warning: Failed to save final checkpoint: {ckpt_err}")
|
||||
|
||||
# Calculate success rates
|
||||
for tool_name in total_tool_stats:
|
||||
@@ -1086,7 +1112,7 @@ def main(
|
||||
batch_size: int = None,
|
||||
run_name: str = None,
|
||||
distribution: str = "default",
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-sonnet-4.6",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 10,
|
||||
@@ -1129,7 +1155,7 @@ def main(
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh")
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "medium")
|
||||
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
|
||||
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
@@ -1190,7 +1216,7 @@ def main(
|
||||
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
|
||||
|
||||
# Build reasoning_config from CLI flags
|
||||
# --reasoning_disabled takes priority, then --reasoning_effort, then default (xhigh)
|
||||
# --reasoning_disabled takes priority, then --reasoning_effort, then default (medium)
|
||||
reasoning_config = None
|
||||
if reasoning_disabled:
|
||||
# Completely disable reasoning/thinking tokens
|
||||
|
||||
@@ -13,6 +13,10 @@ model:
|
||||
# "auto" - Use Nous Portal if logged in, otherwise OpenRouter/env vars (default)
|
||||
# "openrouter" - Always use OpenRouter API key from OPENROUTER_API_KEY
|
||||
# "nous" - Always use Nous Portal (requires: hermes login)
|
||||
# "zai" - Use z.ai / ZhipuAI GLM models (requires: GLM_API_KEY)
|
||||
# "kimi-coding"- Use Kimi / Moonshot AI models (requires: KIMI_API_KEY)
|
||||
# "minimax" - Use MiniMax global endpoint (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - Use MiniMax China endpoint (requires: MINIMAX_CN_API_KEY)
|
||||
# Can also be overridden with --provider flag or HERMES_INFERENCE_PROVIDER env var.
|
||||
provider: "auto"
|
||||
|
||||
@@ -46,6 +50,16 @@ model:
|
||||
# # Data policy: "allow" (default) or "deny" to exclude providers that may store data
|
||||
# # data_collection: "deny"
|
||||
|
||||
# =============================================================================
|
||||
# Git Worktree Isolation
|
||||
# =============================================================================
|
||||
# When enabled, each CLI session creates an isolated git worktree so multiple
|
||||
# agents can work on the same repo concurrently without file collisions.
|
||||
# Equivalent to always passing --worktree / -w on the command line.
|
||||
#
|
||||
# worktree: true # Always create a worktree when in a git repo
|
||||
# worktree: false # Default — only create when -w flag is passed
|
||||
|
||||
# =============================================================================
|
||||
# Terminal Tool Configuration
|
||||
# =============================================================================
|
||||
@@ -116,8 +130,23 @@ terminal:
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# modal_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 6: Daytona cloud execution
|
||||
# Commands run in Daytona cloud sandboxes
|
||||
# Great for: Cloud dev environments, persistent workspaces, team collaboration
|
||||
# Requires: pip install daytona, DAYTONA_API_KEY env var
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# backend: "daytona"
|
||||
# cwd: "~"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# daytona_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
# container_disk: 10240 # Daytona max is 10GB per sandbox
|
||||
|
||||
#
|
||||
# --- Container resource limits (docker, singularity, modal -- ignored for local/ssh) ---
|
||||
# --- Container resource limits (docker, singularity, modal, daytona -- ignored for local/ssh) ---
|
||||
# These settings apply to all container backends. They control the resources
|
||||
# allocated to the sandbox and whether its filesystem persists across sessions.
|
||||
container_cpu: 1 # CPU cores
|
||||
@@ -266,7 +295,7 @@ agent:
|
||||
# Reasoning effort level (OpenRouter and Nous Portal)
|
||||
# Controls how much "thinking" the model does before responding.
|
||||
# Options: "xhigh" (max), "high", "medium", "low", "minimal", "none" (disable)
|
||||
reasoning_effort: "xhigh"
|
||||
reasoning_effort: "medium"
|
||||
|
||||
# Predefined personalities (use with /personality command)
|
||||
personalities:
|
||||
|
||||
45
cron/jobs.py
45
cron/jobs.py
@@ -14,6 +14,8 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
try:
|
||||
from croniter import croniter
|
||||
HAS_CRONITER = True
|
||||
@@ -128,7 +130,7 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
# Duration like "30m", "2h", "1d" → one-shot from now
|
||||
try:
|
||||
minutes = parse_duration(schedule)
|
||||
run_at = datetime.now() + timedelta(minutes=minutes)
|
||||
run_at = _hermes_now() + timedelta(minutes=minutes)
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": run_at.isoformat(),
|
||||
@@ -146,37 +148,50 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_aware(dt: datetime) -> datetime:
|
||||
"""Make a naive datetime tz-aware using the configured timezone.
|
||||
|
||||
Handles backward compatibility: timestamps stored before timezone support
|
||||
are naive (server-local). We assume they were in the same timezone as
|
||||
the current configuration so comparisons work without crashing.
|
||||
"""
|
||||
if dt.tzinfo is None:
|
||||
tz = _hermes_now().tzinfo
|
||||
return dt.replace(tzinfo=tz)
|
||||
return dt
|
||||
|
||||
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
|
||||
|
||||
Returns ISO timestamp string, or None if no more runs.
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
now = _hermes_now()
|
||||
|
||||
if schedule["kind"] == "once":
|
||||
run_at = datetime.fromisoformat(schedule["run_at"])
|
||||
run_at = _ensure_aware(datetime.fromisoformat(schedule["run_at"]))
|
||||
# If in the future, return it; if in the past, no more runs
|
||||
return schedule["run_at"] if run_at > now else None
|
||||
|
||||
|
||||
elif schedule["kind"] == "interval":
|
||||
minutes = schedule["minutes"]
|
||||
if last_run_at:
|
||||
# Next run is last_run + interval
|
||||
last = datetime.fromisoformat(last_run_at)
|
||||
last = _ensure_aware(datetime.fromisoformat(last_run_at))
|
||||
next_run = last + timedelta(minutes=minutes)
|
||||
else:
|
||||
# First run is now + interval
|
||||
next_run = now + timedelta(minutes=minutes)
|
||||
return next_run.isoformat()
|
||||
|
||||
|
||||
elif schedule["kind"] == "cron":
|
||||
if not HAS_CRONITER:
|
||||
return None
|
||||
cron = croniter(schedule["expr"], now)
|
||||
next_run = cron.get_next(datetime)
|
||||
return next_run.isoformat()
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -204,7 +219,7 @@ def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_')
|
||||
try:
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
json.dump({"jobs": jobs, "updated_at": datetime.now().isoformat()}, f, indent=2)
|
||||
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, JOBS_FILE)
|
||||
@@ -249,7 +264,7 @@ def create_job(
|
||||
deliver = "origin" if origin else "local"
|
||||
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
now = datetime.now().isoformat()
|
||||
now = _hermes_now().isoformat()
|
||||
|
||||
job = {
|
||||
"id": job_id,
|
||||
@@ -328,7 +343,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
if job["id"] == job_id:
|
||||
now = datetime.now().isoformat()
|
||||
now = _hermes_now().isoformat()
|
||||
job["last_run_at"] = now
|
||||
job["last_status"] = "ok" if success else "error"
|
||||
job["last_error"] = error if not success else None
|
||||
@@ -361,7 +376,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
now = datetime.now()
|
||||
now = _hermes_now()
|
||||
jobs = load_jobs()
|
||||
due = []
|
||||
|
||||
@@ -373,7 +388,7 @@ def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
if not next_run:
|
||||
continue
|
||||
|
||||
next_run_dt = datetime.fromisoformat(next_run)
|
||||
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
|
||||
if next_run_dt <= now:
|
||||
due.append(job)
|
||||
|
||||
@@ -386,7 +401,7 @@ def save_job_output(job_id: str, output: str):
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
output_file = job_output_dir / f"{timestamp}.md"
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
@@ -27,6 +27,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add parent directory to path for imports
|
||||
@@ -174,6 +176,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
model = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
|
||||
|
||||
# Load config.yaml for model, reasoning, prefill, toolsets, provider routing
|
||||
_cfg = {}
|
||||
try:
|
||||
import yaml
|
||||
_cfg_path = str(_hermes_home / "config.yaml")
|
||||
@@ -188,6 +192,41 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reasoning config from env or config.yaml
|
||||
reasoning_config = None
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
effort = str(_cfg.get("agent", {}).get("reasoning_effort", "")).strip()
|
||||
if effort and effort.lower() != "none":
|
||||
valid = ("xhigh", "high", "medium", "low", "minimal")
|
||||
if effort.lower() in valid:
|
||||
reasoning_config = {"enabled": True, "effort": effort.lower()}
|
||||
elif effort.lower() == "none":
|
||||
reasoning_config = {"enabled": False}
|
||||
|
||||
# Prefill messages from env or config.yaml
|
||||
prefill_messages = None
|
||||
prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "")
|
||||
if prefill_file:
|
||||
import json as _json
|
||||
pfpath = Path(prefill_file).expanduser()
|
||||
if not pfpath.is_absolute():
|
||||
pfpath = _hermes_home / pfpath
|
||||
if pfpath.exists():
|
||||
try:
|
||||
with open(pfpath, "r", encoding="utf-8") as _pf:
|
||||
prefill_messages = _json.load(_pf)
|
||||
if not isinstance(prefill_messages, list):
|
||||
prefill_messages = None
|
||||
except Exception:
|
||||
prefill_messages = None
|
||||
|
||||
# Max iterations
|
||||
max_iterations = _cfg.get("agent", {}).get("max_turns") or _cfg.get("max_turns") or 90
|
||||
|
||||
# Provider routing
|
||||
pr = _cfg.get("provider_routing", {})
|
||||
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
format_runtime_provider_error,
|
||||
@@ -206,8 +245,15 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
base_url=runtime.get("base_url"),
|
||||
provider=runtime.get("provider"),
|
||||
api_mode=runtime.get("api_mode"),
|
||||
max_iterations=max_iterations,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
@@ -219,7 +265,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
@@ -241,7 +287,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
output = f"""# Cron Job: {job_name} (FAILED)
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
@@ -280,6 +326,7 @@ def tick(verbose: bool = True) -> int:
|
||||
_LOCK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Cross-platform file locking: fcntl on Unix, msvcrt on Windows
|
||||
lock_fd = None
|
||||
try:
|
||||
lock_fd = open(_LOCK_FILE, "w")
|
||||
if fcntl:
|
||||
@@ -288,17 +335,19 @@ def tick(verbose: bool = True) -> int:
|
||||
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_NBLCK, 1)
|
||||
except (OSError, IOError):
|
||||
logger.debug("Tick skipped — another instance holds the lock")
|
||||
if lock_fd is not None:
|
||||
lock_fd.close()
|
||||
return 0
|
||||
|
||||
try:
|
||||
due_jobs = get_due_jobs()
|
||||
|
||||
if verbose and not due_jobs:
|
||||
logger.info("%s - No jobs due", datetime.now().strftime('%H:%M:%S'))
|
||||
logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S'))
|
||||
return 0
|
||||
|
||||
if verbose:
|
||||
logger.info("%s - %s job(s) due", datetime.now().strftime('%H:%M:%S'), len(due_jobs))
|
||||
logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs))
|
||||
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
|
||||
345
docs/send_file_integration_map.md
Normal file
345
docs/send_file_integration_map.md
Normal file
@@ -0,0 +1,345 @@
|
||||
# send_file Integration Map — Hermes Agent Codebase Deep Dive
|
||||
|
||||
## 1. environments/tool_context.py — Base64 File Transfer Implementation
|
||||
|
||||
### upload_file() (lines 153-205)
|
||||
- Reads local file as raw bytes, base64-encodes to ASCII string
|
||||
- Creates parent dirs in sandbox via `self.terminal(f"mkdir -p {parent}")`
|
||||
- **Chunk size:** 60,000 chars (~60KB per shell command)
|
||||
- **Small files (<=60KB b64):** Single `printf '%s' '{b64}' | base64 -d > {remote_path}`
|
||||
- **Large files:** Writes chunks to `/tmp/_hermes_upload.b64` via `printf >> append`, then `base64 -d` to target
|
||||
- **Error handling:** Checks local file exists; returns `{exit_code, output}`
|
||||
- **Size limits:** No explicit limit, but shell arg limit ~2MB means chunking is necessary for files >~45KB raw
|
||||
- **No theoretical max** — but very large files would be slow (many terminal round trips)
|
||||
|
||||
### download_file() (lines 234-278)
|
||||
- Runs `base64 {remote_path}` inside sandbox, captures stdout
|
||||
- Strips output, base64-decodes to raw bytes
|
||||
- Writes to host filesystem with parent dir creation
|
||||
- **Error handling:** Checks exit code, empty output, decode errors
|
||||
- Returns `{success: bool, bytes: int}` or `{success: false, error: str}`
|
||||
- **Size limit:** Bounded by terminal output buffer (practical limit ~few MB via base64 terminal output)
|
||||
|
||||
### Promotion potential:
|
||||
- These methods work via `self.terminal()` — they're environment-agnostic
|
||||
- Could be directly lifted into a new tool that operates on the agent's current sandbox
|
||||
- For send_file, this `download_file()` pattern is the key: it extracts files from sandbox → host
|
||||
|
||||
## 2. tools/environments/base.py — BaseEnvironment Interface
|
||||
|
||||
### Current methods:
|
||||
- `execute(command, cwd, timeout, stdin_data)` → `{output, returncode}`
|
||||
- `cleanup()` — release resources
|
||||
- `stop()` — alias for cleanup
|
||||
- `_prepare_command()` — sudo transformation
|
||||
- `_build_run_kwargs()` — subprocess kwargs
|
||||
- `_timeout_result()` — standard timeout dict
|
||||
|
||||
### What would need to be added for file transfer:
|
||||
- **Nothing required at this level.** File transfer can be implemented via `execute()` (base64 over terminal, like ToolContext does) or via environment-specific methods.
|
||||
- Optional: `upload_file(local_path, remote_path)` and `download_file(remote_path, local_path)` methods could be added to BaseEnvironment for optimized per-backend transfers, but the base64-over-terminal approach already works universally.
|
||||
|
||||
## 3. tools/environments/docker.py — Docker Container Details
|
||||
|
||||
### Container ID tracking:
|
||||
- `self._container_id` stored at init from `self._inner.container_id`
|
||||
- Inner is `minisweagent.environments.docker.DockerEnvironment`
|
||||
- Container ID is a standard Docker container hash
|
||||
|
||||
### docker cp feasibility:
|
||||
- **YES**, `docker cp` could be used for optimized file transfer:
|
||||
- `docker cp {container_id}:{remote_path} {local_path}` (download)
|
||||
- `docker cp {local_path} {container_id}:{remote_path}` (upload)
|
||||
- Much faster than base64-over-terminal for large files
|
||||
- Container ID is directly accessible via `env._container_id` or `env._inner.container_id`
|
||||
|
||||
### Volumes mounted:
|
||||
- **Persistent mode:** Bind mounts at `~/.hermes/sandboxes/docker/{task_id}/workspace` → `/workspace` and `.../home` → `/root`
|
||||
- **Ephemeral mode:** tmpfs at `/workspace` (10GB), `/home` (1GB), `/root` (1GB)
|
||||
- **User volumes:** From `config.yaml docker_volumes` (arbitrary `-v` mounts)
|
||||
- **Security tmpfs:** `/tmp` (512MB), `/var/tmp` (256MB), `/run` (64MB)
|
||||
|
||||
### Direct host access for persistent mode:
|
||||
- If persistent, files at `/workspace/foo.txt` are just `~/.hermes/sandboxes/docker/{task_id}/workspace/foo.txt` on host — no transfer needed!
|
||||
|
||||
## 4. tools/environments/ssh.py — SSH Connection Management
|
||||
|
||||
### Connection management:
|
||||
- Uses SSH ControlMaster for persistent connection
|
||||
- Control socket at `/tmp/hermes-ssh/{user}@{host}:{port}.sock`
|
||||
- ControlPersist=300 (5 min keepalive)
|
||||
- BatchMode=yes (non-interactive)
|
||||
- Stores: `self.host`, `self.user`, `self.port`, `self.key_path`
|
||||
|
||||
### SCP/SFTP feasibility:
|
||||
- **YES**, SCP can piggyback on the ControlMaster socket:
|
||||
- `scp -o ControlPath={socket} {user}@{host}:{remote} {local}` (download)
|
||||
- `scp -o ControlPath={socket} {local} {user}@{host}:{remote}` (upload)
|
||||
- Same SSH key and connection reuse — zero additional auth
|
||||
- Would be much faster than base64-over-terminal for large files
|
||||
|
||||
## 5. tools/environments/modal.py — Modal Sandbox Filesystem
|
||||
|
||||
### Filesystem API exposure:
|
||||
- **Not directly.** The inner `SwerexModalEnvironment` wraps Modal's sandbox
|
||||
- The sandbox object is accessible at: `env._inner.deployment._sandbox`
|
||||
- Modal's Python SDK exposes `sandbox.open()` for file I/O — but only via async API
|
||||
- Currently only used for `snapshot_filesystem()` during cleanup
|
||||
- **Could use:** `sandbox.open(path, "rb")` to read files or `sandbox.open(path, "wb")` to write
|
||||
- **Alternative:** Base64-over-terminal already works via `execute()` — simpler, no SDK dependency
|
||||
|
||||
## 6. gateway/platforms/base.py — MEDIA: Tag Flow (Complete)
|
||||
|
||||
### extract_media() (lines 587-620):
|
||||
- **Pattern:** `MEDIA:\S+` — extracts file paths after MEDIA: prefix
|
||||
- **Voice flag:** `[[audio_as_voice]]` global directive sets `is_voice=True` for all media in message
|
||||
- Returns `List[Tuple[str, bool]]` (path, is_voice) and cleaned content
|
||||
|
||||
### _process_message_background() media routing (lines 752-786):
|
||||
- After extracting MEDIA tags, routes by file extension:
|
||||
- `.ogg .opus .mp3 .wav .m4a` → `send_voice()`
|
||||
- `.mp4 .mov .avi .mkv .3gp` → `send_video()`
|
||||
- `.jpg .jpeg .png .webp .gif` → `send_image_file()`
|
||||
- **Everything else** → `send_document()`
|
||||
- This routing already supports arbitrary files!
|
||||
|
||||
### send_* method inventory (base class):
|
||||
- `send(chat_id, content, reply_to, metadata)` — ABSTRACT, text
|
||||
- `send_image(chat_id, image_url, caption, reply_to)` — URL-based images
|
||||
- `send_animation(chat_id, animation_url, caption, reply_to)` — GIF animations
|
||||
- `send_voice(chat_id, audio_path, caption, reply_to)` — voice messages
|
||||
- `send_video(chat_id, video_path, caption, reply_to)` — video files
|
||||
- `send_document(chat_id, file_path, caption, file_name, reply_to)` — generic files
|
||||
- `send_image_file(chat_id, image_path, caption, reply_to)` — local image files
|
||||
- `send_typing(chat_id)` — typing indicator
|
||||
- `edit_message(chat_id, message_id, content)` — edit sent messages
|
||||
|
||||
### What's missing:
|
||||
- **Telegram:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
|
||||
- **Discord:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
|
||||
- **Slack:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
|
||||
- **WhatsApp:** Has `send_document` and `send_image_file` via bridge — COMPLETE.
|
||||
- The base class defaults just send "📎 File: /path" as text — useless for actual file delivery.
|
||||
|
||||
## 7. gateway/platforms/telegram.py — Send Method Analysis
|
||||
|
||||
### Implemented send methods:
|
||||
- `send()` — MarkdownV2 text with fallback to plain
|
||||
- `send_voice()` — `.ogg`/`.opus` as `send_voice()`, others as `send_audio()`
|
||||
- `send_image()` — URL-based via `send_photo()`
|
||||
- `send_image_file()` — local file via `send_photo(photo=open(path, 'rb'))` ✅
|
||||
- `send_animation()` — GIF via `send_animation()`
|
||||
- `send_typing()` — "typing" chat action
|
||||
- `edit_message()` — edit text messages
|
||||
|
||||
### MISSING:
|
||||
- **`send_document()` NOT overridden** — Need to add `self._bot.send_document(chat_id, document=open(file_path, 'rb'), ...)`
|
||||
- **`send_video()` NOT overridden** — Need to add `self._bot.send_video(...)`
|
||||
|
||||
## 8. gateway/platforms/discord.py — Send Method Analysis
|
||||
|
||||
### Implemented send methods:
|
||||
- `send()` — text messages with chunking
|
||||
- `send_voice()` — discord.File attachment
|
||||
- `send_image()` — downloads URL, creates discord.File attachment
|
||||
- `send_image_file()` — local file via discord.File attachment ✅
|
||||
- `send_typing()` — channel.typing()
|
||||
- `edit_message()` — edit text messages
|
||||
|
||||
### MISSING:
|
||||
- **`send_document()` NOT overridden** — Need to add discord.File attachment
|
||||
- **`send_video()` NOT overridden** — Need to add discord.File attachment
|
||||
|
||||
## 9. gateway/run.py — User File Attachment Handling
|
||||
|
||||
### Current attachment flow:
|
||||
1. **Telegram photos** (line 509-529): Download via `photo.get_file()` → `cache_image_from_bytes()` → vision auto-analysis
|
||||
2. **Telegram voice** (line 532-541): Download → `cache_audio_from_bytes()` → STT transcription
|
||||
3. **Telegram audio** (line 542-551): Same pattern
|
||||
4. **Telegram documents** (line 553-617): Extension validation against `SUPPORTED_DOCUMENT_TYPES`, 20MB limit, content injection for text files
|
||||
5. **Discord attachments** (line 717-751): Content-type detection, image/audio caching, URL fallback for other types
|
||||
6. **Gateway run.py** (lines 818-883): Auto-analyzes images with vision, transcribes audio, enriches document messages with context notes
|
||||
|
||||
### Key insight: Files are always cached to host filesystem first, then processed. The agent sees local file paths.
|
||||
|
||||
## 10. tools/terminal_tool.py — Terminal Tool & Environment Interaction
|
||||
|
||||
### How it manages environments:
|
||||
- Global dict `_active_environments: Dict[str, Any]` keyed by task_id
|
||||
- Per-task creation locks prevent duplicate sandbox creation
|
||||
- Auto-cleanup thread kills idle environments after `TERMINAL_LIFETIME_SECONDS`
|
||||
- `_get_env_config()` reads all TERMINAL_* env vars for backend selection
|
||||
- `_create_environment()` factory creates the right backend type
|
||||
|
||||
### Could send_file piggyback?
|
||||
- **YES.** send_file needs access to the same environment to extract files from sandboxes.
|
||||
- It can reuse `_active_environments[task_id]` to get the environment, then:
|
||||
- Docker: Use `docker cp` via `env._container_id`
|
||||
- SSH: Use `scp` via `env.control_socket`
|
||||
- Local: Just read the file directly
|
||||
- Modal: Use base64-over-terminal via `env.execute()`
|
||||
- The file_tools.py module already does this with `ShellFileOperations` — read_file/write_file/search/patch all share the same env instance.
|
||||
|
||||
## 11. tools/tts_tool.py — Working Example of File Delivery
|
||||
|
||||
### Flow:
|
||||
1. Generate audio file to `~/.hermes/audio_cache/tts_TIMESTAMP.{ogg,mp3}`
|
||||
2. Return JSON with `media_tag: "MEDIA:/path/to/file"`
|
||||
3. For Telegram voice: prepend `[[audio_as_voice]]` directive
|
||||
4. The LLM includes the MEDIA tag in its response text
|
||||
5. `BasePlatformAdapter._process_message_background()` calls `extract_media()` to find the tag
|
||||
6. Routes by extension → `send_voice()` for audio files
|
||||
7. Platform adapter sends the file natively
|
||||
|
||||
### Key pattern: Tool saves file to host → returns MEDIA: path → LLM echoes it → gateway extracts → platform delivers
|
||||
|
||||
## 12. tools/image_generation_tool.py — Working Example of Image Delivery
|
||||
|
||||
### Flow:
|
||||
1. Call FAL.ai API → get image URL
|
||||
2. Return JSON with `image: "https://fal.media/..."` URL
|
||||
3. The LLM includes the URL in markdown: ``
|
||||
4. `BasePlatformAdapter.extract_images()` finds `` patterns
|
||||
5. Routes through `send_image()` (URL) or `send_animation()` (GIF)
|
||||
6. Platform downloads and sends natively
|
||||
|
||||
### Key difference from TTS: Images are URL-based, not local files. The gateway downloads at send time.
|
||||
|
||||
---
|
||||
|
||||
# INTEGRATION MAP: Where send_file Hooks In
|
||||
|
||||
## Architecture Decision: MEDIA: Tag Protocol vs. New Tool
|
||||
|
||||
The MEDIA: tag protocol is already the established pattern for file delivery. Two options:
|
||||
|
||||
### Option A: Pure MEDIA: Tag (Minimal Change)
|
||||
- No new tool needed
|
||||
- Agent downloads file from sandbox to host using terminal (base64)
|
||||
- Saves to known location (e.g., `~/.hermes/file_cache/`)
|
||||
- Includes `MEDIA:/path` in response text
|
||||
- Existing routing in `_process_message_background()` handles delivery
|
||||
- **Problem:** Agent has to manually do base64 dance + know about MEDIA: convention
|
||||
|
||||
### Option B: Dedicated send_file Tool (Recommended)
|
||||
- New tool that the agent calls with `(file_path, caption?)`
|
||||
- Tool handles the sandbox → host extraction automatically
|
||||
- Returns MEDIA: tag that gets routed through existing pipeline
|
||||
- Much cleaner agent experience
|
||||
|
||||
## Implementation Plan for Option B
|
||||
|
||||
### Files to CREATE:
|
||||
|
||||
1. **`tools/send_file_tool.py`** — The new tool
|
||||
- Accepts: `file_path` (path in sandbox), `caption` (optional)
|
||||
- Detects environment backend from `_active_environments`
|
||||
- Extracts file from sandbox:
|
||||
- **local:** `shutil.copy()` or direct path
|
||||
- **docker:** `docker cp {container_id}:{path} {local_cache}/`
|
||||
- **ssh:** `scp -o ControlPath=... {user}@{host}:{path} {local_cache}/`
|
||||
- **modal:** base64-over-terminal via `env.execute("base64 {path}")`
|
||||
- Saves to `~/.hermes/file_cache/{uuid}_{filename}`
|
||||
- Returns: `MEDIA:/cached/path` in response for gateway to pick up
|
||||
- Register with `registry.register(name="send_file", toolset="file", ...)`
|
||||
|
||||
### Files to MODIFY:
|
||||
|
||||
2. **`gateway/platforms/telegram.py`** — Add missing send methods:
|
||||
```python
|
||||
async def send_document(self, chat_id, file_path, caption=None, file_name=None, reply_to=None):
|
||||
with open(file_path, "rb") as f:
|
||||
msg = await self._bot.send_document(
|
||||
chat_id=int(chat_id), document=f,
|
||||
caption=caption, filename=file_name or os.path.basename(file_path))
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
|
||||
async def send_image_file(self, chat_id, image_path, caption=None, reply_to=None):
|
||||
with open(image_path, "rb") as f:
|
||||
msg = await self._bot.send_photo(chat_id=int(chat_id), photo=f, caption=caption)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
|
||||
async def send_video(self, chat_id, video_path, caption=None, reply_to=None):
|
||||
with open(video_path, "rb") as f:
|
||||
msg = await self._bot.send_video(chat_id=int(chat_id), video=f, caption=caption)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
```
|
||||
|
||||
3. **`gateway/platforms/discord.py`** — Add missing send methods:
|
||||
```python
|
||||
async def send_document(self, chat_id, file_path, caption=None, file_name=None, reply_to=None):
|
||||
channel = self._client.get_channel(int(chat_id)) or await self._client.fetch_channel(int(chat_id))
|
||||
with open(file_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=file_name or os.path.basename(file_path))
|
||||
msg = await channel.send(content=caption, file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
async def send_image_file(self, chat_id, image_path, caption=None, reply_to=None):
|
||||
# Same pattern as send_document with image filename
|
||||
|
||||
async def send_video(self, chat_id, video_path, caption=None, reply_to=None):
|
||||
# Same pattern, discord renders video attachments inline
|
||||
```
|
||||
|
||||
4. **`toolsets.py`** — Add `"send_file"` to `_HERMES_CORE_TOOLS` list
|
||||
|
||||
5. **`agent/prompt_builder.py`** — Update platform hints to mention send_file tool
|
||||
|
||||
### Code that can be REUSED (zero rewrite):
|
||||
|
||||
- `BasePlatformAdapter.extract_media()` — Already extracts MEDIA: tags
|
||||
- `BasePlatformAdapter._process_message_background()` — Already routes by extension
|
||||
- `ToolContext.download_file()` — Base64-over-terminal extraction pattern
|
||||
- `tools/terminal_tool.py` _active_environments dict — Environment access
|
||||
- `tools/registry.py` — Tool registration infrastructure
|
||||
- `gateway/platforms/base.py` send_document/send_image_file/send_video signatures — Already defined
|
||||
|
||||
### Code that needs to be WRITTEN from scratch:
|
||||
|
||||
1. `tools/send_file_tool.py` (~150 lines):
|
||||
- File extraction from each environment backend type
|
||||
- Local file cache management
|
||||
- Registry registration
|
||||
|
||||
2. Telegram `send_document` + `send_image_file` + `send_video` overrides (~40 lines)
|
||||
3. Discord `send_document` + `send_image_file` + `send_video` overrides (~50 lines)
|
||||
|
||||
### Total effort: ~240 lines of new code, ~5 lines of config changes
|
||||
|
||||
## Key Environment-Specific Extract Strategies
|
||||
|
||||
| Backend | Extract Method | Speed | Complexity |
|
||||
|------------|-------------------------------|----------|------------|
|
||||
| local | shutil.copy / direct path | Instant | None |
|
||||
| docker | `docker cp container:path .` | Fast | Low |
|
||||
| docker+vol | Direct host path access | Instant | None |
|
||||
| ssh | `scp -o ControlPath=...` | Fast | Low |
|
||||
| modal | base64-over-terminal | Moderate | Medium |
|
||||
| singularity| Direct path (overlay mount) | Fast | Low |
|
||||
|
||||
## Data Flow Summary
|
||||
|
||||
```
|
||||
Agent calls send_file(file_path="/workspace/output.pdf", caption="Here's the report")
|
||||
│
|
||||
▼
|
||||
send_file_tool.py:
|
||||
1. Get environment from _active_environments[task_id]
|
||||
2. Detect backend type (docker/ssh/modal/local)
|
||||
3. Extract file to ~/.hermes/file_cache/{uuid}_{filename}
|
||||
4. Return: '{"success": true, "media_tag": "MEDIA:/home/user/.hermes/file_cache/abc123_output.pdf"}'
|
||||
│
|
||||
▼
|
||||
LLM includes MEDIA: tag in its response text
|
||||
│
|
||||
▼
|
||||
BasePlatformAdapter._process_message_background():
|
||||
1. extract_media(response) → finds MEDIA:/path
|
||||
2. Checks extension: .pdf → send_document()
|
||||
3. Calls platform-specific send_document(chat_id, file_path, caption)
|
||||
│
|
||||
▼
|
||||
TelegramAdapter.send_document() / DiscordAdapter.send_document():
|
||||
Opens file, sends via platform API as native document attachment
|
||||
User receives downloadable file in chat
|
||||
```
|
||||
@@ -40,7 +40,7 @@ This directory contains the integration layer between **hermes-agent's** tool-ca
|
||||
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
|
||||
|
||||
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
|
||||
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, ssh, singularity)
|
||||
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, daytona, ssh, singularity)
|
||||
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` which queries `tools/registry.py`)
|
||||
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
|
||||
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
|
||||
@@ -195,8 +195,12 @@ environments/
|
||||
│ └── hermes_swe_env.py
|
||||
│
|
||||
└── benchmarks/ # Evaluation benchmarks
|
||||
└── terminalbench_2/
|
||||
└── terminalbench2_env.py
|
||||
├── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
|
||||
│ └── terminalbench2_env.py
|
||||
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
|
||||
│ └── tblite_env.py
|
||||
└── yc_bench/ # Long-horizon strategic benchmark
|
||||
└── yc_bench_env.py
|
||||
```
|
||||
|
||||
## Concrete Environments
|
||||
@@ -324,7 +328,7 @@ For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
|
||||
| `distribution` | Probabilistic toolset distribution name | `None` |
|
||||
| `max_agent_turns` | Max LLM calls per rollout | `30` |
|
||||
| `agent_temperature` | Sampling temperature | `1.0` |
|
||||
| `terminal_backend` | `local`, `docker`, `modal`, `ssh`, `singularity` | `local` |
|
||||
| `terminal_backend` | `local`, `docker`, `modal`, `daytona`, `ssh`, `singularity` | `local` |
|
||||
| `system_prompt` | System message for the agent | `None` |
|
||||
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
|
||||
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |
|
||||
|
||||
@@ -18,9 +18,14 @@ Benchmarks (eval-only):
|
||||
- benchmarks/terminalbench_2/: Terminal-Bench 2.0 evaluation
|
||||
"""
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
try:
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
except ImportError:
|
||||
# atroposlib not installed — environments are unavailable but
|
||||
# submodules like tool_call_parsers can still be imported directly.
|
||||
pass
|
||||
|
||||
__all__ = [
|
||||
"AgentResult",
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any, Dict, List, Optional, Set
|
||||
from model_tools import handle_function_call
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., mini-swe-agent's modal/docker backends). Running them in a separate
|
||||
# (e.g., mini-swe-agent's modal/docker/daytona backends). Running them in a separate
|
||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
|
||||
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
|
||||
@@ -249,23 +249,62 @@ class HermesAgentLoop:
|
||||
reasoning = _extract_reasoning_from_message(assistant_msg)
|
||||
reasoning_per_turn.append(reasoning)
|
||||
|
||||
# Check for tool calls -- standard OpenAI spec
|
||||
# Check for tool calls -- standard OpenAI spec.
|
||||
# Fallback: if response has no structured tool_calls but content
|
||||
# contains raw tool call tags (e.g. <tool_call>), parse them using
|
||||
# hermes-agent's standalone parsers. This handles the case where
|
||||
# ManagedServer's ToolCallTranslator couldn't parse because vLLM
|
||||
# isn't installed.
|
||||
if (
|
||||
not assistant_msg.tool_calls
|
||||
and assistant_msg.content
|
||||
and self.tool_schemas
|
||||
and "<tool_call>" in (assistant_msg.content or "")
|
||||
):
|
||||
try:
|
||||
from environments.tool_call_parsers import get_parser
|
||||
fallback_parser = get_parser("hermes")
|
||||
parsed_content, parsed_calls = fallback_parser.parse(
|
||||
assistant_msg.content
|
||||
)
|
||||
if parsed_calls:
|
||||
assistant_msg.tool_calls = parsed_calls
|
||||
if parsed_content is not None:
|
||||
assistant_msg.content = parsed_content
|
||||
logger.debug(
|
||||
"Fallback parser extracted %d tool calls from raw content",
|
||||
len(parsed_calls),
|
||||
)
|
||||
except Exception:
|
||||
pass # Fall through to no tool calls
|
||||
|
||||
if assistant_msg.tool_calls:
|
||||
# Normalize tool calls to dicts — they may come as objects
|
||||
# (OpenAI API) or dicts (vLLM ToolCallTranslator).
|
||||
def _tc_to_dict(tc):
|
||||
if isinstance(tc, dict):
|
||||
return {
|
||||
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("function", {}).get("name", tc.get("name", "")),
|
||||
"arguments": tc.get("function", {}).get("arguments", tc.get("arguments", "{}")),
|
||||
},
|
||||
}
|
||||
return {
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
|
||||
# Build the assistant message dict for conversation history
|
||||
msg_dict: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in assistant_msg.tool_calls
|
||||
],
|
||||
"tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls],
|
||||
}
|
||||
|
||||
# Preserve reasoning_content for multi-turn chat template handling
|
||||
@@ -278,8 +317,13 @@ class HermesAgentLoop:
|
||||
|
||||
# Execute each tool call via hermes-agent's dispatch
|
||||
for tc in assistant_msg.tool_calls:
|
||||
tool_name = tc.function.name
|
||||
tool_args_raw = tc.function.arguments
|
||||
# Handle both object (OpenAI) and dict (vLLM) formats
|
||||
if isinstance(tc, dict):
|
||||
tool_name = tc.get("function", {}).get("name", tc.get("name", ""))
|
||||
tool_args_raw = tc.get("function", {}).get("arguments", tc.get("arguments", "{}"))
|
||||
else:
|
||||
tool_name = tc.function.name
|
||||
tool_args_raw = tc.function.arguments
|
||||
|
||||
# Validate tool name
|
||||
if tool_name not in self.valid_tool_names:
|
||||
@@ -336,7 +380,7 @@ class HermesAgentLoop:
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker) get
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_event_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
@@ -390,10 +434,11 @@ class HermesAgentLoop:
|
||||
pass
|
||||
|
||||
# Add tool response to conversation
|
||||
tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"tool_call_id": tc_id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
38
environments/benchmarks/tblite/local.yaml
Normal file
38
environments/benchmarks/tblite/local.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Docker Backend (Local Compute)
|
||||
#
|
||||
# Runs tasks in Docker containers on the local machine.
|
||||
# Sandboxed like Modal but no cloud costs. Good for dev/testing.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local.yaml
|
||||
#
|
||||
# # Override concurrency:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local.yaml \
|
||||
# --env.eval_concurrency 4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "docker"
|
||||
terminal_timeout: 300
|
||||
tool_pool_size: 16
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200
|
||||
eval_concurrency: 8 # max 8 tasks at once
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: false
|
||||
wandb_name: "openthoughts-tblite-local"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite-local"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
40
environments/benchmarks/tblite/local_vllm.yaml
Normal file
40
environments/benchmarks/tblite/local_vllm.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
# OpenThoughts-TBLite Evaluation -- Local vLLM Backend
|
||||
#
|
||||
# Runs against a local vLLM server with Docker sandboxes.
|
||||
#
|
||||
# Start the vLLM server from the atropos directory:
|
||||
# python -m example_trainer.vllm_api_server \
|
||||
# --model Qwen/Qwen3-4B-Instruct-2507 \
|
||||
# --port 9001 \
|
||||
# --gpu-memory-utilization 0.8 \
|
||||
# --max-model-len=32000
|
||||
#
|
||||
# Then run:
|
||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
||||
# --config environments/benchmarks/tblite/local_vllm.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 16000
|
||||
agent_temperature: 0.6
|
||||
terminal_backend: "docker"
|
||||
terminal_timeout: 300
|
||||
tool_pool_size: 16
|
||||
dataset_name: "NousResearch/openthoughts-tblite"
|
||||
test_timeout: 600
|
||||
task_timeout: 1200
|
||||
eval_concurrency: 8
|
||||
tool_call_parser: "hermes"
|
||||
system_prompt: "You are an expert terminal agent. You MUST use the provided tools to complete tasks. Use the terminal tool to run shell commands, read_file to read files, write_file to write files, search_files to search, and patch to edit files. Do NOT write out solutions as text - execute them using the tools. Always start by exploring the environment with terminal commands."
|
||||
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
use_wandb: false
|
||||
wandb_name: "tblite-qwen3-4b-instruct"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/tblite-qwen3-4b-local"
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:9001"
|
||||
model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
server_type: "vllm"
|
||||
health_check: false
|
||||
@@ -118,6 +118,14 @@ class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
"Tasks exceeding this are scored as FAIL. Default 30 minutes.",
|
||||
)
|
||||
|
||||
# --- Eval concurrency ---
|
||||
eval_concurrency: int = Field(
|
||||
default=0,
|
||||
description="Maximum number of tasks to evaluate in parallel. "
|
||||
"0 means unlimited (all tasks run concurrently). "
|
||||
"Set to 8 for local backends to avoid overwhelming the machine.",
|
||||
)
|
||||
|
||||
|
||||
# Tasks that cannot run properly on Modal and are excluded from scoring.
|
||||
MODAL_INCOMPATIBLE_TASKS = {
|
||||
@@ -429,8 +437,13 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
"error": "no_image",
|
||||
}
|
||||
|
||||
# --- 2. Register per-task Modal image override ---
|
||||
register_task_env_overrides(task_id, {"modal_image": modal_image})
|
||||
# --- 2. Register per-task image override ---
|
||||
# Set both modal_image and docker_image so the task image is used
|
||||
# regardless of which backend is configured.
|
||||
register_task_env_overrides(task_id, {
|
||||
"modal_image": modal_image,
|
||||
"docker_image": modal_image,
|
||||
})
|
||||
logger.info(
|
||||
"Task %s: registered image override for task_id %s",
|
||||
task_name, task_id[:8],
|
||||
@@ -445,17 +458,37 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
messages.append({"role": "user", "content": self.format_prompt(eval_item)})
|
||||
|
||||
# --- 4. Run agent loop ---
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
# Use ManagedServer (Phase 2) for vLLM/SGLang backends to get
|
||||
# token-level tracking via /generate. Falls back to direct
|
||||
# ServerManager (Phase 1) for OpenAI endpoints.
|
||||
if self._use_managed_server():
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer,
|
||||
preserve_think_blocks=bool(self.config.thinking_mode),
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# --- 5. Verify -- run test suite in the agent's sandbox ---
|
||||
# Skip verification if the agent produced no meaningful output
|
||||
@@ -655,13 +688,19 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
|
||||
async def _eval_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Wrap rollout_and_score_eval with a per-task wall-clock timeout.
|
||||
Wrap rollout_and_score_eval with a per-task wall-clock timeout
|
||||
and optional concurrency limit via semaphore.
|
||||
|
||||
If the task exceeds task_timeout seconds, it's automatically scored
|
||||
as FAIL. This prevents any single task from hanging indefinitely.
|
||||
"""
|
||||
task_name = item.get("task_name", "unknown")
|
||||
category = item.get("category", "unknown")
|
||||
|
||||
# Acquire concurrency semaphore if configured
|
||||
if self._eval_semaphore:
|
||||
await self._eval_semaphore.acquire()
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
@@ -679,6 +718,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
finally:
|
||||
if self._eval_semaphore:
|
||||
self._eval_semaphore.release()
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
@@ -696,6 +738,13 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Set up concurrency limit if configured
|
||||
if self.config.eval_concurrency > 0:
|
||||
self._eval_semaphore = asyncio.Semaphore(self.config.eval_concurrency)
|
||||
print(f" Eval concurrency: {self.config.eval_concurrency} tasks at a time")
|
||||
else:
|
||||
self._eval_semaphore = None
|
||||
|
||||
# Route all logging through tqdm.write() so the progress bar stays
|
||||
# pinned at the bottom while log lines scroll above it.
|
||||
from tqdm import tqdm
|
||||
|
||||
115
environments/benchmarks/yc_bench/README.md
Normal file
115
environments/benchmarks/yc_bench/README.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# YC-Bench: Long-Horizon Agent Benchmark
|
||||
|
||||
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
|
||||
|
||||
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install yc-bench (optional dependency)
|
||||
pip install "hermes-agent[yc-bench]"
|
||||
|
||||
# Or install from source
|
||||
git clone https://github.com/collinear-ai/yc-bench
|
||||
cd yc-bench && pip install -e .
|
||||
|
||||
# Verify
|
||||
yc-bench --help
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# From the repo root:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
|
||||
# Or directly:
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
# Override model:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
# Quick single-preset test:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
HermesAgentLoop (our agent)
|
||||
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
|
||||
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
|
||||
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
|
||||
-> ... (100-500 turns per run)
|
||||
```
|
||||
|
||||
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
|
||||
|
||||
### Simulation Mechanics
|
||||
|
||||
- **4 skill domains**: research, inference, data_environment, training
|
||||
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
|
||||
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
|
||||
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
|
||||
- **Financial pressure**: Monthly payroll, bankruptcy = game over
|
||||
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
|
||||
|
||||
### Difficulty Presets
|
||||
|
||||
| Preset | Employees | Tasks | Focus |
|
||||
|-----------|-----------|-------|-------|
|
||||
| tutorial | 3 | 50 | Basic loop mechanics |
|
||||
| easy | 5 | 100 | Throughput awareness |
|
||||
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
|
||||
| **hard** | 7 | 200 | Precise ETA reasoning |
|
||||
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
|
||||
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
|
||||
|
||||
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
|
||||
|
||||
### Scoring
|
||||
|
||||
```
|
||||
composite = 0.5 × survival + 0.5 × normalised_funds
|
||||
```
|
||||
|
||||
- **Survival** (binary): Did the company avoid bankruptcy?
|
||||
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
|
||||
|
||||
## Configuration
|
||||
|
||||
Key fields in `default.yaml`:
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
|
||||
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
|
||||
| `max_agent_turns` | 200 | Max LLM calls per run |
|
||||
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
|
||||
| `survival_weight` | 0.5 | Weight of survival in composite score |
|
||||
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
|
||||
| `horizon_years` | null | Override horizon (null = auto from preset) |
|
||||
|
||||
## Cost & Time Estimates
|
||||
|
||||
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
|
||||
|
||||
| Preset | Turns | Time | Est. Cost |
|
||||
|--------|-------|------|-----------|
|
||||
| fast_test | ~50 | 5-10 min | $1-5 |
|
||||
| medium | ~200 | 20-40 min | $5-15 |
|
||||
| hard | ~300 | 30-60 min | $10-25 |
|
||||
|
||||
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
|
||||
|
||||
## References
|
||||
|
||||
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
|
||||
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
|
||||
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)
|
||||
0
environments/benchmarks/yc_bench/__init__.py
Normal file
0
environments/benchmarks/yc_bench/__init__.py
Normal file
43
environments/benchmarks/yc_bench/default.yaml
Normal file
43
environments/benchmarks/yc_bench/default.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
# YC-Bench Evaluation -- Default Configuration
|
||||
#
|
||||
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
|
||||
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal"]
|
||||
max_agent_turns: 200
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.0
|
||||
terminal_backend: "local"
|
||||
terminal_timeout: 60
|
||||
presets: ["fast_test", "medium", "hard"]
|
||||
seeds: [1, 2, 3]
|
||||
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
|
||||
survival_weight: 0.5 # weight of binary survival in composite score
|
||||
funds_weight: 0.5 # weight of normalised final funds in composite score
|
||||
db_dir: "/tmp/yc_bench_dbs"
|
||||
company_name: "BenchCo"
|
||||
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "yc-bench"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
34
environments/benchmarks/yc_bench/run_eval.sh
Executable file
34
environments/benchmarks/yc_bench/run_eval.sh
Executable file
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
# YC-Bench Evaluation
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
#
|
||||
# Run a single preset:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/yc-bench
|
||||
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "YC-Bench Evaluation"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
847
environments/benchmarks/yc_bench/yc_bench_env.py
Normal file
847
environments/benchmarks/yc_bench/yc_bench_env.py
Normal file
@@ -0,0 +1,847 @@
|
||||
"""
|
||||
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
|
||||
|
||||
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
|
||||
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
|
||||
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
|
||||
interacting exclusively via CLI subprocess calls against a SQLite-backed
|
||||
discrete-event simulation.
|
||||
|
||||
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
|
||||
multi-turn strategic coherence -- whether an agent can manage compounding
|
||||
decisions over hundreds of turns without going bankrupt.
|
||||
|
||||
This is an eval-only environment. Run via:
|
||||
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
|
||||
2. evaluate() -- Iterates over all runs sequentially through:
|
||||
a. rollout_and_score_eval() -- Per-run agent loop
|
||||
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
|
||||
- Runs HermesAgentLoop with terminal tool only
|
||||
- Reads final SQLite DB to extract score
|
||||
- Returns survival (0/1) + normalised funds score
|
||||
b. Aggregates per-preset and overall metrics
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
|
||||
- Deterministic: same seed + preset = same world (SHA256-based RNG)
|
||||
- Multi-dimensional scoring: survival + normalised final funds
|
||||
- Per-preset difficulty breakdown in results
|
||||
- Isolated SQLite DB per run (no cross-run state leakage)
|
||||
|
||||
Requires: pip install hermes-agent[yc-bench]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# System prompt
|
||||
# =============================================================================
|
||||
|
||||
YC_BENCH_SYSTEM_PROMPT = """\
|
||||
You are the autonomous CEO of an early-stage AI startup in a deterministic
|
||||
business simulation. You manage the company exclusively through the `yc-bench`
|
||||
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
|
||||
without going bankrupt, while **maximising final funds**.
|
||||
|
||||
## Simulation Mechanics
|
||||
|
||||
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
|
||||
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige − 1))`.
|
||||
- **Domains**: There are 4 skill domains: **research**, **inference**,
|
||||
**data_environment**, and **training**. Each has its own prestige level
|
||||
(1.0-10.0). Higher prestige unlocks better-paying tasks.
|
||||
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
|
||||
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
|
||||
is the number of active tasks assigned to that employee. Focus beats breadth.
|
||||
- **Payroll**: Deducted automatically on the first business day of each month.
|
||||
Running out of funds = bankruptcy = game over.
|
||||
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
|
||||
Time only advances when you call `yc-bench sim resume`.
|
||||
|
||||
## Task Lifecycle
|
||||
|
||||
1. Browse market tasks with `market browse`
|
||||
2. Accept a task with `task accept` (this sets its deadline)
|
||||
3. Assign employees with `task assign`
|
||||
4. Dispatch with `task dispatch` to start work
|
||||
5. Call `sim resume` to advance time and let employees make progress
|
||||
6. Tasks complete when all domain requirements are fulfilled
|
||||
|
||||
**Penalties for failure vary by difficulty preset.** Completing a task on time
|
||||
earns full reward + prestige gain. Missing a deadline or cancelling a task
|
||||
incurs prestige penalties -- cancelling is always more costly than letting a
|
||||
task fail, so cancel only as a last resort.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
### Observe
|
||||
- `yc-bench company status` -- funds, prestige, runway
|
||||
- `yc-bench employee list` -- skills, salary, active tasks
|
||||
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
|
||||
- `yc-bench task list [--status active|planned]` -- your tasks
|
||||
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
|
||||
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
|
||||
- `yc-bench report monthly` -- monthly P&L
|
||||
|
||||
### Act
|
||||
- `yc-bench task accept --task-id UUID` -- accept from market
|
||||
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
|
||||
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
|
||||
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
|
||||
- `yc-bench sim resume` -- advance simulation clock
|
||||
|
||||
### Memory (persists across context truncation)
|
||||
- `yc-bench scratchpad read` -- read your persistent notes
|
||||
- `yc-bench scratchpad write --content "text"` -- overwrite notes
|
||||
- `yc-bench scratchpad append --content "text"` -- append to notes
|
||||
- `yc-bench scratchpad clear` -- clear notes
|
||||
|
||||
## Strategy Guidelines
|
||||
|
||||
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
|
||||
high-reward tasks. Don't spread thin across all 4 domains early on.
|
||||
2. **Focus employees** -- assigning one employee to many tasks halves their
|
||||
throughput per additional task. Keep assignments concentrated.
|
||||
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
|
||||
employee assignments. This persists even if conversation context is truncated.
|
||||
4. **Monitor runway** -- always know how many months of payroll you can cover.
|
||||
Accept high-reward tasks before payroll dates.
|
||||
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
|
||||
into prestige loss, locking you out of profitable contracts.
|
||||
6. Use `finance ledger` and `report monthly` to track revenue trends.
|
||||
|
||||
## Your Turn
|
||||
|
||||
Each turn:
|
||||
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
|
||||
2. Check for completed tasks and pending deadlines.
|
||||
3. Browse market for profitable tasks within your prestige level.
|
||||
4. Accept, assign, and dispatch tasks strategically.
|
||||
5. Call `yc-bench sim resume` to advance time.
|
||||
6. Repeat until the simulation ends.
|
||||
|
||||
Think step by step before acting."""
|
||||
|
||||
# Starting funds in cents ($250,000)
|
||||
INITIAL_FUNDS_CENTS = 25_000_000
|
||||
|
||||
# Default horizon per preset (years)
|
||||
_PRESET_HORIZONS = {
|
||||
"tutorial": 1,
|
||||
"easy": 1,
|
||||
"medium": 1,
|
||||
"hard": 1,
|
||||
"nightmare": 1,
|
||||
"fast_test": 1,
|
||||
"default": 3,
|
||||
"high_reward": 1,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the YC-Bench evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
|
||||
preset selection, seed control, scoring, and simulation parameters.
|
||||
"""
|
||||
|
||||
presets: List[str] = Field(
|
||||
default=["fast_test", "medium", "hard"],
|
||||
description="YC-Bench preset names to evaluate.",
|
||||
)
|
||||
seeds: List[int] = Field(
|
||||
default=[1, 2, 3],
|
||||
description="Random seeds -- each preset x seed = one run.",
|
||||
)
|
||||
run_timeout: int = Field(
|
||||
default=3600,
|
||||
description="Maximum wall-clock seconds per run. Default 60 minutes.",
|
||||
)
|
||||
survival_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of survival (0/1) in composite score.",
|
||||
)
|
||||
funds_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of normalised final funds in composite score.",
|
||||
)
|
||||
db_dir: str = Field(
|
||||
default="/tmp/yc_bench_dbs",
|
||||
description="Directory for per-run SQLite databases.",
|
||||
)
|
||||
horizon_years: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Simulation horizon in years. If None (default), inferred from "
|
||||
"preset name (1 year for most, 3 for 'default')."
|
||||
),
|
||||
)
|
||||
company_name: str = Field(
|
||||
default="BenchCo",
|
||||
description="Name of the simulated company.",
|
||||
)
|
||||
start_date: str = Field(
|
||||
default="01/01/2025",
|
||||
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scoring helpers
|
||||
# =============================================================================
|
||||
|
||||
def _read_final_score(db_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read final game state from a YC-Bench SQLite database.
|
||||
|
||||
Returns dict with final_funds_cents (int), survived (bool),
|
||||
terminal_reason (str).
|
||||
|
||||
Note: yc-bench table names are plural -- 'companies' not 'company',
|
||||
'sim_events' not 'simulation_log'.
|
||||
"""
|
||||
if not os.path.exists(db_path):
|
||||
logger.warning("DB not found at %s", db_path)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": "db_missing",
|
||||
}
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
|
||||
# Read final funds from the 'companies' table
|
||||
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
funds = row[0] if row else 0
|
||||
|
||||
# Determine terminal reason from 'sim_events' table
|
||||
terminal_reason = "unknown"
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT event_type FROM sim_events "
|
||||
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
|
||||
"ORDER BY scheduled_at DESC LIMIT 1"
|
||||
)
|
||||
event_row = cur.fetchone()
|
||||
if event_row:
|
||||
terminal_reason = event_row[0]
|
||||
except sqlite3.OperationalError:
|
||||
# Table may not exist if simulation didn't progress
|
||||
pass
|
||||
|
||||
survived = funds >= 0 and terminal_reason != "bankruptcy"
|
||||
return {
|
||||
"final_funds_cents": funds,
|
||||
"survived": survived,
|
||||
"terminal_reason": terminal_reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to read DB %s: %s", db_path, e)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": f"db_error: {e}",
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _compute_composite_score(
|
||||
final_funds_cents: int,
|
||||
survived: bool,
|
||||
survival_weight: float = 0.5,
|
||||
funds_weight: float = 0.5,
|
||||
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score from survival and final funds.
|
||||
|
||||
Score = survival_weight * survival_score
|
||||
+ funds_weight * normalised_funds_score
|
||||
|
||||
Normalised funds uses log-scale relative to initial capital:
|
||||
- funds <= 0: 0.0
|
||||
- funds == initial: ~0.15
|
||||
- funds == 10x: ~0.52
|
||||
- funds == 100x: 1.0
|
||||
"""
|
||||
survival_score = 1.0 if survived else 0.0
|
||||
|
||||
if final_funds_cents <= 0:
|
||||
funds_score = 0.0
|
||||
else:
|
||||
max_ratio = 100.0
|
||||
ratio = final_funds_cents / max(initial_funds_cents, 1)
|
||||
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
|
||||
|
||||
return survival_weight * survival_score + funds_weight * funds_score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
YC-Bench long-horizon agent benchmark environment (eval-only).
|
||||
|
||||
Each eval item is a (preset, seed) pair. The environment initialises the
|
||||
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
|
||||
a competing built-in agent loop). The HermesAgentLoop then drives the
|
||||
interaction by calling individual yc-bench CLI commands via the terminal tool.
|
||||
|
||||
After the agent loop ends, the SQLite DB is read to extract the final score.
|
||||
|
||||
Scoring:
|
||||
composite = 0.5 * survival + 0.5 * normalised_funds
|
||||
"""
|
||||
|
||||
name = "yc-bench"
|
||||
env_config_cls = YCBenchEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
|
||||
env_config = YCBenchEvalConfig(
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
max_agent_turns=200,
|
||||
max_token_length=32000,
|
||||
agent_temperature=0.0,
|
||||
system_prompt=YC_BENCH_SYSTEM_PROMPT,
|
||||
terminal_backend="local",
|
||||
terminal_timeout=60,
|
||||
presets=["fast_test", "medium", "hard"],
|
||||
seeds=[1, 2, 3],
|
||||
run_timeout=3600,
|
||||
survival_weight=0.5,
|
||||
funds_weight=0.5,
|
||||
db_dir="/tmp/yc_bench_dbs",
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="yc-bench",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Verify yc-bench is installed and build the eval matrix."""
|
||||
# Verify yc-bench CLI is available
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise FileNotFoundError
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
raise RuntimeError(
|
||||
"yc-bench CLI not found. Install with:\n"
|
||||
' pip install "hermes-agent[yc-bench]"\n'
|
||||
"Or: git clone https://github.com/collinear-ai/yc-bench "
|
||||
"&& cd yc-bench && pip install -e ."
|
||||
)
|
||||
print("yc-bench CLI verified.")
|
||||
|
||||
# Build eval matrix: preset x seed
|
||||
self.all_eval_items = [
|
||||
{"preset": preset, "seed": seed}
|
||||
for preset in self.config.presets
|
||||
for seed in self.config.seeds
|
||||
]
|
||||
self.iter = 0
|
||||
|
||||
os.makedirs(self.config.db_dir, exist_ok=True)
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL log for crash-safe result persistence
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = threading.Lock()
|
||||
|
||||
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
|
||||
for item in self.all_eval_items:
|
||||
print(f" preset={item['preset']!r} seed={item['seed']}")
|
||||
print(f"Streaming results to: {self._streaming_path}\n")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single run result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs (eval-only -- not used)
|
||||
# =========================================================================
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
return (
|
||||
f"A new YC-Bench simulation has been initialized "
|
||||
f"(preset='{preset}', seed={seed}).\n"
|
||||
f"Your company '{self.config.company_name}' is ready.\n\n"
|
||||
"Begin by calling:\n"
|
||||
"1. `yc-bench company status` -- see your starting funds and prestige\n"
|
||||
"2. `yc-bench employee list` -- see your team and their skills\n"
|
||||
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
|
||||
"you can take\n\n"
|
||||
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
|
||||
"`yc-bench sim resume` to advance time. Repeat this loop until the "
|
||||
"simulation ends (horizon reached or bankruptcy)."
|
||||
)
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Per-run evaluation
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single (preset, seed) run.
|
||||
|
||||
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
|
||||
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
|
||||
3. Runs HermesAgentLoop with terminal tool
|
||||
4. Reads SQLite DB to compute final score
|
||||
5. Returns result dict with survival, funds, and composite score
|
||||
"""
|
||||
preset = eval_item["preset"]
|
||||
seed = eval_item["seed"]
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
run_key = f"{preset}_seed{seed}_{run_id}"
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
|
||||
run_start = time.time()
|
||||
|
||||
# Isolated DB per run -- prevents cross-run state leakage
|
||||
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
||||
os.environ["YC_BENCH_EXPERIMENT"] = preset
|
||||
|
||||
# Determine horizon: explicit config override > preset lookup > default 1
|
||||
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
|
||||
|
||||
try:
|
||||
# ----------------------------------------------------------
|
||||
# Step 1: Initialise the simulation via CLI
|
||||
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
|
||||
# `yc-bench run` starts yc-bench's own LLM agent loop (via
|
||||
# LiteLLM), which would compete with our HermesAgentLoop.
|
||||
# `sim init` just sets up the world and returns.
|
||||
# ----------------------------------------------------------
|
||||
init_cmd = [
|
||||
"yc-bench", "sim", "init",
|
||||
"--seed", str(seed),
|
||||
"--start-date", self.config.start_date,
|
||||
"--company-name", self.config.company_name,
|
||||
"--horizon-years", str(horizon),
|
||||
]
|
||||
init_result = subprocess.run(
|
||||
init_cmd, capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
if init_result.returncode != 0:
|
||||
error_msg = (init_result.stderr or init_result.stdout).strip()
|
||||
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
|
||||
|
||||
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 2: Run the HermesAgentLoop
|
||||
# ----------------------------------------------------------
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": self.format_prompt(eval_item)},
|
||||
]
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=run_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 3: Read final score from the simulation DB
|
||||
# ----------------------------------------------------------
|
||||
score_data = _read_final_score(db_path)
|
||||
final_funds = score_data["final_funds_cents"]
|
||||
survived = score_data["survived"]
|
||||
terminal_reason = score_data["terminal_reason"]
|
||||
|
||||
composite = _compute_composite_score(
|
||||
final_funds_cents=final_funds,
|
||||
survived=survived,
|
||||
survival_weight=self.config.survival_weight,
|
||||
funds_weight=self.config.funds_weight,
|
||||
)
|
||||
|
||||
elapsed = time.time() - run_start
|
||||
status = "SURVIVED" if survived else "BANKRUPT"
|
||||
if final_funds >= 0:
|
||||
funds_str = f"${final_funds / 100:,.0f}"
|
||||
else:
|
||||
funds_str = f"-${abs(final_funds) / 100:,.0f}"
|
||||
|
||||
tqdm.write(
|
||||
f" [{status}] preset={preset!r} seed={seed} "
|
||||
f"funds={funds_str} score={composite:.3f} "
|
||||
f"turns={result.turns_used} ({elapsed:.0f}s)"
|
||||
)
|
||||
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": survived,
|
||||
"final_funds_cents": final_funds,
|
||||
"final_funds_usd": final_funds / 100,
|
||||
"terminal_reason": terminal_reason,
|
||||
"composite_score": composite,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"elapsed_seconds": elapsed,
|
||||
"db_path": db_path,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - run_start
|
||||
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
|
||||
tqdm.write(
|
||||
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"error: {e}",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": str(e),
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""Wrap a single rollout with a wall-clock timeout."""
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.run_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] preset={preset!r} seed={seed} "
|
||||
f"(exceeded {self.config.run_timeout}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": "timeout",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run YC-Bench evaluation over all (preset, seed) combinations.
|
||||
|
||||
Runs sequentially -- each run is 100-500 turns, parallelising would
|
||||
be prohibitively expensive and cause env var conflicts.
|
||||
"""
|
||||
start_time = time.time()
|
||||
from tqdm import tqdm
|
||||
|
||||
# --- tqdm-compatible logging handler (TB2 pattern) ---
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
root = logging.getLogger()
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(levelname)s %(name)s: %(message)s")
|
||||
)
|
||||
root.handlers = [handler]
|
||||
for noisy in ("httpx", "openai"):
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
|
||||
# --- Print config summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting YC-Bench Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Presets: {self.config.presets}")
|
||||
print(f" Seeds: {self.config.seeds}")
|
||||
print(f" Total runs: {len(self.all_eval_items)}")
|
||||
print(f" Max turns/run: {self.config.max_agent_turns}")
|
||||
print(f" Run timeout: {self.config.run_timeout}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = []
|
||||
pbar = tqdm(
|
||||
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
|
||||
)
|
||||
|
||||
try:
|
||||
for item in self.all_eval_items:
|
||||
result = await self._run_with_timeout(item)
|
||||
results.append(result)
|
||||
survived_count = sum(1 for r in results if r.get("survived"))
|
||||
pbar.set_postfix_str(
|
||||
f"survived={survived_count}/{len(results)}"
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
|
||||
pbar.close()
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
return
|
||||
|
||||
pbar.close()
|
||||
end_time = time.time()
|
||||
|
||||
# --- Compute metrics ---
|
||||
valid = [r for r in results if r is not None]
|
||||
if not valid:
|
||||
print("Warning: No valid results.")
|
||||
return
|
||||
|
||||
total = len(valid)
|
||||
survived_total = sum(1 for r in valid if r.get("survived"))
|
||||
survival_rate = survived_total / total if total else 0.0
|
||||
avg_score = (
|
||||
sum(r.get("composite_score", 0) for r in valid) / total
|
||||
if total
|
||||
else 0.0
|
||||
)
|
||||
|
||||
preset_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid:
|
||||
preset_results[r["preset"]].append(r)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/survival_rate": survival_rate,
|
||||
"eval/avg_composite_score": avg_score,
|
||||
"eval/total_runs": total,
|
||||
"eval/survived_runs": survived_total,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
key = preset.replace("-", "_")
|
||||
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
|
||||
eval_metrics[f"eval/avg_score_{key}"] = pa
|
||||
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# --- Print summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("YC-Bench Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(
|
||||
f"Overall survival rate: {survival_rate:.1%} "
|
||||
f"({survived_total}/{total})"
|
||||
)
|
||||
print(f"Average composite score: {avg_score:.4f}")
|
||||
print(f"Evaluation time: {end_time - start_time:.1f}s")
|
||||
|
||||
print("\nPer-preset breakdown:")
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
|
||||
for r in items:
|
||||
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
|
||||
funds = r.get("final_funds_usd", 0)
|
||||
print(
|
||||
f" seed={r['seed']} [{status}] "
|
||||
f"${funds:,.0f} "
|
||||
f"score={r.get('composite_score', 0):.3f}"
|
||||
)
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# --- Log results ---
|
||||
samples = [
|
||||
{k: v for k, v in r.items() if k != "messages"} for r in valid
|
||||
]
|
||||
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging results: {e}")
|
||||
|
||||
# --- Cleanup (TB2 pattern) ---
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f"Results saved to: {self._streaming_path}")
|
||||
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log YC-Bench-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
for k, v in self.eval_metrics:
|
||||
wandb_metrics[k] = v
|
||||
self.eval_metrics = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
YCBenchEvalEnv.cli()
|
||||
@@ -114,8 +114,8 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
# --- Terminal backend ---
|
||||
terminal_backend: str = Field(
|
||||
default="local",
|
||||
description="Terminal backend: 'local', 'docker', 'modal', 'ssh', 'singularity'. "
|
||||
"Modal recommended for production RL (cloud isolation per rollout).",
|
||||
description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
|
||||
"Modal or Daytona recommended for production RL (cloud isolation per rollout).",
|
||||
)
|
||||
terminal_timeout: int = Field(
|
||||
default=120,
|
||||
@@ -229,6 +229,12 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
from environments.agent_loop import resize_tool_pool
|
||||
resize_tool_pool(config.tool_pool_size)
|
||||
|
||||
# Set tool_parser on the ServerManager so ManagedServer uses it
|
||||
# for bidirectional tool call translation (raw text ↔ OpenAI tool_calls).
|
||||
if hasattr(self.server, 'tool_parser'):
|
||||
self.server.tool_parser = config.tool_call_parser
|
||||
print(f"🔧 Tool parser: {config.tool_call_parser}")
|
||||
|
||||
# Current group's resolved tools (set in collect_trajectories)
|
||||
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
||||
|
||||
@@ -466,22 +472,14 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
|
||||
# Load the tool call parser from registry based on config
|
||||
from environments.tool_call_parsers import get_parser
|
||||
try:
|
||||
tc_parser = get_parser(self.config.tool_call_parser)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"Tool call parser '%s' not found, falling back to 'hermes'",
|
||||
self.config.tool_call_parser,
|
||||
)
|
||||
tc_parser = get_parser("hermes")
|
||||
|
||||
# Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs
|
||||
# tool_parser is set on ServerManager in __init__ and passed through
|
||||
# to ManagedServer, which uses ToolCallTranslator for bidirectional
|
||||
# translation between raw text and OpenAI tool_calls.
|
||||
try:
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer,
|
||||
tool_call_parser=tc_parser,
|
||||
preserve_think_blocks=bool(self.config.thinking_mode),
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
|
||||
@@ -114,11 +114,27 @@ def _patch_swerex_modal():
|
||||
self._worker = _AsyncWorker()
|
||||
self._worker.start()
|
||||
|
||||
# Pre-build a modal.Image with pip fix for Modal's legacy image builder.
|
||||
# Modal requires `python -m pip` to work during image build, but some
|
||||
# task images (e.g., TBLite's broken-python) have intentionally broken pip.
|
||||
# Fix: remove stale pip dist-info and reinstall via ensurepip before Modal
|
||||
# tries to use it. This is a no-op for images where pip already works.
|
||||
import modal as _modal
|
||||
image_spec = self.config.image
|
||||
if isinstance(image_spec, str):
|
||||
image_spec = _modal.Image.from_registry(
|
||||
image_spec,
|
||||
setup_dockerfile_commands=[
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
],
|
||||
)
|
||||
|
||||
# Create AND start the deployment entirely on the worker's loop/thread
|
||||
# so all gRPC channels and async state are bound to that loop
|
||||
async def _create_and_start():
|
||||
deployment = ModalDeployment(
|
||||
image=self.config.image,
|
||||
image=image_spec,
|
||||
startup_timeout=self.config.startup_timeout,
|
||||
runtime_timeout=self.config.runtime_timeout,
|
||||
deployment_timeout=self.config.deployment_timeout,
|
||||
|
||||
@@ -35,7 +35,8 @@ class DeepSeekV31ToolCallParser(ToolCallParser):
|
||||
|
||||
# Regex captures: function_name, function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>"
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
|
||||
@@ -38,7 +38,8 @@ class DeepSeekV3ToolCallParser(ToolCallParser):
|
||||
|
||||
# Regex captures: type, function_name, function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<|tool▁call▁end|>"
|
||||
r"<|tool▁call▁begin|>(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<|tool▁call▁end|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
|
||||
@@ -44,7 +44,7 @@ _tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
|
||||
"""
|
||||
Run a tool call in a thread pool executor so backends that use asyncio.run()
|
||||
internally (modal, docker) get a clean event loop.
|
||||
internally (modal, docker, daytona) get a clean event loop.
|
||||
|
||||
If we're already in an async context, executes handle_function_call() in a
|
||||
disposable worker thread and blocks for the result.
|
||||
@@ -95,7 +95,7 @@ class ToolContext:
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
|
||||
|
||||
# Run via thread helper so modal/docker backends' asyncio.run() doesn't deadlock
|
||||
# Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock
|
||||
result = _run_tool_in_thread(
|
||||
"terminal",
|
||||
{"command": command, "timeout": timeout},
|
||||
|
||||
@@ -701,6 +701,8 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
if images:
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
@@ -727,10 +729,13 @@ class BasePlatformAdapter(ABC):
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
for image_url, alt_text in images:
|
||||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "")
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
if self._is_animation_url(image_url):
|
||||
img_result = await self.send_animation(
|
||||
@@ -745,9 +750,9 @@ class BasePlatformAdapter(ABC):
|
||||
caption=alt_text if alt_text else None,
|
||||
)
|
||||
if not img_result.success:
|
||||
print(f"[{self.name}] Failed to send image: {img_result.error}")
|
||||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
print(f"[{self.name}] Error sending image: {img_err}")
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
|
||||
@@ -267,6 +267,43 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
@@ -179,6 +179,35 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
"""Slack doesn't have a direct typing indicator API for bots."""
|
||||
pass
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file to Slack by uploading it."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=image_path,
|
||||
filename=os.path.basename(image_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
@@ -8,10 +8,13 @@ Uses python-telegram-bot library for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from telegram import Update, Bot, Message
|
||||
from telegram.ext import (
|
||||
@@ -73,6 +76,19 @@ def _escape_mdv2(text: str) -> str:
|
||||
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
|
||||
|
||||
|
||||
def _strip_mdv2(text: str) -> str:
|
||||
"""Strip MarkdownV2 escape backslashes to produce clean plain text.
|
||||
|
||||
Also removes MarkdownV2 bold markers (*text* -> text) so the fallback
|
||||
doesn't show stray asterisks from header/bold conversion.
|
||||
"""
|
||||
# Remove escape backslashes before special characters
|
||||
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
|
||||
# Remove MarkdownV2 bold markers that format_message converted from **bold**
|
||||
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
@@ -199,9 +215,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
text=plain_chunk,
|
||||
parse_mode=None, # Plain text
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
@@ -286,6 +306,34 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -293,12 +341,16 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo."""
|
||||
"""Send an image natively as a Telegram photo.
|
||||
|
||||
Tries URL-based send first (fast, works for <5MB images).
|
||||
Falls back to downloading and uploading as file (supports up to 10MB).
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_url,
|
||||
@@ -307,9 +359,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send photo, falling back to URL: {e}")
|
||||
# Fallback: send as text link
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
logger.warning("[%s] URL-based send_photo failed (%s), trying file upload", self.name, e)
|
||||
# Fallback: download and upload as file (supports up to 10MB)
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(image_url)
|
||||
resp.raise_for_status()
|
||||
image_data = resp.content
|
||||
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_data,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e2:
|
||||
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
|
||||
# Final fallback: send URL as text
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
|
||||
@@ -28,6 +28,41 @@ from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _kill_port_process(port: int) -> None:
|
||||
"""Kill any process listening on the given TCP port."""
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
# Use netstat to find the PID bound to this port, then taskkill
|
||||
result = subprocess.run(
|
||||
["netstat", "-ano", "-p", "TCP"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
for line in result.stdout.splitlines():
|
||||
parts = line.split()
|
||||
if len(parts) >= 5 and parts[3] == "LISTENING":
|
||||
local_addr = parts[1]
|
||||
if local_addr.endswith(f":{port}"):
|
||||
try:
|
||||
subprocess.run(
|
||||
["taskkill", "/PID", parts[4], "/F"],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except subprocess.SubprocessError:
|
||||
pass
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["fuser", f"{port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
subprocess.run(
|
||||
["fuser", "-k", f"{port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
@@ -145,21 +180,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
self._session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Kill any orphaned bridge from a previous gateway run
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["fuser", f"{self._bridge_port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
# Port is in use — kill the process
|
||||
subprocess.run(
|
||||
["fuser", "-k", f"{self._bridge_port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
import time
|
||||
time.sleep(2)
|
||||
except Exception:
|
||||
pass
|
||||
_kill_port_process(self._bridge_port)
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
# Start the bridge process in its own process group.
|
||||
# Route output to a log file so QR codes, errors, and reconnection
|
||||
@@ -293,13 +316,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Error stopping bridge: {e}")
|
||||
|
||||
# Also kill any orphaned bridge processes on our port
|
||||
try:
|
||||
subprocess.run(
|
||||
["fuser", "-k", f"{self._bridge_port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
_kill_port_process(self._bridge_port)
|
||||
|
||||
self._running = False
|
||||
self._bridge_process = None
|
||||
|
||||
632
gateway/run.py
632
gateway/run.py
@@ -66,6 +66,7 @@ if _config_path.exists():
|
||||
"docker_image": "TERMINAL_DOCKER_IMAGE",
|
||||
"singularity_image": "TERMINAL_SINGULARITY_IMAGE",
|
||||
"modal_image": "TERMINAL_MODAL_IMAGE",
|
||||
"daytona_image": "TERMINAL_DAYTONA_IMAGE",
|
||||
"ssh_host": "TERMINAL_SSH_HOST",
|
||||
"ssh_user": "TERMINAL_SSH_USER",
|
||||
"ssh_port": "TERMINAL_SSH_PORT",
|
||||
@@ -74,6 +75,7 @@ if _config_path.exists():
|
||||
"container_memory": "TERMINAL_CONTAINER_MEMORY",
|
||||
"container_disk": "TERMINAL_CONTAINER_DISK",
|
||||
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
|
||||
"sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
}
|
||||
for _cfg_key, _env_var in _terminal_env_map.items():
|
||||
if _cfg_key in _terminal_cfg:
|
||||
@@ -92,6 +94,11 @@ if _config_path.exists():
|
||||
if _agent_cfg and isinstance(_agent_cfg, dict):
|
||||
if "max_turns" in _agent_cfg:
|
||||
os.environ["HERMES_MAX_ITERATIONS"] = str(_agent_cfg["max_turns"])
|
||||
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
|
||||
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
|
||||
_tz_cfg = _cfg.get("timezone", "")
|
||||
if _tz_cfg and isinstance(_tz_cfg, str) and "HERMES_TIMEZONE" not in os.environ:
|
||||
os.environ["HERMES_TIMEZONE"] = _tz_cfg.strip()
|
||||
except Exception:
|
||||
pass # Non-fatal; gateway can still run with .env values
|
||||
|
||||
@@ -101,11 +108,13 @@ os.environ["HERMES_QUIET"] = "1"
|
||||
# Enable interactive exec approval for dangerous commands on messaging platforms
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
|
||||
# Set terminal working directory for messaging platforms
|
||||
# Uses MESSAGING_CWD if set, otherwise defaults to home directory
|
||||
# This is separate from CLI which uses the directory where `hermes` is run
|
||||
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
|
||||
os.environ["TERMINAL_CWD"] = messaging_cwd
|
||||
# Set terminal working directory for messaging platforms.
|
||||
# If the user set an explicit path in config.yaml (not "." or "auto"),
|
||||
# respect it. Otherwise use MESSAGING_CWD or default to home directory.
|
||||
_configured_cwd = os.environ.get("TERMINAL_CWD", "")
|
||||
if not _configured_cwd or _configured_cwd in (".", "auto", "cwd"):
|
||||
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
|
||||
os.environ["TERMINAL_CWD"] = messaging_cwd
|
||||
|
||||
from gateway.config import (
|
||||
Platform,
|
||||
@@ -172,7 +181,6 @@ class GatewayRunner:
|
||||
self.session_store = SessionStore(
|
||||
self.config.sessions_dir, self.config,
|
||||
has_active_processes_fn=lambda key: process_registry.has_active_for_session(key),
|
||||
on_auto_reset=self._flush_memories_before_reset,
|
||||
)
|
||||
self.delivery_router = DeliveryRouter(self.config)
|
||||
self._running = False
|
||||
@@ -203,15 +211,14 @@ class GatewayRunner:
|
||||
from gateway.hooks import HookRegistry
|
||||
self.hooks = HookRegistry()
|
||||
|
||||
def _flush_memories_before_reset(self, old_entry):
|
||||
"""Prompt the agent to save memories/skills before an auto-reset.
|
||||
|
||||
Called synchronously by SessionStore before destroying an expired session.
|
||||
Loads the transcript, gives the agent a real turn with memory + skills
|
||||
tools, and explicitly asks it to preserve anything worth keeping.
|
||||
def _flush_memories_for_session(self, old_session_id: str):
|
||||
"""Prompt the agent to save memories/skills before context is lost.
|
||||
|
||||
Synchronous worker — meant to be called via run_in_executor from
|
||||
an async context so it doesn't block the event loop.
|
||||
"""
|
||||
try:
|
||||
history = self.session_store.load_transcript(old_entry.session_id)
|
||||
history = self.session_store.load_transcript(old_session_id)
|
||||
if not history or len(history) < 4:
|
||||
return
|
||||
|
||||
@@ -225,7 +232,7 @@ class GatewayRunner:
|
||||
max_iterations=8,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory", "skills"],
|
||||
session_id=old_entry.session_id,
|
||||
session_id=old_session_id,
|
||||
)
|
||||
|
||||
# Build conversation history from transcript
|
||||
@@ -254,9 +261,14 @@ class GatewayRunner:
|
||||
user_message=flush_prompt,
|
||||
conversation_history=msgs,
|
||||
)
|
||||
logger.info("Pre-reset save completed for session %s", old_entry.session_id)
|
||||
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Pre-reset save failed for session %s: %s", old_entry.session_id, e)
|
||||
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
|
||||
|
||||
async def _async_flush_memories(self, old_session_id: str):
|
||||
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
|
||||
|
||||
@staticmethod
|
||||
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
||||
@@ -324,7 +336,7 @@ class GatewayRunner:
|
||||
|
||||
Checks HERMES_REASONING_EFFORT env var first, then agent.reasoning_effort
|
||||
in config.yaml. Valid: "xhigh", "high", "medium", "low", "minimal", "none".
|
||||
Returns None to use default (xhigh).
|
||||
Returns None to use default (medium).
|
||||
"""
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
@@ -345,7 +357,7 @@ class GatewayRunner:
|
||||
valid = ("xhigh", "high", "medium", "low", "minimal")
|
||||
if effort in valid:
|
||||
return {"enabled": True, "effort": effort}
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (xhigh)", effort)
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -458,10 +470,50 @@ class GatewayRunner:
|
||||
# Check if we're restarting after a /update command
|
||||
await self._send_update_notification()
|
||||
|
||||
# Start background session expiry watcher for proactive memory flushing
|
||||
asyncio.create_task(self._session_expiry_watcher())
|
||||
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
return True
|
||||
|
||||
async def _session_expiry_watcher(self, interval: int = 300):
|
||||
"""Background task that proactively flushes memories for expired sessions.
|
||||
|
||||
Runs every `interval` seconds (default 5 min). For each session that
|
||||
has expired according to its reset policy, flushes memories in a thread
|
||||
pool and marks the session so it won't be flushed again.
|
||||
|
||||
This means memories are already saved by the time the user sends their
|
||||
next message, so there's no blocking delay.
|
||||
"""
|
||||
await asyncio.sleep(60) # initial delay — let the gateway fully start
|
||||
while self._running:
|
||||
try:
|
||||
self.session_store._ensure_loaded()
|
||||
for key, entry in list(self.session_store._entries.items()):
|
||||
if entry.session_id in self.session_store._pre_flushed_sessions:
|
||||
continue # already flushed this session
|
||||
if not self.session_store._is_session_expired(entry):
|
||||
continue # session still active
|
||||
# Session has expired — flush memories in the background
|
||||
logger.info(
|
||||
"Session %s expired (key=%s), flushing memories proactively",
|
||||
entry.session_id, key,
|
||||
)
|
||||
try:
|
||||
await self._async_flush_memories(entry.session_id)
|
||||
self.session_store._pre_flushed_sessions.add(entry.session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e)
|
||||
except Exception as e:
|
||||
logger.debug("Session expiry watcher error: %s", e)
|
||||
# Sleep in small increments so we can stop quickly
|
||||
for _ in range(interval):
|
||||
if not self._running:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the gateway and disconnect all adapters."""
|
||||
logger.info("Stopping gateway...")
|
||||
@@ -658,7 +710,8 @@ class GatewayRunner:
|
||||
# Emit command:* hook for any recognized slash command
|
||||
_known_commands = {"new", "reset", "help", "status", "stop", "model",
|
||||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage", "reload-mcp", "update"}
|
||||
"compress", "usage", "insights", "reload-mcp", "update",
|
||||
"title"}
|
||||
if command and command in _known_commands:
|
||||
await self.hooks.emit(f"command:{command}", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
@@ -682,6 +735,9 @@ class GatewayRunner:
|
||||
if command == "model":
|
||||
return await self._handle_model_command(event)
|
||||
|
||||
if command == "provider":
|
||||
return await self._handle_provider_command(event)
|
||||
|
||||
if command == "personality":
|
||||
return await self._handle_personality_command(event)
|
||||
|
||||
@@ -700,11 +756,17 @@ class GatewayRunner:
|
||||
if command == "usage":
|
||||
return await self._handle_usage_command(event)
|
||||
|
||||
if command == "insights":
|
||||
return await self._handle_insights_command(event)
|
||||
|
||||
if command == "reload-mcp":
|
||||
return await self._handle_reload_mcp_command(event)
|
||||
|
||||
if command == "update":
|
||||
return await self._handle_update_command(event)
|
||||
|
||||
if command == "title":
|
||||
return await self._handle_title_command(event)
|
||||
|
||||
# Skill slash commands: /skill-name loads the skill and sends to agent
|
||||
if command:
|
||||
@@ -779,6 +841,167 @@ class GatewayRunner:
|
||||
# Load conversation history from transcript
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Session hygiene: auto-compress pathologically large transcripts
|
||||
#
|
||||
# Long-lived gateway sessions can accumulate enough history that
|
||||
# every new message rehydrates an oversized transcript, causing
|
||||
# repeated truncation/context failures. Detect this early and
|
||||
# compress proactively — before the agent even starts. (#628)
|
||||
# -----------------------------------------------------------------
|
||||
if history and len(history) >= 4:
|
||||
from agent.model_metadata import estimate_messages_tokens_rough
|
||||
|
||||
# Read thresholds from config.yaml → session_hygiene section
|
||||
_hygiene_cfg = {}
|
||||
try:
|
||||
_hyg_cfg_path = _hermes_home / "config.yaml"
|
||||
if _hyg_cfg_path.exists():
|
||||
import yaml as _hyg_yaml
|
||||
with open(_hyg_cfg_path) as _hyg_f:
|
||||
_hyg_data = _hyg_yaml.safe_load(_hyg_f) or {}
|
||||
_hygiene_cfg = _hyg_data.get("session_hygiene", {})
|
||||
if not isinstance(_hygiene_cfg, dict):
|
||||
_hygiene_cfg = {}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_compress_token_threshold = int(
|
||||
_hygiene_cfg.get("auto_compress_tokens", 100_000)
|
||||
)
|
||||
_compress_msg_threshold = int(
|
||||
_hygiene_cfg.get("auto_compress_messages", 200)
|
||||
)
|
||||
_warn_token_threshold = int(
|
||||
_hygiene_cfg.get("warn_tokens", 200_000)
|
||||
)
|
||||
|
||||
_msg_count = len(history)
|
||||
_approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
_needs_compress = (
|
||||
_approx_tokens >= _compress_token_threshold
|
||||
or _msg_count >= _compress_msg_threshold
|
||||
)
|
||||
|
||||
if _needs_compress:
|
||||
logger.info(
|
||||
"Session hygiene: %s messages, ~%s tokens — auto-compressing "
|
||||
"(thresholds: %s msgs / %s tokens)",
|
||||
_msg_count, f"{_approx_tokens:,}",
|
||||
_compress_msg_threshold, f"{_compress_token_threshold:,}",
|
||||
)
|
||||
|
||||
_hyg_adapter = self.adapters.get(source.platform)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"🗜️ Session is large ({_msg_count} messages, "
|
||||
f"~{_approx_tokens:,} tokens). Auto-compressing..."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from run_agent import AIAgent
|
||||
|
||||
_hyg_runtime = _resolve_runtime_agent_kwargs()
|
||||
if _hyg_runtime.get("api_key"):
|
||||
_hyg_msgs = [
|
||||
{"role": m.get("role"), "content": m.get("content")}
|
||||
for m in history
|
||||
if m.get("role") in ("user", "assistant")
|
||||
and m.get("content")
|
||||
]
|
||||
|
||||
if len(_hyg_msgs) >= 4:
|
||||
_hyg_agent = AIAgent(
|
||||
**_hyg_runtime,
|
||||
max_iterations=4,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
_compressed, _ = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: _hyg_agent._compress_context(
|
||||
_hyg_msgs, "",
|
||||
approx_tokens=_approx_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
self.session_store.rewrite_transcript(
|
||||
session_entry.session_id, _compressed
|
||||
)
|
||||
history = _compressed
|
||||
_new_count = len(_compressed)
|
||||
_new_tokens = estimate_messages_tokens_rough(
|
||||
_compressed
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Session hygiene: compressed %s → %s msgs, "
|
||||
"~%s → ~%s tokens",
|
||||
_msg_count, _new_count,
|
||||
f"{_approx_tokens:,}", f"{_new_tokens:,}",
|
||||
)
|
||||
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"🗜️ Compressed: {_msg_count} → "
|
||||
f"{_new_count} messages, "
|
||||
f"~{_approx_tokens:,} → "
|
||||
f"~{_new_tokens:,} tokens"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Still too large after compression — warn user
|
||||
if _new_tokens >= _warn_token_threshold:
|
||||
logger.warning(
|
||||
"Session hygiene: still ~%s tokens after "
|
||||
"compression — suggesting /reset",
|
||||
f"{_new_tokens:,}",
|
||||
)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
"⚠️ Session is still very large "
|
||||
"after compression "
|
||||
f"(~{_new_tokens:,} tokens). "
|
||||
"Consider using /reset to start "
|
||||
"fresh if you experience issues."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Session hygiene auto-compress failed: %s", e
|
||||
)
|
||||
# Compression failed and session is dangerously large
|
||||
if _approx_tokens >= _warn_token_threshold:
|
||||
_hyg_adapter = self.adapters.get(source.platform)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"⚠️ Session is very large "
|
||||
f"({_msg_count} messages, "
|
||||
f"~{_approx_tokens:,} tokens) and "
|
||||
"auto-compression failed. Consider "
|
||||
"using /compress or /reset to avoid "
|
||||
"issues."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# First-message onboarding -- only on the very first interaction ever
|
||||
if not history and not self.session_store.has_any_sessions():
|
||||
context_prompt += (
|
||||
@@ -1003,33 +1226,12 @@ class GatewayRunner:
|
||||
# Get existing session key
|
||||
session_key = self.session_store._generate_session_key(source)
|
||||
|
||||
# Memory flush before reset: load the old transcript and let a
|
||||
# temporary agent save memories before the session is wiped.
|
||||
# Flush memories in the background (fire-and-forget) so the user
|
||||
# gets the "Session reset!" response immediately.
|
||||
try:
|
||||
old_entry = self.session_store._entries.get(session_key)
|
||||
if old_entry:
|
||||
old_history = self.session_store.load_transcript(old_entry.session_id)
|
||||
if old_history:
|
||||
from run_agent import AIAgent
|
||||
loop = asyncio.get_event_loop()
|
||||
_flush_kwargs = _resolve_runtime_agent_kwargs()
|
||||
def _do_flush():
|
||||
tmp_agent = AIAgent(
|
||||
**_flush_kwargs,
|
||||
max_iterations=5,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=old_entry.session_id,
|
||||
)
|
||||
# Build simple message list from transcript
|
||||
msgs = []
|
||||
for m in old_history:
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
if role in ("user", "assistant") and content:
|
||||
msgs.append({"role": role, "content": content})
|
||||
tmp_agent.flush_memories(msgs)
|
||||
await loop.run_in_executor(None, _do_flush)
|
||||
asyncio.create_task(self._async_flush_memories(old_entry.session_id))
|
||||
except Exception as e:
|
||||
logger.debug("Gateway memory flush on reset failed: %s", e)
|
||||
|
||||
@@ -1096,13 +1298,16 @@ class GatewayRunner:
|
||||
"`/reset` — Reset conversation history",
|
||||
"`/status` — Show session info",
|
||||
"`/stop` — Interrupt the running agent",
|
||||
"`/model [name]` — Show or change the model",
|
||||
"`/model [provider:model]` — Show/change model (or switch provider)",
|
||||
"`/provider` — Show available providers and auth status",
|
||||
"`/personality [name]` — Set a personality",
|
||||
"`/retry` — Retry your last message",
|
||||
"`/undo` — Remove the last exchange",
|
||||
"`/sethome` — Set this chat as the home channel",
|
||||
"`/compress` — Compress conversation context",
|
||||
"`/title [name]` — Set or show the session title",
|
||||
"`/usage` — Show token usage for this session",
|
||||
"`/insights [days]` — Show usage insights and analytics",
|
||||
"`/reload-mcp` — Reload MCP servers from config",
|
||||
"`/update` — Update Hermes Agent to the latest version",
|
||||
"`/help` — Show this message",
|
||||
@@ -1121,13 +1326,20 @@ class GatewayRunner:
|
||||
async def _handle_model_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /model command - show or change the current model."""
|
||||
import yaml
|
||||
from hermes_cli.models import (
|
||||
parse_model_input,
|
||||
validate_requested_model,
|
||||
curated_models_for_provider,
|
||||
normalize_provider,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
|
||||
args = event.get_command_args().strip()
|
||||
config_path = _hermes_home / 'config.yaml'
|
||||
|
||||
# Resolve current model the same way the agent init does:
|
||||
# env vars first, then config.yaml always overrides.
|
||||
# Resolve current model and provider from config
|
||||
current = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
|
||||
current_provider = "openrouter"
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
@@ -1137,39 +1349,164 @@ class GatewayRunner:
|
||||
current = model_cfg
|
||||
elif isinstance(model_cfg, dict):
|
||||
current = model_cfg.get("default", current)
|
||||
current_provider = model_cfg.get("provider", current_provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Resolve "auto" to the actual provider using credential detection
|
||||
current_provider = normalize_provider(current_provider)
|
||||
if current_provider == "auto":
|
||||
try:
|
||||
from hermes_cli.auth import resolve_provider as _resolve_provider
|
||||
current_provider = _resolve_provider(current_provider)
|
||||
except Exception:
|
||||
current_provider = "openrouter"
|
||||
|
||||
if not args:
|
||||
return f"🤖 **Current model:** `{current}`\n\nTo change: `/model provider/model-name`"
|
||||
provider_label = _PROVIDER_LABELS.get(current_provider, current_provider)
|
||||
lines = [
|
||||
f"🤖 **Current model:** `{current}`",
|
||||
f"**Provider:** {provider_label}",
|
||||
"",
|
||||
]
|
||||
curated = curated_models_for_provider(current_provider)
|
||||
if curated:
|
||||
lines.append(f"**Available models ({provider_label}):**")
|
||||
for mid, desc in curated:
|
||||
marker = " ←" if mid == current else ""
|
||||
label = f" _{desc}_" if desc else ""
|
||||
lines.append(f"• `{mid}`{label}{marker}")
|
||||
lines.append("")
|
||||
lines.append("To change: `/model model-name`")
|
||||
lines.append("Switch provider: `/model provider:model-name`")
|
||||
return "\n".join(lines)
|
||||
|
||||
if "/" not in args:
|
||||
return (
|
||||
f"🤖 Invalid model format: `{args}`\n\n"
|
||||
f"Use `provider/model-name` format, e.g.:\n"
|
||||
f"• `anthropic/claude-sonnet-4`\n"
|
||||
f"• `google/gemini-2.5-pro`\n"
|
||||
f"• `openai/gpt-4o`"
|
||||
)
|
||||
# Parse provider:model syntax
|
||||
target_provider, new_model = parse_model_input(args, current_provider)
|
||||
provider_changed = target_provider != current_provider
|
||||
|
||||
# Write to config.yaml (source of truth), same pattern as CLI save_config_value.
|
||||
# Resolve credentials for the target provider (for API probe)
|
||||
api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
if provider_changed:
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=target_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception as e:
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
return f"⚠️ Could not resolve credentials for provider '{provider_label}': {e}"
|
||||
else:
|
||||
# Use current provider's base_url from config or registry
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=current_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Validate the model against the live API
|
||||
try:
|
||||
validation = validate_requested_model(
|
||||
new_model,
|
||||
target_provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
msg = validation.get("message", "Invalid model")
|
||||
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
|
||||
return f"⚠️ {msg}{tip}"
|
||||
|
||||
# Persist to config only if validation approves
|
||||
if validation.get("persist"):
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = new_model
|
||||
if provider_changed:
|
||||
user_config["model"]["provider"] = target_provider
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save model change: {e}"
|
||||
|
||||
# Set env vars so the next agent run picks up the change
|
||||
os.environ["HERMES_MODEL"] = new_model
|
||||
if provider_changed:
|
||||
os.environ["HERMES_INFERENCE_PROVIDER"] = target_provider
|
||||
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
provider_note = f"\n**Provider:** {provider_label}" if provider_changed else ""
|
||||
|
||||
warning = ""
|
||||
if validation.get("message"):
|
||||
warning = f"\n⚠️ {validation['message']}"
|
||||
|
||||
if validation.get("persist"):
|
||||
persist_note = "saved to config"
|
||||
else:
|
||||
persist_note = "this session only — will revert on restart"
|
||||
return f"🤖 Model changed to `{new_model}` ({persist_note}){provider_note}{warning}\n_(takes effect on next message)_"
|
||||
|
||||
async def _handle_provider_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /provider command - show available providers."""
|
||||
import yaml
|
||||
from hermes_cli.models import (
|
||||
list_available_providers,
|
||||
normalize_provider,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
|
||||
# Resolve current provider from config
|
||||
current_provider = "openrouter"
|
||||
config_path = _hermes_home / 'config.yaml'
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = args
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save model change: {e}"
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
model_cfg = cfg.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
current_provider = model_cfg.get("provider", current_provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Also set env var so code reading it before the next agent init sees the update.
|
||||
os.environ["HERMES_MODEL"] = args
|
||||
current_provider = normalize_provider(current_provider)
|
||||
if current_provider == "auto":
|
||||
try:
|
||||
from hermes_cli.auth import resolve_provider as _resolve_provider
|
||||
current_provider = _resolve_provider(current_provider)
|
||||
except Exception:
|
||||
current_provider = "openrouter"
|
||||
|
||||
return f"🤖 Model changed to `{args}`\n_(takes effect on next message)_"
|
||||
current_label = _PROVIDER_LABELS.get(current_provider, current_provider)
|
||||
|
||||
lines = [
|
||||
f"🔌 **Current provider:** {current_label} (`{current_provider}`)",
|
||||
"",
|
||||
"**Available providers:**",
|
||||
]
|
||||
|
||||
providers = list_available_providers()
|
||||
for p in providers:
|
||||
marker = " ← active" if p["id"] == current_provider else ""
|
||||
auth = "✅" if p["authenticated"] else "❌"
|
||||
aliases = f" _(also: {', '.join(p['aliases'])})_" if p["aliases"] else ""
|
||||
lines.append(f"{auth} `{p['id']}` — {p['label']}{aliases}{marker}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("Switch: `/model provider:model-name`")
|
||||
lines.append("Setup: `hermes setup`")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_personality_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /personality command - list or set a personality."""
|
||||
@@ -1253,8 +1590,7 @@ class GatewayRunner:
|
||||
)
|
||||
|
||||
# Let the normal message handler process it
|
||||
await self._handle_message(retry_event)
|
||||
return None # Response sent through normal flow
|
||||
return await self._handle_message(retry_event)
|
||||
|
||||
async def _handle_undo_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /undo command - remove the last user/assistant exchange."""
|
||||
@@ -1360,6 +1696,40 @@ class GatewayRunner:
|
||||
logger.warning("Manual compress failed: %s", e)
|
||||
return f"Compression failed: {e}"
|
||||
|
||||
async def _handle_title_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /title command — set or show the current session's title."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_id = session_entry.session_id
|
||||
|
||||
if not self._session_db:
|
||||
return "Session database not available."
|
||||
|
||||
title_arg = event.get_command_args().strip()
|
||||
if title_arg:
|
||||
# Sanitize the title before setting
|
||||
try:
|
||||
sanitized = self._session_db.sanitize_title(title_arg)
|
||||
except ValueError as e:
|
||||
return f"⚠️ {e}"
|
||||
if not sanitized:
|
||||
return "⚠️ Title is empty after cleanup. Please use printable characters."
|
||||
# Set the title
|
||||
try:
|
||||
if self._session_db.set_session_title(session_id, sanitized):
|
||||
return f"✏️ Session title set: **{sanitized}**"
|
||||
else:
|
||||
return "Session not found in database."
|
||||
except ValueError as e:
|
||||
return f"⚠️ {e}"
|
||||
else:
|
||||
# Show the current title
|
||||
title = self._session_db.get_session_title(session_id)
|
||||
if title:
|
||||
return f"📌 Session title: **{title}**"
|
||||
else:
|
||||
return "No title set. Usage: `/title My Session Name`"
|
||||
|
||||
async def _handle_usage_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /usage command -- show token usage for the session's last agent run."""
|
||||
source = event.source
|
||||
@@ -1397,6 +1767,53 @@ class GatewayRunner:
|
||||
)
|
||||
return "No usage data available for this session."
|
||||
|
||||
async def _handle_insights_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /insights command -- show usage insights and analytics."""
|
||||
import asyncio as _asyncio
|
||||
|
||||
args = event.get_command_args().strip()
|
||||
days = 30
|
||||
source = None
|
||||
|
||||
# Parse simple args: /insights 7 or /insights --days 7
|
||||
if args:
|
||||
parts = args.split()
|
||||
i = 0
|
||||
while i < len(parts):
|
||||
if parts[i] == "--days" and i + 1 < len(parts):
|
||||
try:
|
||||
days = int(parts[i + 1])
|
||||
except ValueError:
|
||||
return f"Invalid --days value: {parts[i + 1]}"
|
||||
i += 2
|
||||
elif parts[i] == "--source" and i + 1 < len(parts):
|
||||
source = parts[i + 1]
|
||||
i += 2
|
||||
elif parts[i].isdigit():
|
||||
days = int(parts[i])
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
from agent.insights import InsightsEngine
|
||||
|
||||
loop = _asyncio.get_event_loop()
|
||||
|
||||
def _run_insights():
|
||||
db = SessionDB()
|
||||
engine = InsightsEngine(db)
|
||||
report = engine.generate(days=days, source=source)
|
||||
result = engine.format_gateway(report)
|
||||
db.close()
|
||||
return result
|
||||
|
||||
return await loop.run_in_executor(None, _run_insights)
|
||||
except Exception as e:
|
||||
logger.error("Insights command error: %s", e, exc_info=True)
|
||||
return f"Error generating insights: {e}"
|
||||
|
||||
async def _handle_reload_mcp_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /reload-mcp command -- disconnect and reconnect all MCP servers."""
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -2041,7 +2458,7 @@ class GatewayRunner:
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key or ""
|
||||
|
||||
# Read from env var or use default (same as CLI)
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
|
||||
# Map platform enum to the platform hint key the agent understands.
|
||||
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
|
||||
@@ -2381,14 +2798,85 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
|
||||
logger.info("Cron ticker stopped")
|
||||
|
||||
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = False) -> bool:
|
||||
"""
|
||||
Start the gateway and run until interrupted.
|
||||
|
||||
This is the main entry point for running the gateway.
|
||||
Returns True if the gateway ran successfully, False if it failed to start.
|
||||
A False return causes a non-zero exit code so systemd can auto-restart.
|
||||
|
||||
Args:
|
||||
config: Optional gateway configuration override.
|
||||
replace: If True, kill any existing gateway instance before starting.
|
||||
Useful for systemd services to avoid restart-loop deadlocks
|
||||
when the previous process hasn't fully exited yet.
|
||||
"""
|
||||
# ── Duplicate-instance guard ──────────────────────────────────────
|
||||
# Prevent two gateways from running under the same HERMES_HOME.
|
||||
# The PID file is scoped to HERMES_HOME, so future multi-profile
|
||||
# setups (each profile using a distinct HERMES_HOME) will naturally
|
||||
# allow concurrent instances without tripping this guard.
|
||||
import time as _time
|
||||
from gateway.status import get_running_pid, remove_pid_file
|
||||
existing_pid = get_running_pid()
|
||||
if existing_pid is not None and existing_pid != os.getpid():
|
||||
if replace:
|
||||
logger.info(
|
||||
"Replacing existing gateway instance (PID %d) with --replace.",
|
||||
existing_pid,
|
||||
)
|
||||
try:
|
||||
os.kill(existing_pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
pass # Already gone
|
||||
except PermissionError:
|
||||
logger.error(
|
||||
"Permission denied killing PID %d. Cannot replace.",
|
||||
existing_pid,
|
||||
)
|
||||
return False
|
||||
# Wait up to 10 seconds for the old process to exit
|
||||
for _ in range(20):
|
||||
try:
|
||||
os.kill(existing_pid, 0)
|
||||
_time.sleep(0.5)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
break # Process is gone
|
||||
else:
|
||||
# Still alive after 10s — force kill
|
||||
logger.warning(
|
||||
"Old gateway (PID %d) did not exit after SIGTERM, sending SIGKILL.",
|
||||
existing_pid,
|
||||
)
|
||||
try:
|
||||
os.kill(existing_pid, signal.SIGKILL)
|
||||
_time.sleep(0.5)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
remove_pid_file()
|
||||
else:
|
||||
hermes_home = os.getenv("HERMES_HOME", "~/.hermes")
|
||||
logger.error(
|
||||
"Another gateway instance is already running (PID %d, HERMES_HOME=%s). "
|
||||
"Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.",
|
||||
existing_pid, hermes_home,
|
||||
)
|
||||
print(
|
||||
f"\n❌ Gateway already running (PID {existing_pid}).\n"
|
||||
f" Use 'hermes gateway restart' to replace it,\n"
|
||||
f" or 'hermes gateway stop' to kill it first.\n"
|
||||
f" Or use 'hermes gateway run --replace' to auto-replace.\n"
|
||||
)
|
||||
return False
|
||||
|
||||
# Sync bundled skills on gateway start (fast -- skips unchanged)
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
sync_skills(quiet=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Configure rotating file log so gateway output is persisted for debugging
|
||||
log_dir = _hermes_home / 'logs'
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -311,7 +311,9 @@ class SessionStore:
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
self._loaded = False
|
||||
self._has_active_processes_fn = has_active_processes_fn
|
||||
self._on_auto_reset = on_auto_reset # callback(old_entry) before auto-reset
|
||||
# on_auto_reset is deprecated — memory flush now runs proactively
|
||||
# via the background session expiry watcher in GatewayRunner.
|
||||
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
|
||||
|
||||
# Initialize SQLite session database
|
||||
self._db = None
|
||||
@@ -353,6 +355,44 @@ class SessionStore:
|
||||
"""Generate a session key from a source."""
|
||||
return build_session_key(source)
|
||||
|
||||
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||
"""Check if a session has expired based on its reset policy.
|
||||
|
||||
Works from the entry alone — no SessionSource needed.
|
||||
Used by the background expiry watcher to proactively flush memories.
|
||||
Sessions with active background processes are never considered expired.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
if self._has_active_processes_fn(entry.session_key):
|
||||
return False
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=entry.platform,
|
||||
session_type=entry.chat_type,
|
||||
)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0, second=0, microsecond=0,
|
||||
)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
@@ -439,13 +479,11 @@ class SessionStore:
|
||||
self._save()
|
||||
return entry
|
||||
else:
|
||||
# Session is being auto-reset — flush memories before destroying
|
||||
# Session is being auto-reset. The background expiry watcher
|
||||
# should have already flushed memories proactively; discard
|
||||
# the marker so it doesn't accumulate.
|
||||
was_auto_reset = True
|
||||
if self._on_auto_reset:
|
||||
try:
|
||||
self._on_auto_reset(entry)
|
||||
except Exception as e:
|
||||
logger.debug("Auto-reset callback failed: %s", e)
|
||||
self._pre_flushed_sessions.discard(entry.session_id)
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(entry.session_id, "session_reset")
|
||||
|
||||
@@ -3,37 +3,59 @@ Gateway runtime status helpers.
|
||||
|
||||
Provides PID-file based detection of whether the gateway daemon is running,
|
||||
used by send_message's check_fn to gate availability in the CLI.
|
||||
|
||||
The PID file lives at ``{HERMES_HOME}/gateway.pid``. HERMES_HOME defaults to
|
||||
``~/.hermes`` but can be overridden via the environment variable. This means
|
||||
separate HERMES_HOME directories naturally get separate PID files — a property
|
||||
that will be useful when we add named profiles (multiple agents running
|
||||
concurrently under distinct configurations).
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
_PID_FILE = Path.home() / ".hermes" / "gateway.pid"
|
||||
|
||||
def _get_pid_path() -> Path:
|
||||
"""Return the path to the gateway PID file, respecting HERMES_HOME."""
|
||||
home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
return home / "gateway.pid"
|
||||
|
||||
|
||||
def write_pid_file() -> None:
|
||||
"""Write the current process PID to the gateway PID file."""
|
||||
_PID_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_PID_FILE.write_text(str(os.getpid()))
|
||||
pid_path = _get_pid_path()
|
||||
pid_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pid_path.write_text(str(os.getpid()))
|
||||
|
||||
|
||||
def remove_pid_file() -> None:
|
||||
"""Remove the gateway PID file if it exists."""
|
||||
try:
|
||||
_PID_FILE.unlink(missing_ok=True)
|
||||
_get_pid_path().unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
Checks the PID file and verifies the process is actually alive.
|
||||
Cleans up stale PID files automatically.
|
||||
"""
|
||||
pid_path = _get_pid_path()
|
||||
if not pid_path.exists():
|
||||
return None
|
||||
try:
|
||||
pid = int(pid_path.read_text().strip())
|
||||
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
|
||||
return pid
|
||||
except (ValueError, ProcessLookupError, PermissionError):
|
||||
# Stale PID file — process is gone
|
||||
remove_pid_file()
|
||||
return None
|
||||
|
||||
|
||||
def is_gateway_running() -> bool:
|
||||
"""Check if the gateway daemon is currently running."""
|
||||
if not _PID_FILE.exists():
|
||||
return False
|
||||
try:
|
||||
pid = int(_PID_FILE.read_text().strip())
|
||||
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
|
||||
return True
|
||||
except (ValueError, ProcessLookupError, PermissionError):
|
||||
# Stale PID file -- process is gone
|
||||
remove_pid_file()
|
||||
return False
|
||||
return get_running_pid() is not None
|
||||
|
||||
@@ -72,15 +72,19 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""Describes a known OAuth provider."""
|
||||
"""Describes a known inference provider."""
|
||||
id: str
|
||||
name: str
|
||||
auth_type: str # "oauth_device_code" or "api_key"
|
||||
auth_type: str # "oauth_device_code", "oauth_external", or "api_key"
|
||||
portal_base_url: str = ""
|
||||
inference_base_url: str = ""
|
||||
client_id: str = ""
|
||||
scope: str = ""
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
# For API-key providers: env vars to check (in priority order)
|
||||
api_key_env_vars: tuple = ()
|
||||
# Optional env var for base URL override
|
||||
base_url_env_var: str = ""
|
||||
|
||||
|
||||
PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
@@ -99,9 +103,118 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
auth_type="oauth_external",
|
||||
inference_base_url=DEFAULT_CODEX_BASE_URL,
|
||||
),
|
||||
"zai": ProviderConfig(
|
||||
id="zai",
|
||||
name="Z.AI / GLM",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.z.ai/api/paas/v4",
|
||||
api_key_env_vars=("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
base_url_env_var="GLM_BASE_URL",
|
||||
),
|
||||
"kimi-coding": ProviderConfig(
|
||||
id="kimi-coding",
|
||||
name="Kimi / Moonshot",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.moonshot.ai/v1",
|
||||
api_key_env_vars=("KIMI_API_KEY",),
|
||||
base_url_env_var="KIMI_BASE_URL",
|
||||
),
|
||||
"minimax": ProviderConfig(
|
||||
id="minimax",
|
||||
name="MiniMax",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimax.io/v1",
|
||||
api_key_env_vars=("MINIMAX_API_KEY",),
|
||||
base_url_env_var="MINIMAX_BASE_URL",
|
||||
),
|
||||
"minimax-cn": ProviderConfig(
|
||||
id="minimax-cn",
|
||||
name="MiniMax (China)",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimaxi.com/v1",
|
||||
api_key_env_vars=("MINIMAX_CN_API_KEY",),
|
||||
base_url_env_var="MINIMAX_CN_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code Endpoint Detection
|
||||
# =============================================================================
|
||||
|
||||
# Kimi Code (platform.kimi.ai) issues keys prefixed "sk-kimi-" that only work
|
||||
# on api.kimi.com/coding/v1. Legacy keys from platform.moonshot.ai work on
|
||||
# api.moonshot.ai/v1 (the default). Auto-detect when user hasn't set
|
||||
# KIMI_BASE_URL explicitly.
|
||||
KIMI_CODE_BASE_URL = "https://api.kimi.com/coding/v1"
|
||||
|
||||
|
||||
def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) -> str:
|
||||
"""Return the correct Kimi base URL based on the API key prefix.
|
||||
|
||||
If the user has explicitly set KIMI_BASE_URL, that always wins.
|
||||
Otherwise, sk-kimi- prefixed keys route to api.kimi.com/coding/v1.
|
||||
"""
|
||||
if env_override:
|
||||
return env_override
|
||||
if api_key.startswith("sk-kimi-"):
|
||||
return KIMI_CODE_BASE_URL
|
||||
return default_url
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z.AI Endpoint Detection
|
||||
# =============================================================================
|
||||
|
||||
# Z.AI has separate billing for general vs coding plans, and global vs China
|
||||
# endpoints. A key that works on one may return "Insufficient balance" on
|
||||
# another. We probe at setup time and store the working endpoint.
|
||||
|
||||
ZAI_ENDPOINTS = [
|
||||
# (id, base_url, default_model, label)
|
||||
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
|
||||
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
|
||||
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
|
||||
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
|
||||
]
|
||||
|
||||
|
||||
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str, str]]:
|
||||
"""Probe z.ai endpoints to find one that accepts this API key.
|
||||
|
||||
Returns {"id": ..., "base_url": ..., "model": ..., "label": ...} for the
|
||||
first working endpoint, or None if all fail.
|
||||
"""
|
||||
for ep_id, base_url, model, label in ZAI_ENDPOINTS:
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"stream": False,
|
||||
"max_tokens": 1,
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
logger.debug("Z.AI endpoint probe: %s (%s) OK", ep_id, base_url)
|
||||
return {
|
||||
"id": ep_id,
|
||||
"base_url": base_url,
|
||||
"model": model,
|
||||
"label": label,
|
||||
}
|
||||
logger.debug("Z.AI endpoint probe: %s returned %s", ep_id, resp.status_code)
|
||||
except Exception as exc:
|
||||
logger.debug("Z.AI endpoint probe: %s failed: %s", ep_id, exc)
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Error Types
|
||||
# =============================================================================
|
||||
@@ -355,10 +468,19 @@ def resolve_provider(
|
||||
1. active_provider in auth.json with valid credentials
|
||||
2. Explicit CLI api_key/base_url -> "openrouter"
|
||||
3. OPENAI_API_KEY or OPENROUTER_API_KEY env vars -> "openrouter"
|
||||
4. Fallback: "openrouter"
|
||||
4. Provider-specific API keys (GLM, Kimi, MiniMax) -> that provider
|
||||
5. Fallback: "openrouter"
|
||||
"""
|
||||
normalized = (requested or "auto").strip().lower()
|
||||
|
||||
# Normalize provider aliases
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
}
|
||||
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
if normalized in {"openrouter", "custom"}:
|
||||
return "openrouter"
|
||||
if normalized in PROVIDER_REGISTRY:
|
||||
@@ -387,6 +509,14 @@ def resolve_provider(
|
||||
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
|
||||
return "openrouter"
|
||||
|
||||
# Auto-detect API-key providers by checking their env vars
|
||||
for pid, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
return pid
|
||||
|
||||
return "openrouter"
|
||||
|
||||
|
||||
@@ -1230,6 +1360,42 @@ def get_codex_auth_status() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
"""Status snapshot for API-key providers (z.ai, Kimi, MiniMax)."""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
return {"configured": False}
|
||||
|
||||
api_key = ""
|
||||
key_source = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
key_source = env_var
|
||||
break
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
|
||||
if provider_id == "kimi-coding":
|
||||
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, env_url)
|
||||
elif env_url:
|
||||
base_url = env_url
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
|
||||
return {
|
||||
"configured": bool(api_key),
|
||||
"provider": provider_id,
|
||||
"name": pconfig.name,
|
||||
"key_source": key_source,
|
||||
"base_url": base_url,
|
||||
"logged_in": bool(api_key), # compat with OAuth status shape
|
||||
}
|
||||
|
||||
|
||||
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generic auth status dispatcher."""
|
||||
target = provider_id or get_active_provider()
|
||||
@@ -1237,9 +1403,54 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
return get_nous_auth_status()
|
||||
if target == "openai-codex":
|
||||
return get_codex_auth_status()
|
||||
# API-key providers
|
||||
pconfig = PROVIDER_REGISTRY.get(target)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
return get_api_key_provider_status(target)
|
||||
return {"logged_in": False}
|
||||
|
||||
|
||||
def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
"""Resolve API key and base URL for an API-key provider.
|
||||
|
||||
Returns dict with: provider, api_key, base_url, source.
|
||||
"""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
raise AuthError(
|
||||
f"Provider '{provider_id}' is not an API-key provider.",
|
||||
provider=provider_id,
|
||||
code="invalid_provider",
|
||||
)
|
||||
|
||||
api_key = ""
|
||||
key_source = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
key_source = env_var
|
||||
break
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
|
||||
if provider_id == "kimi-coding":
|
||||
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, env_url)
|
||||
elif env_url:
|
||||
base_url = env_url.rstrip("/")
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
|
||||
return {
|
||||
"provider": provider_id,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url.rstrip("/"),
|
||||
"source": key_source or "default",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# External credential detection
|
||||
# =============================================================================
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
"""Welcome banner, ASCII art, and skills summary for the CLI.
|
||||
"""Welcome banner, ASCII art, skills summary, and update check for the CLI.
|
||||
|
||||
Pure display functions with no HermesCLI state dependency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -13,6 +18,8 @@ from rich.table import Table
|
||||
from prompt_toolkit import print_formatted_text as _pt_print
|
||||
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# ANSI building blocks for conversation display
|
||||
@@ -95,15 +102,93 @@ def get_available_skills() -> Dict[str, List[str]]:
|
||||
return skills_by_category
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Update check
|
||||
# =========================================================================
|
||||
|
||||
# Cache update check results for 6 hours to avoid repeated git fetches
|
||||
_UPDATE_CHECK_CACHE_SECONDS = 6 * 3600
|
||||
|
||||
|
||||
def check_for_updates() -> Optional[int]:
|
||||
"""Check how many commits behind origin/main the local repo is.
|
||||
|
||||
Does a ``git fetch`` at most once every 6 hours (cached to
|
||||
``~/.hermes/.update_check``). Returns the number of commits behind,
|
||||
or ``None`` if the check fails or isn't applicable.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
repo_dir = hermes_home / "hermes-agent"
|
||||
cache_file = hermes_home / ".update_check"
|
||||
|
||||
# Must be a git repo
|
||||
if not (repo_dir / ".git").exists():
|
||||
return None
|
||||
|
||||
# Read cache
|
||||
now = time.time()
|
||||
try:
|
||||
if cache_file.exists():
|
||||
cached = json.loads(cache_file.read_text())
|
||||
if now - cached.get("ts", 0) < _UPDATE_CHECK_CACHE_SECONDS:
|
||||
return cached.get("behind")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch latest refs (fast — only downloads ref metadata, no files)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "fetch", "origin", "--quiet"],
|
||||
capture_output=True, timeout=10,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
except Exception:
|
||||
pass # Offline or timeout — use stale refs, that's fine
|
||||
|
||||
# Count commits behind
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-list", "--count", "HEAD..origin/main"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
behind = int(result.stdout.strip())
|
||||
else:
|
||||
behind = None
|
||||
except Exception:
|
||||
behind = None
|
||||
|
||||
# Write cache
|
||||
try:
|
||||
cache_file.write_text(json.dumps({"ts": now, "behind": behind}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return behind
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Welcome banner
|
||||
# =========================================================================
|
||||
|
||||
def _format_context_length(tokens: int) -> str:
|
||||
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
|
||||
if tokens >= 1_000_000:
|
||||
val = tokens / 1_000_000
|
||||
return f"{val:g}M"
|
||||
elif tokens >= 1_000:
|
||||
val = tokens / 1_000
|
||||
return f"{val:g}K"
|
||||
return str(tokens)
|
||||
|
||||
|
||||
def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
tools: List[dict] = None,
|
||||
enabled_toolsets: List[str] = None,
|
||||
session_id: str = None,
|
||||
get_toolset_for_tool=None):
|
||||
get_toolset_for_tool=None,
|
||||
context_length: int = None):
|
||||
"""Build and print a welcome banner with caduceus on left and info on right.
|
||||
|
||||
Args:
|
||||
@@ -114,6 +199,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
enabled_toolsets: List of enabled toolset names.
|
||||
session_id: Session identifier.
|
||||
get_toolset_for_tool: Callable to map tool name -> toolset name.
|
||||
context_length: Model's context window size in tokens.
|
||||
"""
|
||||
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
|
||||
if get_toolset_for_tool is None:
|
||||
@@ -135,7 +221,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
model_short = model.split("/")[-1] if "/" in model else model
|
||||
if len(model_short) > 28:
|
||||
model_short = model_short[:25] + "..."
|
||||
left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
|
||||
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
|
||||
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
|
||||
left_lines.append(f"[dim #B8860B]{cwd}[/]")
|
||||
if session_id:
|
||||
left_lines.append(f"[dim #8B8682]Session: {session_id}[/]")
|
||||
@@ -245,6 +332,18 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
summary_parts.append("/help for commands")
|
||||
right_lines.append(f"[dim #B8860B]{' · '.join(summary_parts)}[/]")
|
||||
|
||||
# Update check — show if behind origin/main
|
||||
try:
|
||||
behind = check_for_updates()
|
||||
if behind and behind > 0:
|
||||
commits_word = "commit" if behind == 1 else "commits"
|
||||
right_lines.append(
|
||||
f"[bold yellow]⚠ {behind} {commits_word} behind[/]"
|
||||
f"[dim yellow] — run [bold]hermes update[/bold] to update[/]"
|
||||
)
|
||||
except Exception:
|
||||
pass # Never break the banner over an update check
|
||||
|
||||
right_content = "\n".join(right_lines)
|
||||
layout_table.add_row(left_content, right_content)
|
||||
|
||||
|
||||
352
hermes_cli/clipboard.py
Normal file
352
hermes_cli/clipboard.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Clipboard image extraction for macOS, Linux, and WSL2.
|
||||
|
||||
Provides a single function `save_clipboard_image(dest)` that checks the
|
||||
system clipboard for image data, saves it to *dest* as PNG, and returns
|
||||
True on success. No external Python dependencies — uses only OS-level
|
||||
CLI tools that ship with the platform (or are commonly installed).
|
||||
|
||||
Platform support:
|
||||
macOS — osascript (always available), pngpaste (if installed)
|
||||
WSL2 — powershell.exe via .NET System.Windows.Forms.Clipboard
|
||||
Linux — wl-paste (Wayland), xclip (X11)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache WSL detection (checked once per process)
|
||||
_wsl_detected: bool | None = None
|
||||
|
||||
|
||||
def save_clipboard_image(dest: Path) -> bool:
|
||||
"""Extract an image from the system clipboard and save it as PNG.
|
||||
|
||||
Returns True if an image was found and saved, False otherwise.
|
||||
"""
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
if sys.platform == "darwin":
|
||||
return _macos_save(dest)
|
||||
return _linux_save(dest)
|
||||
|
||||
|
||||
def has_clipboard_image() -> bool:
|
||||
"""Quick check: does the clipboard currently contain an image?
|
||||
|
||||
Lighter than save_clipboard_image — doesn't extract or write anything.
|
||||
"""
|
||||
if sys.platform == "darwin":
|
||||
return _macos_has_image()
|
||||
if _is_wsl():
|
||||
return _wsl_has_image()
|
||||
if os.environ.get("WAYLAND_DISPLAY"):
|
||||
return _wayland_has_image()
|
||||
return _xclip_has_image()
|
||||
|
||||
|
||||
# ── macOS ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _macos_save(dest: Path) -> bool:
|
||||
"""Try pngpaste first (fast, handles more formats), fall back to osascript."""
|
||||
return _macos_pngpaste(dest) or _macos_osascript(dest)
|
||||
|
||||
|
||||
def _macos_has_image() -> bool:
|
||||
"""Check if macOS clipboard contains image data."""
|
||||
try:
|
||||
info = subprocess.run(
|
||||
["osascript", "-e", "clipboard info"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
return "«class PNGf»" in info.stdout or "«class TIFF»" in info.stdout
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _macos_pngpaste(dest: Path) -> bool:
|
||||
"""Use pngpaste (brew install pngpaste) — fastest, cleanest."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["pngpaste", str(dest)],
|
||||
capture_output=True, timeout=3,
|
||||
)
|
||||
if r.returncode == 0 and dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
pass # pngpaste not installed
|
||||
except Exception as e:
|
||||
logger.debug("pngpaste failed: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def _macos_osascript(dest: Path) -> bool:
|
||||
"""Use osascript to extract PNG data from clipboard (always available)."""
|
||||
if not _macos_has_image():
|
||||
return False
|
||||
|
||||
# Extract as PNG
|
||||
script = (
|
||||
'try\n'
|
||||
' set imgData to the clipboard as «class PNGf»\n'
|
||||
f' set f to open for access POSIX file "{dest}" with write permission\n'
|
||||
' write imgData to f\n'
|
||||
' close access f\n'
|
||||
'on error\n'
|
||||
' return "fail"\n'
|
||||
'end try\n'
|
||||
)
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["osascript", "-e", script],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
if r.returncode == 0 and "fail" not in r.stdout and dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("osascript clipboard extract failed: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
# ── Linux ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _is_wsl() -> bool:
|
||||
"""Detect if running inside WSL (1 or 2)."""
|
||||
global _wsl_detected
|
||||
if _wsl_detected is not None:
|
||||
return _wsl_detected
|
||||
try:
|
||||
with open("/proc/version", "r") as f:
|
||||
_wsl_detected = "microsoft" in f.read().lower()
|
||||
except Exception:
|
||||
_wsl_detected = False
|
||||
return _wsl_detected
|
||||
|
||||
|
||||
def _linux_save(dest: Path) -> bool:
|
||||
"""Try clipboard backends in priority order: WSL → Wayland → X11."""
|
||||
if _is_wsl():
|
||||
if _wsl_save(dest):
|
||||
return True
|
||||
# Fall through — WSLg might have wl-paste or xclip working
|
||||
|
||||
if os.environ.get("WAYLAND_DISPLAY"):
|
||||
if _wayland_save(dest):
|
||||
return True
|
||||
|
||||
return _xclip_save(dest)
|
||||
|
||||
|
||||
# ── WSL2 (powershell.exe) ────────────────────────────────────────────────
|
||||
|
||||
# PowerShell script: get clipboard image as base64-encoded PNG on stdout.
|
||||
# Using .NET System.Windows.Forms.Clipboard — always available on Windows.
|
||||
_PS_CHECK_IMAGE = (
|
||||
"Add-Type -AssemblyName System.Windows.Forms;"
|
||||
"[System.Windows.Forms.Clipboard]::ContainsImage()"
|
||||
)
|
||||
|
||||
_PS_EXTRACT_IMAGE = (
|
||||
"Add-Type -AssemblyName System.Windows.Forms;"
|
||||
"Add-Type -AssemblyName System.Drawing;"
|
||||
"$img = [System.Windows.Forms.Clipboard]::GetImage();"
|
||||
"if ($null -eq $img) { exit 1 }"
|
||||
"$ms = New-Object System.IO.MemoryStream;"
|
||||
"$img.Save($ms, [System.Drawing.Imaging.ImageFormat]::Png);"
|
||||
"[System.Convert]::ToBase64String($ms.ToArray())"
|
||||
)
|
||||
|
||||
|
||||
def _wsl_has_image() -> bool:
|
||||
"""Check if Windows clipboard has an image (via powershell.exe)."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=8,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except FileNotFoundError:
|
||||
logger.debug("powershell.exe not found — WSL clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("WSL clipboard check failed: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def _wsl_save(dest: Path) -> bool:
|
||||
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
return False
|
||||
|
||||
png_bytes = base64.b64decode(b64_data)
|
||||
dest.write_bytes(png_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("powershell.exe not found — WSL clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("WSL clipboard extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
|
||||
|
||||
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
|
||||
|
||||
def _wayland_has_image() -> bool:
|
||||
"""Check if Wayland clipboard has image content."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["wl-paste", "--list-types"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
return r.returncode == 0 and any(
|
||||
t.startswith("image/") for t in r.stdout.splitlines()
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _wayland_save(dest: Path) -> bool:
|
||||
"""Use wl-paste to extract clipboard image (Wayland sessions)."""
|
||||
try:
|
||||
# Check available MIME types
|
||||
types_r = subprocess.run(
|
||||
["wl-paste", "--list-types"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
if types_r.returncode != 0:
|
||||
return False
|
||||
types = types_r.stdout.splitlines()
|
||||
|
||||
# Prefer PNG, fall back to other image formats
|
||||
mime = None
|
||||
for preferred in ("image/png", "image/jpeg", "image/bmp",
|
||||
"image/gif", "image/webp"):
|
||||
if preferred in types:
|
||||
mime = preferred
|
||||
break
|
||||
|
||||
if not mime:
|
||||
return False
|
||||
|
||||
# Extract the image data
|
||||
with open(dest, "wb") as f:
|
||||
subprocess.run(
|
||||
["wl-paste", "--type", mime],
|
||||
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
|
||||
)
|
||||
|
||||
if not dest.exists() or dest.stat().st_size == 0:
|
||||
return False
|
||||
|
||||
# BMP needs conversion to PNG (common in WSLg where only BMP
|
||||
# is bridged from Windows clipboard via RDP).
|
||||
if mime == "image/bmp":
|
||||
return _convert_to_png(dest)
|
||||
|
||||
return True
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("wl-paste clipboard extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
|
||||
|
||||
def _convert_to_png(path: Path) -> bool:
|
||||
"""Convert an image file to PNG in-place (requires Pillow or ImageMagick)."""
|
||||
# Try Pillow first (likely installed in the venv)
|
||||
try:
|
||||
from PIL import Image
|
||||
img = Image.open(path)
|
||||
img.save(path, "PNG")
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Pillow BMP→PNG conversion failed: %s", e)
|
||||
|
||||
# Fall back to ImageMagick convert
|
||||
try:
|
||||
tmp = path.with_suffix(".bmp")
|
||||
path.rename(tmp)
|
||||
r = subprocess.run(
|
||||
["convert", str(tmp), "png:" + str(path)],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
tmp.unlink(missing_ok=True)
|
||||
if r.returncode == 0 and path.exists() and path.stat().st_size > 0:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.debug("ImageMagick not installed — cannot convert BMP to PNG")
|
||||
except Exception as e:
|
||||
logger.debug("ImageMagick BMP→PNG conversion failed: %s", e)
|
||||
|
||||
# Can't convert — BMP is still usable as-is for most APIs
|
||||
return path.exists() and path.stat().st_size > 0
|
||||
|
||||
|
||||
# ── X11 (xclip) ─────────────────────────────────────────────────────────
|
||||
|
||||
def _xclip_has_image() -> bool:
|
||||
"""Check if X11 clipboard has image content."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
return r.returncode == 0 and "image/png" in r.stdout
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _xclip_save(dest: Path) -> bool:
|
||||
"""Use xclip to extract clipboard image (X11 sessions)."""
|
||||
# Check if clipboard has image content
|
||||
try:
|
||||
targets = subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
|
||||
capture_output=True, text=True, timeout=3,
|
||||
)
|
||||
if "image/png" not in targets.stdout:
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
logger.debug("xclip not installed — X11 clipboard image paste unavailable")
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Extract PNG data
|
||||
try:
|
||||
with open(dest, "wb") as f:
|
||||
subprocess.run(
|
||||
["xclip", "-selection", "clipboard", "-t", "image/png", "-o"],
|
||||
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
|
||||
)
|
||||
if dest.exists() and dest.stat().st_size > 0:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("xclip image extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
@@ -1,9 +1,15 @@
|
||||
"""Slash command definitions and autocomplete for the Hermes CLI.
|
||||
|
||||
Contains the COMMANDS dict and the SlashCommandCompleter class.
|
||||
These are pure data/UI with no HermesCLI state dependency.
|
||||
Contains the shared built-in ``COMMANDS`` dict and ``SlashCommandCompleter``.
|
||||
The completer can optionally include dynamic skill slash commands supplied by the
|
||||
interactive CLI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
|
||||
@@ -12,6 +18,7 @@ COMMANDS = {
|
||||
"/tools": "List available tools",
|
||||
"/toolsets": "List available toolsets",
|
||||
"/model": "Show or change the current model",
|
||||
"/provider": "Show available providers and current provider",
|
||||
"/prompt": "View/set custom system prompt",
|
||||
"/personality": "Set a predefined personality",
|
||||
"/clear": "Clear screen and reset conversation (fresh start)",
|
||||
@@ -27,25 +34,68 @@ COMMANDS = {
|
||||
"/platforms": "Show gateway/messaging platform status",
|
||||
"/verbose": "Cycle tool progress display: off → new → all → verbose",
|
||||
"/compress": "Manually compress conversation context (flush memories + summarize)",
|
||||
"/title": "Set a title for the current session (usage: /title My Session Name)",
|
||||
"/usage": "Show token usage for the current session",
|
||||
"/insights": "Show usage insights and analytics (last 30 days)",
|
||||
"/paste": "Check clipboard for an image and attach it",
|
||||
"/reload-mcp": "Reload MCP servers from config.yaml",
|
||||
"/quit": "Exit the CLI (also: /exit, /q)",
|
||||
}
|
||||
|
||||
|
||||
class SlashCommandCompleter(Completer):
|
||||
"""Autocomplete for /commands in the input area."""
|
||||
"""Autocomplete for built-in slash commands and optional skill commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
skill_commands_provider: Callable[[], Mapping[str, dict[str, Any]]] | None = None,
|
||||
) -> None:
|
||||
self._skill_commands_provider = skill_commands_provider
|
||||
|
||||
def _iter_skill_commands(self) -> Mapping[str, dict[str, Any]]:
|
||||
if self._skill_commands_provider is None:
|
||||
return {}
|
||||
try:
|
||||
return self._skill_commands_provider() or {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _completion_text(cmd_name: str, word: str) -> str:
|
||||
"""Return replacement text for a completion.
|
||||
|
||||
When the user has already typed the full command exactly (``/help``),
|
||||
returning ``help`` would be a no-op and prompt_toolkit suppresses the
|
||||
menu. Appending a trailing space keeps the dropdown visible and makes
|
||||
backspacing retrigger it naturally.
|
||||
"""
|
||||
return f"{cmd_name} " if cmd_name == word else cmd_name
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text_before_cursor
|
||||
if not text.startswith("/"):
|
||||
return
|
||||
|
||||
word = text[1:]
|
||||
|
||||
for cmd, desc in COMMANDS.items():
|
||||
cmd_name = cmd[1:]
|
||||
if cmd_name.startswith(word):
|
||||
yield Completion(
|
||||
cmd_name,
|
||||
self._completion_text(cmd_name, word),
|
||||
start_position=-len(word),
|
||||
display=cmd,
|
||||
display_meta=desc,
|
||||
)
|
||||
|
||||
for cmd, info in self._iter_skill_commands().items():
|
||||
cmd_name = cmd[1:]
|
||||
if cmd_name.startswith(word):
|
||||
description = str(info.get("description", "Skill command"))
|
||||
short_desc = description[:50] + ("..." if len(description) > 50 else "")
|
||||
yield Completion(
|
||||
self._completion_text(cmd_name, word),
|
||||
start_position=-len(word),
|
||||
display=cmd,
|
||||
display_meta=f"⚡ {short_desc}",
|
||||
)
|
||||
|
||||
@@ -71,7 +71,8 @@ DEFAULT_CONFIG = {
|
||||
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"modal_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
# Container resource limits (docker, singularity, modal — ignored for local/ssh)
|
||||
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
# Container resource limits (docker, singularity, modal, daytona — ignored for local/ssh)
|
||||
"container_cpu": 1,
|
||||
"container_memory": 5120, # MB (default 5GB)
|
||||
"container_disk": 51200, # MB (default 50GB)
|
||||
@@ -140,9 +141,13 @@ DEFAULT_CONFIG = {
|
||||
# (apiKey, workspace, peerName, sessions, enabled) comes from the global config.
|
||||
"honcho": {},
|
||||
|
||||
# IANA timezone (e.g. "Asia/Kolkata", "America/New_York").
|
||||
# Empty string means use server-local time.
|
||||
"timezone": "",
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
"command_allowlist": [],
|
||||
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 5,
|
||||
}
|
||||
@@ -151,6 +156,15 @@ DEFAULT_CONFIG = {
|
||||
# Config Migration System
|
||||
# =============================================================================
|
||||
|
||||
# Track which env vars were introduced in each config version.
|
||||
# Migration only mentions vars new since the user's previous version.
|
||||
ENV_VARS_BY_VERSION: Dict[int, List[str]] = {
|
||||
3: ["FIRECRAWL_API_KEY", "BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID", "FAL_KEY"],
|
||||
4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"],
|
||||
5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS",
|
||||
"SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"],
|
||||
}
|
||||
|
||||
# Required environment variables with metadata for migration prompts.
|
||||
# LLM provider is required but handled in the setup wizard's provider
|
||||
# selection step (Nous Portal / OpenRouter / Custom endpoint), so this
|
||||
@@ -169,6 +183,86 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_API_KEY": {
|
||||
"description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)",
|
||||
"prompt": "Z.AI / GLM API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"ZAI_API_KEY": {
|
||||
"description": "Z.AI API key (alias for GLM_API_KEY)",
|
||||
"prompt": "Z.AI API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"Z_AI_API_KEY": {
|
||||
"description": "Z.AI API key (alias for GLM_API_KEY)",
|
||||
"prompt": "Z.AI API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_BASE_URL": {
|
||||
"description": "Z.AI / GLM base URL override",
|
||||
"prompt": "Z.AI / GLM base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"KIMI_API_KEY": {
|
||||
"description": "Kimi / Moonshot API key",
|
||||
"prompt": "Kimi API key",
|
||||
"url": "https://platform.moonshot.cn/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"KIMI_BASE_URL": {
|
||||
"description": "Kimi / Moonshot base URL override",
|
||||
"prompt": "Kimi base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_API_KEY": {
|
||||
"description": "MiniMax API key (international)",
|
||||
"prompt": "MiniMax API key",
|
||||
"url": "https://www.minimax.io/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_BASE_URL": {
|
||||
"description": "MiniMax base URL override",
|
||||
"prompt": "MiniMax base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_CN_API_KEY": {
|
||||
"description": "MiniMax API key (China endpoint)",
|
||||
"prompt": "MiniMax (China) API key",
|
||||
"url": "https://www.minimaxi.com/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_CN_BASE_URL": {
|
||||
"description": "MiniMax (China) base URL override",
|
||||
"prompt": "MiniMax (China) base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"FIRECRAWL_API_KEY": {
|
||||
@@ -179,8 +273,16 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"FIRECRAWL_API_URL": {
|
||||
"description": "Firecrawl API URL for self-hosted instances (optional)",
|
||||
"prompt": "Firecrawl API URL (leave empty for cloud)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "tool",
|
||||
"advanced": True,
|
||||
},
|
||||
"BROWSERBASE_API_KEY": {
|
||||
"description": "Browserbase API key for browser automation",
|
||||
"description": "Browserbase API key for cloud browser (optional — local browser works without this)",
|
||||
"prompt": "Browserbase API key",
|
||||
"url": "https://browserbase.com/",
|
||||
"tools": ["browser_navigate", "browser_click"],
|
||||
@@ -188,7 +290,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "tool",
|
||||
},
|
||||
"BROWSERBASE_PROJECT_ID": {
|
||||
"description": "Browserbase project ID",
|
||||
"description": "Browserbase project ID (optional — only needed for cloud browser)",
|
||||
"prompt": "Browserbase project ID",
|
||||
"url": "https://browserbase.com/",
|
||||
"tools": ["browser_navigate", "browser_click"],
|
||||
@@ -476,6 +578,22 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
if not quiet:
|
||||
print(f" ✓ Migrated tool progress to config.yaml: {display['tool_progress']}")
|
||||
|
||||
# ── Version 4 → 5: add timezone field ──
|
||||
if current_ver < 5:
|
||||
config = load_config()
|
||||
if "timezone" not in config:
|
||||
old_tz = os.getenv("HERMES_TIMEZONE", "")
|
||||
if old_tz and old_tz.strip():
|
||||
config["timezone"] = old_tz.strip()
|
||||
results["config_added"].append(f"timezone={old_tz.strip()} (from HERMES_TIMEZONE)")
|
||||
else:
|
||||
config["timezone"] = ""
|
||||
results["config_added"].append("timezone= (empty, uses server-local)")
|
||||
save_config(config)
|
||||
if not quiet:
|
||||
tz_display = config["timezone"] or "(server-local)"
|
||||
print(f" ✓ Added timezone to config.yaml: {tz_display}")
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
@@ -516,34 +634,47 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
if v["name"] not in required_names and not v.get("advanced")
|
||||
]
|
||||
|
||||
if interactive and missing_optional:
|
||||
print(" Would you like to configure any optional keys now?")
|
||||
try:
|
||||
answer = input(" Configure optional keys? [y/N]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
answer = "n"
|
||||
|
||||
if answer in ("y", "yes"):
|
||||
# Only offer to configure env vars that are NEW since the user's previous version
|
||||
new_var_names = set()
|
||||
for ver in range(current_ver + 1, latest_ver + 1):
|
||||
new_var_names.update(ENV_VARS_BY_VERSION.get(ver, []))
|
||||
|
||||
if new_var_names and interactive and not quiet:
|
||||
new_and_unset = [
|
||||
(name, OPTIONAL_ENV_VARS[name])
|
||||
for name in sorted(new_var_names)
|
||||
if not get_env_value(name) and name in OPTIONAL_ENV_VARS
|
||||
]
|
||||
if new_and_unset:
|
||||
print(f"\n {len(new_and_unset)} new optional key(s) in this update:")
|
||||
for name, info in new_and_unset:
|
||||
print(f" • {name} — {info.get('description', '')}")
|
||||
print()
|
||||
for var in missing_optional:
|
||||
desc = var.get("description", "")
|
||||
if var.get("url"):
|
||||
print(f" {desc}")
|
||||
print(f" Get your key at: {var['url']}")
|
||||
else:
|
||||
print(f" {desc}")
|
||||
|
||||
if var.get("password"):
|
||||
import getpass
|
||||
value = getpass.getpass(f" {var['prompt']} (Enter to skip): ")
|
||||
else:
|
||||
value = input(f" {var['prompt']} (Enter to skip): ").strip()
|
||||
|
||||
if value:
|
||||
save_env_value(var["name"], value)
|
||||
results["env_added"].append(var["name"])
|
||||
print(f" ✓ Saved {var['name']}")
|
||||
try:
|
||||
answer = input(" Configure new keys? [y/N]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
answer = "n"
|
||||
|
||||
if answer in ("y", "yes"):
|
||||
print()
|
||||
for name, info in new_and_unset:
|
||||
if info.get("url"):
|
||||
print(f" {info.get('description', name)}")
|
||||
print(f" Get your key at: {info['url']}")
|
||||
else:
|
||||
print(f" {info.get('description', name)}")
|
||||
if info.get("password"):
|
||||
import getpass
|
||||
value = getpass.getpass(f" {info.get('prompt', name)} (Enter to skip): ")
|
||||
else:
|
||||
value = input(f" {info.get('prompt', name)} (Enter to skip): ").strip()
|
||||
if value:
|
||||
save_env_value(name, value)
|
||||
results["env_added"].append(name)
|
||||
print(f" ✓ Saved {name}")
|
||||
print()
|
||||
else:
|
||||
print(" Set later with: hermes config set KEY VALUE")
|
||||
|
||||
# Check for missing config fields
|
||||
missing_config = get_missing_config_fields()
|
||||
@@ -753,12 +884,25 @@ def show_config():
|
||||
print(f" Modal image: {terminal.get('modal_image', 'python:3.11')}")
|
||||
modal_token = get_env_value('MODAL_TOKEN_ID')
|
||||
print(f" Modal token: {'configured' if modal_token else '(not set)'}")
|
||||
elif terminal.get('backend') == 'daytona':
|
||||
print(f" Daytona image: {terminal.get('daytona_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
daytona_key = get_env_value('DAYTONA_API_KEY')
|
||||
print(f" API key: {'configured' if daytona_key else '(not set)'}")
|
||||
elif terminal.get('backend') == 'ssh':
|
||||
ssh_host = get_env_value('TERMINAL_SSH_HOST')
|
||||
ssh_user = get_env_value('TERMINAL_SSH_USER')
|
||||
print(f" SSH host: {ssh_host or '(not set)'}")
|
||||
print(f" SSH user: {ssh_user or '(not set)'}")
|
||||
|
||||
# Timezone
|
||||
print()
|
||||
print(color("◆ Timezone", Colors.CYAN, Colors.BOLD))
|
||||
tz = config.get('timezone', '')
|
||||
if tz:
|
||||
print(f" Timezone: {tz}")
|
||||
else:
|
||||
print(f" Timezone: {color('(server-local)', Colors.DIM)}")
|
||||
|
||||
# Compression
|
||||
print()
|
||||
print(color("◆ Context Compression", Colors.CYAN, Colors.BOLD))
|
||||
@@ -820,15 +964,16 @@ def set_config_value(key: str, value: str):
|
||||
"""Set a configuration value."""
|
||||
# Check if it's an API key (goes to .env)
|
||||
api_keys = [
|
||||
'OPENROUTER_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
|
||||
'FIRECRAWL_API_KEY', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID',
|
||||
'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
|
||||
'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID',
|
||||
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
||||
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
||||
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
|
||||
'GITHUB_TOKEN', 'HONCHO_API_KEY',
|
||||
'GITHUB_TOKEN', 'HONCHO_API_KEY', 'NOUS_API_KEY', 'WANDB_API_KEY',
|
||||
'TINKER_API_KEY',
|
||||
]
|
||||
|
||||
if key.upper() in api_keys or key.upper().startswith('TERMINAL_SSH'):
|
||||
if key.upper() in api_keys or key.upper().endswith('_API_KEY') or key.upper().endswith('_TOKEN') or key.upper().startswith('TERMINAL_SSH'):
|
||||
save_env_value(key.upper(), value)
|
||||
print(f"✓ Set {key} in {get_env_path()}")
|
||||
return
|
||||
@@ -878,8 +1023,10 @@ def set_config_value(key: str, value: str):
|
||||
"terminal.docker_image": "TERMINAL_DOCKER_IMAGE",
|
||||
"terminal.singularity_image": "TERMINAL_SINGULARITY_IMAGE",
|
||||
"terminal.modal_image": "TERMINAL_MODAL_IMAGE",
|
||||
"terminal.daytona_image": "TERMINAL_DAYTONA_IMAGE",
|
||||
"terminal.cwd": "TERMINAL_CWD",
|
||||
"terminal.timeout": "TERMINAL_TIMEOUT",
|
||||
"terminal.sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
}
|
||||
if key in _config_to_env_sync:
|
||||
save_env_value(_config_to_env_sync[key], str(value))
|
||||
|
||||
@@ -33,6 +33,26 @@ os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
|
||||
_PROVIDER_ENV_HINTS = (
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"OPENAI_BASE_URL",
|
||||
"GLM_API_KEY",
|
||||
"ZAI_API_KEY",
|
||||
"Z_AI_API_KEY",
|
||||
"KIMI_API_KEY",
|
||||
"MINIMAX_API_KEY",
|
||||
"MINIMAX_CN_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
def _has_provider_env_config(content: str) -> bool:
|
||||
"""Return True when ~/.hermes/.env contains provider auth/base URL settings."""
|
||||
return any(key in content for key in _PROVIDER_ENV_HINTS)
|
||||
|
||||
|
||||
def check_ok(text: str, detail: str = ""):
|
||||
print(f" {color('✓', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
@@ -132,8 +152,8 @@ def run_doctor(args):
|
||||
|
||||
# Check for common issues
|
||||
content = env_path.read_text()
|
||||
if "OPENROUTER_API_KEY" in content or "ANTHROPIC_API_KEY" in content:
|
||||
check_ok("API key configured")
|
||||
if _has_provider_env_config(content):
|
||||
check_ok("API key or custom endpoint configured")
|
||||
else:
|
||||
check_warn("No API key found in ~/.hermes/.env")
|
||||
issues.append("Run 'hermes setup' to configure API keys")
|
||||
@@ -355,6 +375,21 @@ def run_doctor(args):
|
||||
check_fail("TERMINAL_SSH_HOST not set", "(required for TERMINAL_ENV=ssh)")
|
||||
issues.append("Set TERMINAL_SSH_HOST in .env")
|
||||
|
||||
# Daytona (if using daytona backend)
|
||||
if terminal_env == "daytona":
|
||||
daytona_key = os.getenv("DAYTONA_API_KEY")
|
||||
if daytona_key:
|
||||
check_ok("Daytona API key", "(configured)")
|
||||
else:
|
||||
check_fail("DAYTONA_API_KEY not set", "(required for TERMINAL_ENV=daytona)")
|
||||
issues.append("Set DAYTONA_API_KEY environment variable")
|
||||
try:
|
||||
from daytona import Daytona
|
||||
check_ok("daytona SDK", "(installed)")
|
||||
except ImportError:
|
||||
check_fail("daytona SDK not installed", "(pip install daytona)")
|
||||
issues.append("Install daytona SDK: pip install daytona")
|
||||
|
||||
# Node.js + agent-browser (for browser automation tools)
|
||||
if shutil.which("node"):
|
||||
check_ok("Node.js")
|
||||
@@ -453,7 +488,48 @@ def run_doctor(args):
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ")
|
||||
except Exception as e:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(f'({e})', Colors.DIM)} ")
|
||||
|
||||
|
||||
# -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) --
|
||||
_apikey_providers = [
|
||||
("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"),
|
||||
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
|
||||
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
|
||||
]
|
||||
for _pname, _env_vars, _default_url, _base_env in _apikey_providers:
|
||||
_key = ""
|
||||
for _ev in _env_vars:
|
||||
_key = os.getenv(_ev, "")
|
||||
if _key:
|
||||
break
|
||||
if _key:
|
||||
_label = _pname.ljust(20)
|
||||
print(f" Checking {_pname} API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
_base = os.getenv(_base_env, "")
|
||||
# Auto-detect Kimi Code keys (sk-kimi-) → api.kimi.com
|
||||
if not _base and _key.startswith("sk-kimi-"):
|
||||
_base = "https://api.kimi.com/coding/v1"
|
||||
_url = (_base.rstrip("/") + "/models") if _base else _default_url
|
||||
_headers = {"Authorization": f"Bearer {_key}"}
|
||||
if "api.kimi.com" in _url.lower():
|
||||
_headers["User-Agent"] = "KimiCLI/1.0"
|
||||
_resp = httpx.get(
|
||||
_url,
|
||||
headers=_headers,
|
||||
timeout=10,
|
||||
)
|
||||
if _resp.status_code == 200:
|
||||
print(f"\r {color('✓', Colors.GREEN)} {_label} ")
|
||||
elif _resp.status_code == 401:
|
||||
print(f"\r {color('✗', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ")
|
||||
issues.append(f"Check {_env_vars[0]} in .env")
|
||||
else:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ")
|
||||
except Exception as _e:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Submodules
|
||||
# =========================================================================
|
||||
|
||||
@@ -154,19 +154,33 @@ def get_hermes_cli_path() -> str:
|
||||
# =============================================================================
|
||||
|
||||
def generate_systemd_unit() -> str:
|
||||
import shutil
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
venv_bin = str(PROJECT_ROOT / "venv" / "bin")
|
||||
node_bin = str(PROJECT_ROOT / "node_modules" / ".bin")
|
||||
|
||||
# Build a PATH that includes the venv, node_modules, and standard system dirs
|
||||
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
||||
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run --replace
|
||||
ExecStop={hermes_cli} gateway stop
|
||||
WorkingDirectory={working_dir}
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=15
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -377,8 +391,15 @@ def launchd_status(deep: bool = False):
|
||||
# Gateway Runner
|
||||
# =============================================================================
|
||||
|
||||
def run_gateway(verbose: bool = False):
|
||||
"""Run the gateway in foreground."""
|
||||
def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
"""Run the gateway in foreground.
|
||||
|
||||
Args:
|
||||
verbose: Enable verbose logging output.
|
||||
replace: If True, kill any existing gateway instance before starting.
|
||||
This prevents systemd restart loops when the old process
|
||||
hasn't fully exited yet.
|
||||
"""
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from gateway.run import start_gateway
|
||||
@@ -393,7 +414,7 @@ def run_gateway(verbose: bool = False):
|
||||
|
||||
# Exit with code 1 if gateway fails to connect any platform,
|
||||
# so systemd Restart=on-failure will retry on transient errors
|
||||
success = asyncio.run(start_gateway())
|
||||
success = asyncio.run(start_gateway(replace=replace))
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
@@ -765,7 +786,8 @@ def gateway_command(args):
|
||||
# Default to run if no subcommand
|
||||
if subcmd is None or subcmd == "run":
|
||||
verbose = getattr(args, 'verbose', False)
|
||||
run_gateway(verbose)
|
||||
replace = getattr(args, 'replace', False)
|
||||
run_gateway(verbose, replace=replace)
|
||||
return
|
||||
|
||||
if subcmd == "setup":
|
||||
|
||||
@@ -64,7 +64,13 @@ def _has_any_provider_configured() -> bool:
|
||||
# Check env vars (may be set by .env or shell).
|
||||
# OPENAI_BASE_URL alone counts — local models (vLLM, llama.cpp, etc.)
|
||||
# often don't require an API key.
|
||||
provider_env_vars = ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL")
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
|
||||
# Collect all provider env vars
|
||||
provider_env_vars = {"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL"}
|
||||
for pconfig in PROVIDER_REGISTRY.values():
|
||||
if pconfig.auth_type == "api_key":
|
||||
provider_env_vars.update(pconfig.api_key_env_vars)
|
||||
if any(os.getenv(v) for v in provider_env_vars):
|
||||
return True
|
||||
|
||||
@@ -114,16 +120,63 @@ def _resolve_last_cli_session() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
|
||||
"""Resolve a session name (title) or ID to a session ID.
|
||||
|
||||
- If it looks like a session ID (contains underscore + hex), try direct lookup first.
|
||||
- Otherwise, treat it as a title and use resolve_session_by_title (auto-latest).
|
||||
- Falls back to the other method if the first doesn't match.
|
||||
"""
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB()
|
||||
|
||||
# Try as exact session ID first
|
||||
session = db.get_session(name_or_id)
|
||||
if session:
|
||||
db.close()
|
||||
return session["id"]
|
||||
|
||||
# Try as title (with auto-latest for lineage)
|
||||
session_id = db.resolve_session_by_title(name_or_id)
|
||||
db.close()
|
||||
return session_id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def cmd_chat(args):
|
||||
"""Run interactive chat CLI."""
|
||||
# Resolve --continue into --resume with the latest CLI session
|
||||
if getattr(args, "continue_last", False) and not getattr(args, "resume", None):
|
||||
last_id = _resolve_last_cli_session()
|
||||
if last_id:
|
||||
args.resume = last_id
|
||||
# Resolve --continue into --resume with the latest CLI session or by name
|
||||
continue_val = getattr(args, "continue_last", None)
|
||||
if continue_val and not getattr(args, "resume", None):
|
||||
if isinstance(continue_val, str):
|
||||
# -c "session name" — resolve by title or ID
|
||||
resolved = _resolve_session_by_name_or_id(continue_val)
|
||||
if resolved:
|
||||
args.resume = resolved
|
||||
else:
|
||||
print(f"No session found matching '{continue_val}'.")
|
||||
print("Use 'hermes sessions list' to see available sessions.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("No previous CLI session found to continue.")
|
||||
sys.exit(1)
|
||||
# -c with no argument — continue the most recent session
|
||||
last_id = _resolve_last_cli_session()
|
||||
if last_id:
|
||||
args.resume = last_id
|
||||
else:
|
||||
print("No previous CLI session found to continue.")
|
||||
sys.exit(1)
|
||||
|
||||
# Resolve --resume by title if it's not a direct session ID
|
||||
resume_val = getattr(args, "resume", None)
|
||||
if resume_val:
|
||||
resolved = _resolve_session_by_name_or_id(resume_val)
|
||||
if resolved:
|
||||
args.resume = resolved
|
||||
# If resolution fails, keep the original value — _init_agent will
|
||||
# report "Session not found" with the original input
|
||||
|
||||
# First-run guard: check if any provider is configured before launching
|
||||
if not _has_any_provider_configured():
|
||||
@@ -143,6 +196,13 @@ def cmd_chat(args):
|
||||
print("You can run 'hermes setup' at any time to configure.")
|
||||
sys.exit(1)
|
||||
|
||||
# Sync bundled skills on every CLI launch (fast -- skips unchanged skills)
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
sync_skills(quiet=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Import and run the CLI
|
||||
from cli import main as cli_main
|
||||
|
||||
@@ -154,6 +214,7 @@ def cmd_chat(args):
|
||||
"verbose": args.verbose,
|
||||
"query": args.query,
|
||||
"resume": getattr(args, "resume", None),
|
||||
"worktree": getattr(args, "worktree", False),
|
||||
}
|
||||
# Filter out None values
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
@@ -404,6 +465,10 @@ def cmd_model(args):
|
||||
"openrouter": "OpenRouter",
|
||||
"nous": "Nous Portal",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
active_label = provider_labels.get(active, active)
|
||||
@@ -418,11 +483,16 @@ def cmd_model(args):
|
||||
("openrouter", "OpenRouter (100+ models, pay-per-use)"),
|
||||
("nous", "Nous Portal (Nous Research subscription)"),
|
||||
("openai-codex", "OpenAI Codex"),
|
||||
("zai", "Z.AI / GLM (Zhipu AI direct API)"),
|
||||
("kimi-coding", "Kimi / Moonshot (Moonshot AI direct API)"),
|
||||
("minimax", "MiniMax (global direct API)"),
|
||||
("minimax-cn", "MiniMax China (domestic direct API)"),
|
||||
("custom", "Custom endpoint (self-hosted / VLLM / etc.)"),
|
||||
]
|
||||
|
||||
# Reorder so the active provider is at the top
|
||||
active_key = active if active in ("openrouter", "nous", "openai-codex") else "custom"
|
||||
known_keys = {k for k, _ in providers}
|
||||
active_key = active if active in known_keys else "custom"
|
||||
ordered = []
|
||||
for key, label in providers:
|
||||
if key == active_key:
|
||||
@@ -447,6 +517,8 @@ def cmd_model(args):
|
||||
_model_flow_openai_codex(config, current_model)
|
||||
elif selected_provider == "custom":
|
||||
_model_flow_custom(config)
|
||||
elif selected_provider in ("zai", "kimi-coding", "minimax", "minimax-cn"):
|
||||
_model_flow_api_key_provider(config, selected_provider, current_model)
|
||||
|
||||
|
||||
def _prompt_provider_choice(choices):
|
||||
@@ -716,6 +788,117 @@ def _model_flow_custom(config):
|
||||
print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.")
|
||||
|
||||
|
||||
# Curated model lists for direct API-key providers
|
||||
_PROVIDER_MODELS = {
|
||||
"zai": [
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
"minimax-cn": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
"""Generic flow for API-key providers (z.ai, Kimi, MiniMax)."""
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY, _prompt_model_selection, _save_model_choice,
|
||||
_update_config_for_provider, deactivate_provider,
|
||||
)
|
||||
from hermes_cli.config import get_env_value, save_env_value, load_config, save_config
|
||||
|
||||
pconfig = PROVIDER_REGISTRY[provider_id]
|
||||
key_env = pconfig.api_key_env_vars[0] if pconfig.api_key_env_vars else ""
|
||||
base_url_env = pconfig.base_url_env_var or ""
|
||||
|
||||
# Check / prompt for API key
|
||||
existing_key = ""
|
||||
for ev in pconfig.api_key_env_vars:
|
||||
existing_key = get_env_value(ev) or os.getenv(ev, "")
|
||||
if existing_key:
|
||||
break
|
||||
|
||||
if not existing_key:
|
||||
print(f"No {pconfig.name} API key configured.")
|
||||
if key_env:
|
||||
try:
|
||||
new_key = input(f"{key_env} (or Enter to cancel): ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return
|
||||
if not new_key:
|
||||
print("Cancelled.")
|
||||
return
|
||||
save_env_value(key_env, new_key)
|
||||
print("API key saved.")
|
||||
print()
|
||||
else:
|
||||
print(f" {pconfig.name} API key: {existing_key[:8]}... ✓")
|
||||
print()
|
||||
|
||||
# Optional base URL override
|
||||
current_base = ""
|
||||
if base_url_env:
|
||||
current_base = get_env_value(base_url_env) or os.getenv(base_url_env, "")
|
||||
effective_base = current_base or pconfig.inference_base_url
|
||||
|
||||
try:
|
||||
override = input(f"Base URL [{effective_base}]: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
override = ""
|
||||
if override and base_url_env:
|
||||
save_env_value(base_url_env, override)
|
||||
effective_base = override
|
||||
|
||||
# Model selection
|
||||
model_list = _PROVIDER_MODELS.get(provider_id, [])
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
else:
|
||||
try:
|
||||
selected = input("Model name: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
selected = None
|
||||
|
||||
if selected:
|
||||
# Clear custom endpoint if set (avoid confusion)
|
||||
if get_env_value("OPENAI_BASE_URL"):
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
|
||||
_save_model_choice(selected)
|
||||
|
||||
# Update config with provider and base URL
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if isinstance(model, dict):
|
||||
model["provider"] = provider_id
|
||||
model["base_url"] = effective_base
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
print(f"Default model set to: {selected} (via {pconfig.name})")
|
||||
else:
|
||||
print("No change.")
|
||||
|
||||
|
||||
def cmd_login(args):
|
||||
"""Authenticate Hermes CLI with a provider."""
|
||||
from hermes_cli.auth import login_command
|
||||
@@ -851,11 +1034,17 @@ def _update_via_zip(args):
|
||||
# Sync skills
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
print("→ Checking for new bundled skills...")
|
||||
print("→ Syncing bundled skills...")
|
||||
result = sync_skills(quiet=True)
|
||||
if result["copied"]:
|
||||
print(f" + {len(result['copied'])} new skill(s): {', '.join(result['copied'])}")
|
||||
else:
|
||||
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
|
||||
if result.get("updated"):
|
||||
print(f" ↑ {len(result['updated'])} updated: {', '.join(result['updated'])}")
|
||||
if result.get("user_modified"):
|
||||
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
|
||||
if result.get("cleaned"):
|
||||
print(f" − {len(result['cleaned'])} removed from manifest")
|
||||
if not result["copied"] and not result.get("updated"):
|
||||
print(" ✓ Skills are up to date")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -961,15 +1150,21 @@ def cmd_update(args):
|
||||
print()
|
||||
print("✓ Code updated!")
|
||||
|
||||
# Sync any new bundled skills (manifest-based -- won't overwrite or re-add deleted skills)
|
||||
# Sync bundled skills (copies new, updates changed, respects user deletions)
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
print()
|
||||
print("→ Checking for new bundled skills...")
|
||||
print("→ Syncing bundled skills...")
|
||||
result = sync_skills(quiet=True)
|
||||
if result["copied"]:
|
||||
print(f" + {len(result['copied'])} new skill(s): {', '.join(result['copied'])}")
|
||||
else:
|
||||
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
|
||||
if result.get("updated"):
|
||||
print(f" ↑ {len(result['updated'])} updated: {', '.join(result['updated'])}")
|
||||
if result.get("user_modified"):
|
||||
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
|
||||
if result.get("cleaned"):
|
||||
print(f" − {len(result['cleaned'])} removed from manifest")
|
||||
if not result["copied"] and not result.get("updated"):
|
||||
print(" ✓ Skills are up to date")
|
||||
except Exception as e:
|
||||
logger.debug("Skills sync during update failed: %s", e)
|
||||
@@ -1061,8 +1256,9 @@ def main():
|
||||
Examples:
|
||||
hermes Start interactive chat
|
||||
hermes chat -q "Hello" Single query mode
|
||||
hermes --continue Resume the most recent session
|
||||
hermes --resume <session_id> Resume a specific session
|
||||
hermes -c Resume the most recent session
|
||||
hermes -c "my project" Resume a session by name (latest in lineage)
|
||||
hermes --resume <session_id> Resume a specific session by ID
|
||||
hermes setup Run setup wizard
|
||||
hermes logout Clear stored authentication
|
||||
hermes model Select default model
|
||||
@@ -1070,8 +1266,10 @@ Examples:
|
||||
hermes config edit Edit config in $EDITOR
|
||||
hermes config set model gpt-4 Set a config value
|
||||
hermes gateway Run messaging gateway
|
||||
hermes -w Start in isolated git worktree
|
||||
hermes gateway install Install as system service
|
||||
hermes sessions list List past sessions
|
||||
hermes sessions rename ID T Rename/title a session
|
||||
hermes update Update to latest version
|
||||
|
||||
For more help on a command:
|
||||
@@ -1086,16 +1284,24 @@ For more help on a command:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume", "-r",
|
||||
metavar="SESSION_ID",
|
||||
metavar="SESSION",
|
||||
default=None,
|
||||
help="Resume a previous session by ID (shortcut for: hermes chat --resume ID)"
|
||||
help="Resume a previous session by ID or title"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--continue", "-c",
|
||||
dest="continue_last",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=None,
|
||||
metavar="SESSION_NAME",
|
||||
help="Resume a session by name, or the most recent if no name given"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--worktree", "-w",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Resume the most recent CLI session"
|
||||
help="Run in an isolated git worktree (for parallel agents)"
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
@@ -1122,7 +1328,7 @@ For more help on a command:
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--provider",
|
||||
choices=["auto", "openrouter", "nous", "openai-codex"],
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn"],
|
||||
default=None,
|
||||
help="Inference provider (default: auto)"
|
||||
)
|
||||
@@ -1139,9 +1345,17 @@ For more help on a command:
|
||||
chat_parser.add_argument(
|
||||
"--continue", "-c",
|
||||
dest="continue_last",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=None,
|
||||
metavar="SESSION_NAME",
|
||||
help="Resume a session by name, or the most recent if no name given"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--worktree", "-w",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Resume the most recent CLI session"
|
||||
help="Run in an isolated git worktree (for parallel agents on the same repo)"
|
||||
)
|
||||
chat_parser.set_defaults(func=cmd_chat)
|
||||
|
||||
@@ -1168,6 +1382,8 @@ For more help on a command:
|
||||
# gateway run (default)
|
||||
gateway_run = gateway_subparsers.add_parser("run", help="Run gateway in foreground")
|
||||
gateway_run.add_argument("-v", "--verbose", action="store_true")
|
||||
gateway_run.add_argument("--replace", action="store_true",
|
||||
help="Replace any existing gateway instance (useful for systemd)")
|
||||
|
||||
# gateway start
|
||||
gateway_start = gateway_subparsers.add_parser("start", help="Start gateway service")
|
||||
@@ -1200,7 +1416,15 @@ For more help on a command:
|
||||
setup_parser = subparsers.add_parser(
|
||||
"setup",
|
||||
help="Interactive setup wizard",
|
||||
description="Configure Hermes Agent with an interactive wizard"
|
||||
description="Configure Hermes Agent with an interactive wizard. "
|
||||
"Run a specific section: hermes setup model|terminal|gateway|tools|agent"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"section",
|
||||
nargs="?",
|
||||
choices=["model", "terminal", "gateway", "tools", "agent"],
|
||||
default=None,
|
||||
help="Run a specific setup section instead of the full wizard"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"--non-interactive",
|
||||
@@ -1424,9 +1648,16 @@ For more help on a command:
|
||||
)
|
||||
skills_subparsers = skills_parser.add_subparsers(dest="skills_action")
|
||||
|
||||
skills_browse = skills_subparsers.add_parser("browse", help="Browse all available skills (paginated)")
|
||||
skills_browse.add_argument("--page", type=int, default=1, help="Page number (default: 1)")
|
||||
skills_browse.add_argument("--size", type=int, default=20, help="Results per page (default: 20)")
|
||||
skills_browse.add_argument("--source", default="all",
|
||||
choices=["all", "official", "github", "clawhub", "lobehub"],
|
||||
help="Filter by source (default: all)")
|
||||
|
||||
skills_search = skills_subparsers.add_parser("search", help="Search skill registries")
|
||||
skills_search.add_argument("query", help="Search query")
|
||||
skills_search.add_argument("--source", default="all", choices=["all", "github", "clawhub", "lobehub"])
|
||||
skills_search.add_argument("--source", default="all", choices=["all", "official", "github", "clawhub", "lobehub"])
|
||||
skills_search.add_argument("--limit", type=int, default=10, help="Max results")
|
||||
|
||||
skills_install = skills_subparsers.add_parser("install", help="Install a skill")
|
||||
@@ -1493,7 +1724,7 @@ For more help on a command:
|
||||
# =========================================================================
|
||||
sessions_parser = subparsers.add_parser(
|
||||
"sessions",
|
||||
help="Manage session history (list, export, prune, delete)",
|
||||
help="Manage session history (list, rename, export, prune, delete)",
|
||||
description="View and manage the SQLite session store"
|
||||
)
|
||||
sessions_subparsers = sessions_parser.add_subparsers(dest="sessions_action")
|
||||
@@ -1518,6 +1749,10 @@ For more help on a command:
|
||||
|
||||
sessions_stats = sessions_subparsers.add_parser("stats", help="Show session store statistics")
|
||||
|
||||
sessions_rename = sessions_subparsers.add_parser("rename", help="Set or change a session's title")
|
||||
sessions_rename.add_argument("session_id", help="Session ID to rename")
|
||||
sessions_rename.add_argument("title", nargs="+", help="New title for the session")
|
||||
|
||||
def cmd_sessions(args):
|
||||
import json as _json
|
||||
try:
|
||||
@@ -1530,18 +1765,51 @@ For more help on a command:
|
||||
action = args.sessions_action
|
||||
|
||||
if action == "list":
|
||||
sessions = db.search_sessions(source=args.source, limit=args.limit)
|
||||
sessions = db.list_sessions_rich(source=args.source, limit=args.limit)
|
||||
if not sessions:
|
||||
print("No sessions found.")
|
||||
return
|
||||
print(f"{'ID':<30} {'Source':<12} {'Model':<30} {'Messages':>8} {'Started'}")
|
||||
print("─" * 100)
|
||||
from datetime import datetime
|
||||
import time as _time
|
||||
|
||||
def _relative_time(ts):
|
||||
"""Format a timestamp as relative time (e.g., '2h ago', 'yesterday')."""
|
||||
if not ts:
|
||||
return "?"
|
||||
delta = _time.time() - ts
|
||||
if delta < 60:
|
||||
return "just now"
|
||||
elif delta < 3600:
|
||||
mins = int(delta / 60)
|
||||
return f"{mins}m ago"
|
||||
elif delta < 86400:
|
||||
hours = int(delta / 3600)
|
||||
return f"{hours}h ago"
|
||||
elif delta < 172800:
|
||||
return "yesterday"
|
||||
elif delta < 604800:
|
||||
days = int(delta / 86400)
|
||||
return f"{days}d ago"
|
||||
else:
|
||||
return datetime.fromtimestamp(ts).strftime("%Y-%m-%d")
|
||||
|
||||
has_titles = any(s.get("title") for s in sessions)
|
||||
if has_titles:
|
||||
print(f"{'Title':<22} {'Preview':<40} {'Last Active':<13} {'ID'}")
|
||||
print("─" * 100)
|
||||
else:
|
||||
print(f"{'Preview':<50} {'Last Active':<13} {'Src':<6} {'ID'}")
|
||||
print("─" * 90)
|
||||
for s in sessions:
|
||||
started = datetime.fromtimestamp(s["started_at"]).strftime("%Y-%m-%d %H:%M") if s["started_at"] else "?"
|
||||
model = (s.get("model") or "?")[:28]
|
||||
ended = " (ended)" if s.get("ended_at") else ""
|
||||
print(f"{s['id']:<30} {s['source']:<12} {model:<30} {s['message_count']:>8} {started}{ended}")
|
||||
last_active = _relative_time(s.get("last_active"))
|
||||
preview = s.get("preview", "")[:38] if has_titles else s.get("preview", "")[:48]
|
||||
if has_titles:
|
||||
title = (s.get("title") or "—")[:20]
|
||||
sid = s["id"][:20]
|
||||
print(f"{title:<22} {preview:<40} {last_active:<13} {sid}")
|
||||
else:
|
||||
sid = s["id"][:20]
|
||||
print(f"{preview:<50} {last_active:<13} {s['source']:<6} {sid}")
|
||||
|
||||
elif action == "export":
|
||||
if args.session_id:
|
||||
@@ -1581,6 +1849,16 @@ For more help on a command:
|
||||
count = db.prune_sessions(older_than_days=days, source=args.source)
|
||||
print(f"Pruned {count} session(s).")
|
||||
|
||||
elif action == "rename":
|
||||
title = " ".join(args.title)
|
||||
try:
|
||||
if db.set_session_title(args.session_id, title):
|
||||
print(f"Session '{args.session_id}' renamed to: {title}")
|
||||
else:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
elif action == "stats":
|
||||
total = db.session_count()
|
||||
msgs = db.message_count()
|
||||
@@ -1603,6 +1881,32 @@ For more help on a command:
|
||||
|
||||
sessions_parser.set_defaults(func=cmd_sessions)
|
||||
|
||||
# =========================================================================
|
||||
# insights command
|
||||
# =========================================================================
|
||||
insights_parser = subparsers.add_parser(
|
||||
"insights",
|
||||
help="Show usage insights and analytics",
|
||||
description="Analyze session history to show token usage, costs, tool patterns, and activity trends"
|
||||
)
|
||||
insights_parser.add_argument("--days", type=int, default=30, help="Number of days to analyze (default: 30)")
|
||||
insights_parser.add_argument("--source", help="Filter by platform (cli, telegram, discord, etc.)")
|
||||
|
||||
def cmd_insights(args):
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
from agent.insights import InsightsEngine
|
||||
|
||||
db = SessionDB()
|
||||
engine = InsightsEngine(db)
|
||||
report = engine.generate(days=args.days, source=args.source)
|
||||
print(engine.format_terminal(report))
|
||||
db.close()
|
||||
except Exception as e:
|
||||
print(f"Error generating insights: {e}")
|
||||
|
||||
insights_parser.set_defaults(func=cmd_insights)
|
||||
|
||||
# =========================================================================
|
||||
# version command
|
||||
# =========================================================================
|
||||
@@ -1660,6 +1964,8 @@ For more help on a command:
|
||||
args.provider = None
|
||||
args.toolsets = None
|
||||
args.verbose = False
|
||||
if not hasattr(args, "worktree"):
|
||||
args.worktree = False
|
||||
cmd_chat(args)
|
||||
return
|
||||
|
||||
@@ -1671,7 +1977,9 @@ For more help on a command:
|
||||
args.toolsets = None
|
||||
args.verbose = False
|
||||
args.resume = None
|
||||
args.continue_last = False
|
||||
args.continue_last = None
|
||||
if not hasattr(args, "worktree"):
|
||||
args.worktree = False
|
||||
cmd_chat(args)
|
||||
return
|
||||
|
||||
|
||||
@@ -1,27 +1,85 @@
|
||||
"""
|
||||
Canonical list of OpenRouter models offered in CLI and setup wizards.
|
||||
Canonical model catalogs and lightweight validation helpers.
|
||||
|
||||
Add, remove, or reorder entries here — both `hermes setup` and
|
||||
`hermes` provider-selection will pick up the change automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from difflib import get_close_matches
|
||||
from typing import Any, Optional
|
||||
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
("anthropic/claude-sonnet-4.5", ""),
|
||||
("anthropic/claude-opus-4.5", ""),
|
||||
("openai/gpt-5.2", ""),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("z-ai/glm-4.7", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("minimax/minimax-m2.1", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
]
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"zai": [
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
"minimax-cn": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
}
|
||||
|
||||
_PROVIDER_LABELS = {
|
||||
"openrouter": "OpenRouter",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"nous": "Nous Portal",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"custom": "custom endpoint",
|
||||
}
|
||||
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai",
|
||||
"z-ai": "zai",
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
"kimi": "kimi-coding",
|
||||
"moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn",
|
||||
"minimax_cn": "minimax-cn",
|
||||
}
|
||||
|
||||
|
||||
def model_ids() -> list[str]:
|
||||
"""Return just the model-id strings (convenience helper)."""
|
||||
"""Return just the OpenRouter model-id strings."""
|
||||
return [mid for mid, _ in OPENROUTER_MODELS]
|
||||
|
||||
|
||||
@@ -31,3 +89,231 @@ def menu_labels() -> list[str]:
|
||||
for mid, desc in OPENROUTER_MODELS:
|
||||
labels.append(f"{mid} ({desc})" if desc else mid)
|
||||
return labels
|
||||
|
||||
|
||||
# All provider IDs and aliases that are valid for the provider:model syntax.
|
||||
_KNOWN_PROVIDER_NAMES: set[str] = (
|
||||
set(_PROVIDER_LABELS.keys())
|
||||
| set(_PROVIDER_ALIASES.keys())
|
||||
| {"openrouter", "custom"}
|
||||
)
|
||||
|
||||
|
||||
def list_available_providers() -> list[dict[str, str]]:
|
||||
"""Return info about all providers the user could use with ``provider:model``.
|
||||
|
||||
Each dict has ``id``, ``label``, and ``aliases``.
|
||||
Checks which providers have valid credentials configured.
|
||||
"""
|
||||
# Canonical providers in display order
|
||||
_PROVIDER_ORDER = [
|
||||
"openrouter", "nous", "openai-codex",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn",
|
||||
]
|
||||
# Build reverse alias map
|
||||
aliases_for: dict[str, list[str]] = {}
|
||||
for alias, canonical in _PROVIDER_ALIASES.items():
|
||||
aliases_for.setdefault(canonical, []).append(alias)
|
||||
|
||||
result = []
|
||||
for pid in _PROVIDER_ORDER:
|
||||
label = _PROVIDER_LABELS.get(pid, pid)
|
||||
alias_list = aliases_for.get(pid, [])
|
||||
# Check if this provider has credentials available
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=pid)
|
||||
has_creds = bool(runtime.get("api_key"))
|
||||
except Exception:
|
||||
pass
|
||||
result.append({
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"aliases": alias_list,
|
||||
"authenticated": has_creds,
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
|
||||
"""Parse ``/model`` input into ``(provider, model)``.
|
||||
|
||||
Supports ``provider:model`` syntax to switch providers at runtime::
|
||||
|
||||
openrouter:anthropic/claude-sonnet-4.5 → ("openrouter", "anthropic/claude-sonnet-4.5")
|
||||
nous:hermes-3 → ("nous", "hermes-3")
|
||||
anthropic/claude-sonnet-4.5 → (current_provider, "anthropic/claude-sonnet-4.5")
|
||||
gpt-5.4 → (current_provider, "gpt-5.4")
|
||||
|
||||
The colon is only treated as a provider delimiter if the left side is a
|
||||
recognized provider name or alias. This avoids misinterpreting model names
|
||||
that happen to contain colons (e.g. ``anthropic/claude-3.5-sonnet:beta``).
|
||||
|
||||
Returns ``(provider, model)`` where *provider* is either the explicit
|
||||
provider from the input or *current_provider* if none was specified.
|
||||
"""
|
||||
stripped = raw.strip()
|
||||
colon = stripped.find(":")
|
||||
if colon > 0:
|
||||
provider_part = stripped[:colon].strip().lower()
|
||||
model_part = stripped[colon + 1:].strip()
|
||||
if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES:
|
||||
return (normalize_provider(provider_part), model_part)
|
||||
return (current_provider, stripped)
|
||||
|
||||
|
||||
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
|
||||
"""Return ``(model_id, description)`` tuples for a provider's curated list."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return list(OPENROUTER_MODELS)
|
||||
models = _PROVIDER_MODELS.get(normalized, [])
|
||||
return [(m, "") for m in models]
|
||||
|
||||
|
||||
def normalize_provider(provider: Optional[str]) -> str:
|
||||
"""Normalize provider aliases to Hermes' canonical provider ids.
|
||||
|
||||
Note: ``"auto"`` passes through unchanged — use
|
||||
``hermes_cli.auth.resolve_provider()`` to resolve it to a concrete
|
||||
provider based on credentials and environment.
|
||||
"""
|
||||
normalized = (provider or "openrouter").strip().lower()
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
|
||||
def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
"""Return the best known model catalog for a provider."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return model_ids()
|
||||
if normalized == "openai-codex":
|
||||
from hermes_cli.codex_models import get_codex_model_ids
|
||||
|
||||
return get_codex_model_ids()
|
||||
return list(_PROVIDER_MODELS.get(normalized, []))
|
||||
|
||||
|
||||
def fetch_api_models(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
timeout: float = 5.0,
|
||||
) -> Optional[list[str]]:
|
||||
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
|
||||
|
||||
Returns a list of model ID strings, or ``None`` if the endpoint could not
|
||||
be reached (network error, timeout, auth failure, etc.).
|
||||
"""
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
url = base_url.rstrip("/") + "/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
# Standard OpenAI format: {"data": [{"id": "model-name", ...}, ...]}
|
||||
return [m.get("id", "") for m in data.get("data", [])]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def validate_requested_model(
|
||||
model_name: str,
|
||||
provider: Optional[str],
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate a ``/model`` value for the active provider.
|
||||
|
||||
Performs format checks first, then probes the live API to confirm
|
||||
the model actually exists.
|
||||
|
||||
Returns a dict with:
|
||||
- accepted: whether the CLI should switch to the requested model now
|
||||
- persist: whether it is safe to save to config
|
||||
- recognized: whether it matched a known provider catalog
|
||||
- message: optional warning / guidance for the user
|
||||
"""
|
||||
requested = (model_name or "").strip()
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter" and base_url and "openrouter.ai" not in base_url:
|
||||
normalized = "custom"
|
||||
|
||||
if not requested:
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": "Model name cannot be empty.",
|
||||
}
|
||||
|
||||
if any(ch.isspace() for ch in requested):
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": "Model names cannot contain spaces.",
|
||||
}
|
||||
|
||||
# Probe the live API to check if the model actually exists
|
||||
api_models = fetch_api_models(api_key, base_url)
|
||||
|
||||
if api_models is not None:
|
||||
if requested in set(api_models):
|
||||
# API confirmed the model exists
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
else:
|
||||
# API responded but model is not listed
|
||||
suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5)
|
||||
suggestion_text = ""
|
||||
if suggestions:
|
||||
suggestion_text = "\n Did you mean: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Error: `{requested}` is not a valid model for this provider."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
# api_models is None — couldn't reach API, fall back to catalog check
|
||||
provider_label = _PROVIDER_LABELS.get(normalized, normalized)
|
||||
known_models = provider_model_ids(normalized)
|
||||
|
||||
if requested in known_models:
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
# Can't validate — accept for session only
|
||||
suggestion = get_close_matches(requested, known_models, n=1, cutoff=0.6)
|
||||
suggestion_text = f" Did you mean `{suggestion[0]}`?" if suggestion else ""
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Could not validate `{requested}` against the live {provider_label} API. "
|
||||
"Using it for this session only; config unchanged."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
@@ -7,10 +7,12 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
PROVIDER_REGISTRY,
|
||||
format_auth_error,
|
||||
resolve_provider,
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
)
|
||||
from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
@@ -72,12 +74,26 @@ def _resolve_openrouter_runtime(
|
||||
or OPENROUTER_BASE_URL
|
||||
).rstrip("/")
|
||||
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
# Choose API key based on whether the resolved base_url targets OpenRouter.
|
||||
# When hitting OpenRouter, prefer OPENROUTER_API_KEY (issue #289).
|
||||
# When hitting a custom endpoint (e.g. Z.ai, local LLM), prefer
|
||||
# OPENAI_API_KEY so the OpenRouter key doesn't leak to an unrelated
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = "openrouter.ai" in base_url
|
||||
if _is_openrouter_url:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
else:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
|
||||
|
||||
@@ -132,6 +148,19 @@ def resolve_runtime_provider(
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
# API-key providers (z.ai/GLM, Kimi, MiniMax, MiniMax-CN)
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": creds.get("base_url", "").rstrip("/"),
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"source": creds.get("source", "env"),
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
runtime = _resolve_openrouter_runtime(
|
||||
requested_provider=requested_provider,
|
||||
explicit_api_key=explicit_api_key,
|
||||
|
||||
1702
hermes_cli/setup.py
1702
hermes_cli/setup.py
File diff suppressed because it is too large
Load Diff
@@ -57,8 +57,9 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
|
||||
table.add_column("Trust", style="dim")
|
||||
table.add_column("Identifier", style="bold cyan")
|
||||
for r in exact:
|
||||
trust_style = {"trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
|
||||
table.add_row(r.source, f"[{trust_style}]{r.trust_level}[/]", r.identifier)
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_label = "official" if r.source == "official" else r.trust_level
|
||||
table.add_row(r.source, f"[{trust_style}]{trust_label}[/]", r.identifier)
|
||||
c.print(table)
|
||||
c.print("[bold]Use the full identifier to install a specific one.[/]\n")
|
||||
return ""
|
||||
@@ -99,12 +100,13 @@ def do_search(query: str, source: str = "all", limit: int = 10,
|
||||
table.add_column("Identifier", style="dim")
|
||||
|
||||
for r in results:
|
||||
trust_style = {"trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_label = "official" if r.source == "official" else r.trust_level
|
||||
table.add_row(
|
||||
r.name,
|
||||
r.description[:60] + ("..." if len(r.description) > 60 else ""),
|
||||
r.source,
|
||||
f"[{trust_style}]{r.trust_level}[/]",
|
||||
f"[{trust_style}]{trust_label}[/]",
|
||||
r.identifier,
|
||||
)
|
||||
|
||||
@@ -113,6 +115,130 @@ def do_search(query: str, source: str = "all", limit: int = 10,
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
|
||||
|
||||
def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
||||
console: Optional[Console] = None) -> None:
|
||||
"""Browse all available skills across registries, paginated.
|
||||
|
||||
Official skills are always shown first, regardless of source filter.
|
||||
"""
|
||||
from tools.skills_hub import (
|
||||
GitHubAuth, create_source_router, OptionalSkillSource, SkillMeta,
|
||||
)
|
||||
|
||||
# Clamp page_size to safe range
|
||||
page_size = max(1, min(page_size, 100))
|
||||
|
||||
c = console or _console
|
||||
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
|
||||
# Collect results from all (or filtered) sources
|
||||
# Use empty query to get everything; per-source limits prevent overload
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
|
||||
all_results: list = []
|
||||
source_counts: dict = {}
|
||||
|
||||
for src in sources:
|
||||
sid = src.source_id()
|
||||
if source != "all" and sid != source and sid != "official":
|
||||
# Always include official source for the "first" placement
|
||||
continue
|
||||
try:
|
||||
limit = _PER_SOURCE_LIMIT.get(sid, 50)
|
||||
results = src.search("", limit=limit)
|
||||
source_counts[sid] = len(results)
|
||||
all_results.extend(results)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not all_results:
|
||||
c.print("[dim]No skills found in the Skills Hub.[/]\n")
|
||||
return
|
||||
|
||||
# Deduplicate by name, preferring higher trust
|
||||
seen: dict = {}
|
||||
for r in all_results:
|
||||
rank = _TRUST_RANK.get(r.trust_level, 0)
|
||||
if r.name not in seen or rank > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
deduped = list(seen.values())
|
||||
|
||||
# Sort: official first, then by trust level (desc), then alphabetically
|
||||
deduped.sort(key=lambda r: (
|
||||
-_TRUST_RANK.get(r.trust_level, 0),
|
||||
r.source != "official",
|
||||
r.name.lower(),
|
||||
))
|
||||
|
||||
# Paginate
|
||||
total = len(deduped)
|
||||
total_pages = max(1, (total + page_size - 1) // page_size)
|
||||
page = max(1, min(page, total_pages))
|
||||
start = (page - 1) * page_size
|
||||
end = min(start + page_size, total)
|
||||
page_items = deduped[start:end]
|
||||
|
||||
# Count official vs other
|
||||
official_count = sum(1 for r in deduped if r.source == "official")
|
||||
|
||||
# Build header
|
||||
source_label = f"— {source}" if source != "all" else "— all sources"
|
||||
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
|
||||
f" [dim]({total} skills, page {page}/{total_pages})[/]")
|
||||
if official_count > 0 and page == 1:
|
||||
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
|
||||
c.print()
|
||||
|
||||
# Build table
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("#", style="dim", width=4, justify="right")
|
||||
table.add_column("Name", style="bold cyan", max_width=25)
|
||||
table.add_column("Description", max_width=50)
|
||||
table.add_column("Source", style="dim", width=12)
|
||||
table.add_column("Trust", width=10)
|
||||
|
||||
for i, r in enumerate(page_items, start=start + 1):
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green",
|
||||
"community": "yellow"}.get(r.trust_level, "dim")
|
||||
trust_label = "★ official" if r.source == "official" else r.trust_level
|
||||
|
||||
desc = r.description[:50]
|
||||
if len(r.description) > 50:
|
||||
desc += "..."
|
||||
|
||||
table.add_row(
|
||||
str(i),
|
||||
r.name,
|
||||
desc,
|
||||
r.source,
|
||||
f"[{trust_style}]{trust_label}[/]",
|
||||
)
|
||||
|
||||
c.print(table)
|
||||
|
||||
# Navigation hints
|
||||
nav_parts = []
|
||||
if page > 1:
|
||||
nav_parts.append(f"[cyan]--page {page - 1}[/] ← prev")
|
||||
if page < total_pages:
|
||||
nav_parts.append(f"[cyan]--page {page + 1}[/] → next")
|
||||
|
||||
if nav_parts:
|
||||
c.print(f" {' | '.join(nav_parts)}")
|
||||
|
||||
# Source summary
|
||||
if source == "all" and source_counts:
|
||||
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
|
||||
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
|
||||
|
||||
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
|
||||
"hermes skills install <identifier> to install[/]\n")
|
||||
|
||||
|
||||
def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
console: Optional[Console] = None) -> None:
|
||||
"""Fetch, quarantine, scan, confirm, and install a skill."""
|
||||
@@ -147,6 +273,12 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
c.print(f"[bold red]Error:[/] Could not fetch '{identifier}' from any source.\n")
|
||||
return
|
||||
|
||||
# Auto-detect category for official skills (e.g. "official/autonomous-ai-agents/blackbox")
|
||||
if bundle.source == "official" and not category:
|
||||
id_parts = bundle.identifier.split("/") # ["official", "category", "skill"]
|
||||
if len(id_parts) >= 3:
|
||||
category = id_parts[1]
|
||||
|
||||
# Check if already installed
|
||||
lock = HubLockFile()
|
||||
existing = lock.get_installed(bundle.name)
|
||||
@@ -177,18 +309,28 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
f"{len(result.findings)}_findings")
|
||||
return
|
||||
|
||||
# Confirm with user — always show risk warning regardless of source
|
||||
# Confirm with user — show appropriate warning based on source
|
||||
if not force:
|
||||
c.print()
|
||||
c.print(Panel(
|
||||
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
|
||||
"External skills can contain instructions that influence agent behavior,\n"
|
||||
"shell commands, and scripts. Even after automated scanning, you should\n"
|
||||
"review the installed files before use.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Disclaimer",
|
||||
border_style="yellow",
|
||||
))
|
||||
if bundle.source == "official":
|
||||
c.print(Panel(
|
||||
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
|
||||
"It ships with hermes-agent but is not activated by default.\n"
|
||||
"Installing will copy it to your skills directory where the agent can use it.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Official Skill",
|
||||
border_style="bright_cyan",
|
||||
))
|
||||
else:
|
||||
c.print(Panel(
|
||||
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
|
||||
"External skills can contain instructions that influence agent behavior,\n"
|
||||
"shell commands, and scripts. Even after automated scanning, you should\n"
|
||||
"review the installed files before use.\n\n"
|
||||
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
|
||||
title="Disclaimer",
|
||||
border_style="yellow",
|
||||
))
|
||||
c.print(f"[bold]Install '{bundle.name}'?[/]")
|
||||
try:
|
||||
answer = input("Confirm [y/N]: ").strip().lower()
|
||||
@@ -237,13 +379,14 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
break
|
||||
|
||||
c.print()
|
||||
trust_style = {"trusted": "green", "community": "yellow"}.get(meta.trust_level, "dim")
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(meta.trust_level, "dim")
|
||||
trust_label = "official" if meta.source == "official" else meta.trust_level
|
||||
|
||||
info_lines = [
|
||||
f"[bold]Name:[/] {meta.name}",
|
||||
f"[bold]Description:[/] {meta.description}",
|
||||
f"[bold]Source:[/] {meta.source}",
|
||||
f"[bold]Trust:[/] [{trust_style}]{meta.trust_level}[/]",
|
||||
f"[bold]Trust:[/] [{trust_style}]{trust_label}[/]",
|
||||
f"[bold]Identifier:[/] {meta.identifier}",
|
||||
]
|
||||
if meta.tags:
|
||||
@@ -297,8 +440,9 @@ def do_list(source_filter: str = "all", console: Optional[Console] = None) -> No
|
||||
if source_filter == "builtin" and hub_entry:
|
||||
continue
|
||||
|
||||
trust_style = {"builtin": "blue", "trusted": "green", "community": "yellow"}.get(trust, "dim")
|
||||
table.add_row(name, category, source_display, f"[{trust_style}]{trust}[/]")
|
||||
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(trust, "dim")
|
||||
trust_label = "official" if source_display == "official" else trust
|
||||
table.add_row(name, category, source_display, f"[{trust_style}]{trust_label}[/]")
|
||||
|
||||
c.print(table)
|
||||
c.print(f"[dim]{len(hub_installed)} hub-installed, "
|
||||
@@ -658,7 +802,9 @@ def skills_command(args) -> None:
|
||||
"""Router for `hermes skills <subcommand>` — called from hermes_cli/main.py."""
|
||||
action = getattr(args, "skills_action", None)
|
||||
|
||||
if action == "search":
|
||||
if action == "browse":
|
||||
do_browse(page=args.page, page_size=args.size, source=args.source)
|
||||
elif action == "search":
|
||||
do_search(args.query, source=args.source, limit=args.limit)
|
||||
elif action == "install":
|
||||
do_install(args.identifier, category=args.category, force=args.force)
|
||||
@@ -692,7 +838,7 @@ def skills_command(args) -> None:
|
||||
return
|
||||
do_tap(tap_action, repo=repo)
|
||||
else:
|
||||
_console.print("Usage: hermes skills [search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
|
||||
_console.print("Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
|
||||
_console.print("Run 'hermes skills <command> --help' for details.\n")
|
||||
|
||||
|
||||
@@ -732,7 +878,32 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
|
||||
action = parts[0].lower()
|
||||
args = parts[1:]
|
||||
|
||||
if action == "search":
|
||||
if action == "browse":
|
||||
page = 1
|
||||
page_size = 20
|
||||
source = "all"
|
||||
i = 0
|
||||
while i < len(args):
|
||||
if args[i] == "--page" and i + 1 < len(args):
|
||||
try:
|
||||
page = int(args[i + 1])
|
||||
except ValueError:
|
||||
pass
|
||||
i += 2
|
||||
elif args[i] == "--size" and i + 1 < len(args):
|
||||
try:
|
||||
page_size = int(args[i + 1])
|
||||
except ValueError:
|
||||
pass
|
||||
i += 2
|
||||
elif args[i] == "--source" and i + 1 < len(args):
|
||||
source = args[i + 1]
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
do_browse(page=page, page_size=page_size, source=source, console=c)
|
||||
|
||||
elif action == "search":
|
||||
if not args:
|
||||
c.print("[bold red]Usage:[/] /skills search <query> [--source github] [--limit N]\n")
|
||||
return
|
||||
@@ -838,6 +1009,7 @@ def _print_skills_help(console: Console) -> None:
|
||||
"""Print help for the /skills slash command."""
|
||||
console.print(Panel(
|
||||
"[bold]Skills Hub Commands:[/]\n\n"
|
||||
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
|
||||
" [cyan]search[/] <query> Search registries for skills\n"
|
||||
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
|
||||
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"
|
||||
|
||||
@@ -79,8 +79,12 @@ def show_status(args):
|
||||
"OpenRouter": "OPENROUTER_API_KEY",
|
||||
"Anthropic": "ANTHROPIC_API_KEY",
|
||||
"OpenAI": "OPENAI_API_KEY",
|
||||
"Z.AI/GLM": "GLM_API_KEY",
|
||||
"Kimi": "KIMI_API_KEY",
|
||||
"MiniMax": "MINIMAX_API_KEY",
|
||||
"MiniMax-CN": "MINIMAX_CN_API_KEY",
|
||||
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
@@ -128,7 +132,7 @@ def show_status(args):
|
||||
f" {'OpenAI Codex':<12} {check_mark(codex_logged_in)} "
|
||||
f"{'logged in' if codex_logged_in else 'not logged in (run: hermes model)'}"
|
||||
)
|
||||
codex_auth_file = codex_status.get("auth_file")
|
||||
codex_auth_file = codex_status.get("auth_store")
|
||||
if codex_auth_file:
|
||||
print(f" Auth file: {codex_auth_file}")
|
||||
codex_last_refresh = _format_iso_timestamp(codex_status.get("last_refresh"))
|
||||
@@ -137,6 +141,28 @@ def show_status(args):
|
||||
if codex_status.get("error") and not codex_logged_in:
|
||||
print(f" Error: {codex_status.get('error')}")
|
||||
|
||||
# =========================================================================
|
||||
# API-Key Providers
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
apikey_providers = {
|
||||
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"Kimi / Moonshot": ("KIMI_API_KEY",),
|
||||
"MiniMax": ("MINIMAX_API_KEY",),
|
||||
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
|
||||
}
|
||||
for pname, env_vars in apikey_providers.items():
|
||||
key_val = ""
|
||||
for ev in env_vars:
|
||||
key_val = get_env_value(ev) or ""
|
||||
if key_val:
|
||||
break
|
||||
configured = bool(key_val)
|
||||
label = "configured" if configured else "not configured (run: hermes model)"
|
||||
print(f" {pname:<16} {check_mark(configured)} {label}")
|
||||
|
||||
# =========================================================================
|
||||
# Terminal Configuration
|
||||
# =========================================================================
|
||||
@@ -163,6 +189,9 @@ def show_status(args):
|
||||
elif terminal_env == "docker":
|
||||
docker_image = os.getenv("TERMINAL_DOCKER_IMAGE", "python:3.11-slim")
|
||||
print(f" Docker Image: {docker_image}")
|
||||
elif terminal_env == "daytona":
|
||||
daytona_image = os.getenv("TERMINAL_DAYTONA_IMAGE", "nikolaik/python-nodejs:python3.11-nodejs20")
|
||||
print(f" Daytona Image: {daytona_image}")
|
||||
|
||||
sudo_password = os.getenv("SUDO_PASSWORD", "")
|
||||
print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}")
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""
|
||||
Interactive tool configuration for Hermes Agent.
|
||||
Unified tool configuration for Hermes Agent.
|
||||
|
||||
`hermes tools` and `hermes setup tools` both enter this module.
|
||||
Select a platform → toggle toolsets on/off → for newly enabled tools
|
||||
that need API keys, run through provider-aware configuration.
|
||||
|
||||
`hermes tools` — select a platform, then toggle toolsets on/off via checklist.
|
||||
Saves per-platform tool configuration to ~/.hermes/config.yaml under
|
||||
the `platform_toolsets` key.
|
||||
"""
|
||||
@@ -12,9 +15,63 @@ from typing import Dict, List, Set
|
||||
|
||||
import os
|
||||
|
||||
from hermes_cli.config import load_config, save_config, get_env_value, save_env_value
|
||||
from hermes_cli.config import (
|
||||
load_config, save_config, get_env_value, save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
|
||||
|
||||
def _print_info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _print_success(text: str):
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
def _print_warning(text: str):
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _print_error(text: str):
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
else:
|
||||
display = f"{question}: "
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default or ""
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default or ""
|
||||
|
||||
def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
while True:
|
||||
try:
|
||||
value = input(color(f"{question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
if not value:
|
||||
return default
|
||||
if value in ('y', 'yes'):
|
||||
return True
|
||||
if value in ('n', 'no'):
|
||||
return False
|
||||
|
||||
|
||||
# ─── Toolset Registry ─────────────────────────────────────────────────────────
|
||||
|
||||
# Toolsets shown in the configurator, grouped for display.
|
||||
# Each entry: (toolset_name, label, description)
|
||||
# These map to keys in toolsets.py TOOLSETS dict.
|
||||
@@ -49,6 +106,187 @@ PLATFORMS = {
|
||||
}
|
||||
|
||||
|
||||
# ─── Tool Categories (provider-aware configuration) ──────────────────────────
|
||||
# Maps toolset keys to their provider options. When a toolset is newly enabled,
|
||||
# we use this to show provider selection and prompt for the right API keys.
|
||||
# Toolsets not in this map either need no config or use the simple fallback.
|
||||
|
||||
TOOL_CATEGORIES = {
|
||||
"tts": {
|
||||
"name": "Text-to-Speech",
|
||||
"icon": "🔊",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Microsoft Edge TTS",
|
||||
"tag": "Free - no API key needed",
|
||||
"env_vars": [],
|
||||
"tts_provider": "edge",
|
||||
},
|
||||
{
|
||||
"name": "OpenAI TTS",
|
||||
"tag": "Premium - high quality voices",
|
||||
"env_vars": [
|
||||
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
|
||||
],
|
||||
"tts_provider": "openai",
|
||||
},
|
||||
{
|
||||
"name": "ElevenLabs",
|
||||
"tag": "Premium - most natural voices",
|
||||
"env_vars": [
|
||||
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
|
||||
],
|
||||
"tts_provider": "elevenlabs",
|
||||
},
|
||||
],
|
||||
},
|
||||
"web": {
|
||||
"name": "Web Search & Extract",
|
||||
"icon": "🔍",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Firecrawl Cloud",
|
||||
"tag": "Recommended - hosted service",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl Self-Hosted",
|
||||
"tag": "Free - run your own instance",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"image_gen": {
|
||||
"name": "Image Generation",
|
||||
"icon": "🎨",
|
||||
"providers": [
|
||||
{
|
||||
"name": "FAL.ai",
|
||||
"tag": "FLUX 2 Pro with auto-upscaling",
|
||||
"env_vars": [
|
||||
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"browser": {
|
||||
"name": "Browser Automation",
|
||||
"icon": "🌐",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Local Browser",
|
||||
"tag": "Free headless Chromium (no API key needed)",
|
||||
"env_vars": [],
|
||||
"post_setup": "browserbase", # Same npm install for agent-browser
|
||||
},
|
||||
{
|
||||
"name": "Browserbase",
|
||||
"tag": "Cloud browser with stealth & proxies",
|
||||
"env_vars": [
|
||||
{"key": "BROWSERBASE_API_KEY", "prompt": "Browserbase API key", "url": "https://browserbase.com"},
|
||||
{"key": "BROWSERBASE_PROJECT_ID", "prompt": "Browserbase project ID"},
|
||||
],
|
||||
"post_setup": "browserbase",
|
||||
},
|
||||
],
|
||||
},
|
||||
"homeassistant": {
|
||||
"name": "Smart Home",
|
||||
"icon": "🏠",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Home Assistant",
|
||||
"tag": "REST API integration",
|
||||
"env_vars": [
|
||||
{"key": "HASS_TOKEN", "prompt": "Home Assistant Long-Lived Access Token"},
|
||||
{"key": "HASS_URL", "prompt": "Home Assistant URL", "default": "http://homeassistant.local:8123"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training",
|
||||
"icon": "🧪",
|
||||
"requires_python": (3, 11),
|
||||
"providers": [
|
||||
{
|
||||
"name": "Tinker / Atropos",
|
||||
"tag": "RL training platform",
|
||||
"env_vars": [
|
||||
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
|
||||
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
|
||||
],
|
||||
"post_setup": "rl_training",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES.
|
||||
# Used as a fallback for tools like vision/moa that just need an API key.
|
||||
TOOLSET_ENV_REQUIREMENTS = {
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
}
|
||||
|
||||
|
||||
# ─── Post-Setup Hooks ─────────────────────────────────────────────────────────
|
||||
|
||||
def _run_post_setup(post_setup_key: str):
|
||||
"""Run post-setup hooks for tools that need extra installation steps."""
|
||||
import shutil
|
||||
if post_setup_key == "browserbase":
|
||||
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
|
||||
if not node_modules.exists() and shutil.which("npm"):
|
||||
_print_info(" Installing Node.js dependencies for browser tools...")
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--silent"],
|
||||
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" Node.js dependencies installed")
|
||||
else:
|
||||
_print_warning(" npm install failed - run manually: cd ~/.hermes/hermes-agent && npm install")
|
||||
elif not node_modules.exists():
|
||||
_print_warning(" Node.js not found - browser tools require: npm install (in hermes-agent directory)")
|
||||
|
||||
elif post_setup_key == "rl_training":
|
||||
try:
|
||||
__import__("tinker_atropos")
|
||||
except ImportError:
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
_print_info(" Installing tinker-atropos submodule...")
|
||||
import subprocess
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
result = subprocess.run(
|
||||
[uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" tinker-atropos installed")
|
||||
else:
|
||||
_print_warning(" tinker-atropos install failed - run manually:")
|
||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
||||
else:
|
||||
_print_warning(" tinker-atropos submodule not found - run:")
|
||||
_print_info(" git submodule update --init --recursive")
|
||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
||||
|
||||
|
||||
# ─── Platform / Toolset Helpers ───────────────────────────────────────────────
|
||||
|
||||
def _get_enabled_platforms() -> List[str]:
|
||||
"""Return platform keys that are configured (have tokens or are CLI)."""
|
||||
enabled = ["cli"]
|
||||
@@ -70,7 +308,7 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
platform_toolsets = config.get("platform_toolsets", {})
|
||||
toolset_names = platform_toolsets.get(platform)
|
||||
|
||||
if not toolset_names or not isinstance(toolset_names, list):
|
||||
if toolset_names is None or not isinstance(toolset_names, list):
|
||||
default_ts = PLATFORMS[platform]["default_toolset"]
|
||||
toolset_names = [default_ts]
|
||||
|
||||
@@ -97,61 +335,117 @@ def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys)."""
|
||||
print(color(question, Colors.YELLOW))
|
||||
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu = TerminalMenu(
|
||||
[f" {c}" for c in choices],
|
||||
cursor_index=default,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
)
|
||||
idx = menu.show()
|
||||
if idx is None:
|
||||
sys.exit(0)
|
||||
print()
|
||||
return idx
|
||||
except (ImportError, NotImplementedError):
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {c}", style) if style else f" {marker} {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _toolset_has_keys(ts_key: str) -> bool:
|
||||
"""Check if a toolset's required API keys are configured."""
|
||||
# Check TOOL_CATEGORIES first (provider-aware)
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
if cat:
|
||||
for provider in cat["providers"]:
|
||||
env_vars = provider.get("env_vars", [])
|
||||
if not env_vars:
|
||||
return True # Free provider (e.g., Edge TTS)
|
||||
if all(get_env_value(v["key"]) for v in env_vars):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Fallback to simple requirements
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return True
|
||||
return all(get_env_value(var) for var, _ in requirements)
|
||||
|
||||
|
||||
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
|
||||
rendering bugs in tmux, iTerm, and other non-standard terminals."""
|
||||
|
||||
# Curses-based single-select — works in tmux, iTerm, and standard terminals
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(0, 0, question, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, c in enumerate(choices):
|
||||
y = i + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {c}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: numbered input (Windows without curses, etc.)
|
||||
print(color(question, Colors.YELLOW))
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
|
||||
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
import platform as _platform
|
||||
|
||||
labels = []
|
||||
for ts_key, ts_label, ts_desc in CONFIGURABLE_TOOLSETS:
|
||||
suffix = ""
|
||||
if not _toolset_has_keys(ts_key) and TOOLSET_ENV_REQUIREMENTS.get(ts_key):
|
||||
suffix = " ⚠ no API key"
|
||||
if not _toolset_has_keys(ts_key) and (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
suffix = " [no API key]"
|
||||
labels.append(f"{ts_label} ({ts_desc}){suffix}")
|
||||
|
||||
pre_selected_indices = [
|
||||
@@ -159,48 +453,8 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
if ts_key in enabled
|
||||
]
|
||||
|
||||
# simple_term_menu multi-select has rendering bugs on macOS terminals,
|
||||
# so we use a curses-based fallback there.
|
||||
use_term_menu = _platform.system() != "Darwin"
|
||||
|
||||
if use_term_menu:
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
|
||||
print(color(f"Tools for {platform_label}", Colors.YELLOW))
|
||||
print(color(" SPACE to toggle, ENTER to confirm.", Colors.DIM))
|
||||
print()
|
||||
|
||||
menu_items = [f" {label}" for label in labels]
|
||||
menu = TerminalMenu(
|
||||
menu_items,
|
||||
multi_select=True,
|
||||
show_multi_select_hint=False,
|
||||
multi_select_cursor="[✓] ",
|
||||
multi_select_select_on_accept=False,
|
||||
multi_select_empty_ok=True,
|
||||
preselected_entries=pre_selected_indices if pre_selected_indices else None,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
clear_menu_on_exit=False,
|
||||
)
|
||||
|
||||
menu.show()
|
||||
|
||||
if menu.chosen_menu_entries is None:
|
||||
return enabled
|
||||
|
||||
selected_indices = list(menu.chosen_menu_indices or [])
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected_indices}
|
||||
|
||||
except (ImportError, NotImplementedError):
|
||||
pass # fall through to curses/numbered fallback
|
||||
|
||||
# Curses-based multi-select — arrow keys + space to toggle + enter to confirm.
|
||||
# Used on macOS (where simple_term_menu ghosts) and as a fallback.
|
||||
# simple_term_menu has rendering bugs in tmux, iTerm, and other terminals.
|
||||
try:
|
||||
import curses
|
||||
selected = set(pre_selected_indices)
|
||||
@@ -302,77 +556,294 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected}
|
||||
|
||||
|
||||
# Map toolset keys to the env vars they require and where to get them
|
||||
TOOLSET_ENV_REQUIREMENTS = {
|
||||
"web": [("FIRECRAWL_API_KEY", "https://firecrawl.dev/")],
|
||||
"browser": [("BROWSERBASE_API_KEY", "https://browserbase.com/"),
|
||||
("BROWSERBASE_PROJECT_ID", None)],
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"image_gen": [("FAL_KEY", "https://fal.ai/")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"tts": [], # Edge TTS is free, no key needed
|
||||
"rl": [("TINKER_API_KEY", "https://tinker-console.thinkingmachines.ai/keys"),
|
||||
("WANDB_API_KEY", "https://wandb.ai/authorize")],
|
||||
"homeassistant": [("HASS_TOKEN", "Home Assistant > Profile > Long-Lived Access Tokens"),
|
||||
("HASS_URL", None)],
|
||||
}
|
||||
# ─── Provider-Aware Configuration ────────────────────────────────────────────
|
||||
|
||||
def _configure_toolset(ts_key: str, config: dict):
|
||||
"""Configure a toolset - provider selection + API keys.
|
||||
|
||||
Uses TOOL_CATEGORIES for provider-aware config, falls back to simple
|
||||
env var prompts for toolsets not in TOOL_CATEGORIES.
|
||||
"""
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
|
||||
if cat:
|
||||
_configure_tool_category(ts_key, cat, config)
|
||||
else:
|
||||
# Simple fallback for vision, moa, etc.
|
||||
_configure_simple_requirements(ts_key)
|
||||
|
||||
|
||||
def _check_and_prompt_requirements(newly_enabled: Set[str]):
|
||||
"""Check if newly enabled toolsets have missing API keys and offer to set them up."""
|
||||
for ts_key in sorted(newly_enabled):
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
continue
|
||||
def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
"""Configure a tool category with provider selection."""
|
||||
icon = cat.get("icon", "")
|
||||
name = cat["name"]
|
||||
providers = cat["providers"]
|
||||
|
||||
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
|
||||
if not missing:
|
||||
continue
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" ⚠ {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
for var, url in missing:
|
||||
if url:
|
||||
print(color(f" {var}", Colors.CYAN) + color(f" ({url})", Colors.DIM))
|
||||
else:
|
||||
print(color(f" {var}", Colors.CYAN))
|
||||
|
||||
print()
|
||||
try:
|
||||
response = input(color(" Set up now? [Y/n] ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Check Python version requirement
|
||||
if cat.get("requires_python"):
|
||||
req = cat["requires_python"]
|
||||
if sys.version_info < req:
|
||||
print()
|
||||
continue
|
||||
_print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})")
|
||||
_print_info(" Upgrade Python and reinstall to enable this tool.")
|
||||
return
|
||||
|
||||
if response in ("", "y", "yes"):
|
||||
for var, url in missing:
|
||||
if url:
|
||||
print(color(f" Get key at: {url}", Colors.DIM))
|
||||
try:
|
||||
import getpass
|
||||
value = getpass.getpass(color(f" {var}: ", Colors.YELLOW))
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
break
|
||||
if value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
print(color(f" ✓ Saved", Colors.GREEN))
|
||||
if len(providers) == 1:
|
||||
# Single provider - configure directly
|
||||
provider = providers[0]
|
||||
print()
|
||||
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
|
||||
if provider.get("tag"):
|
||||
_print_info(f" {provider['tag']}")
|
||||
_configure_provider(provider, config)
|
||||
else:
|
||||
# Multiple providers - let user choose
|
||||
print()
|
||||
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
|
||||
print()
|
||||
|
||||
# Plain text labels only (no ANSI codes in menu items)
|
||||
provider_choices = []
|
||||
for p in providers:
|
||||
tag = f" ({p['tag']})" if p.get("tag") else ""
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
|
||||
else:
|
||||
print(color(f" Skipped", Colors.DIM))
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
|
||||
# Detect current provider as default
|
||||
default_idx = 0
|
||||
for i, p in enumerate(providers):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
default_idx = i
|
||||
break
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
default_idx = i
|
||||
break
|
||||
|
||||
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
|
||||
_configure_provider(providers[provider_idx], config)
|
||||
|
||||
|
||||
def _configure_provider(provider: dict, config: dict):
|
||||
"""Configure a single provider - prompt for API keys and set config."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
|
||||
# Set TTS provider in config if applicable
|
||||
if provider.get("tts_provider"):
|
||||
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
|
||||
# Prompt for each required env var
|
||||
all_configured = True
|
||||
for var in env_vars:
|
||||
existing = get_env_value(var["key"])
|
||||
if existing:
|
||||
_print_success(f" {var['key']}: already configured")
|
||||
# Don't ask to update - this is a new enable flow.
|
||||
# Reconfigure is handled separately.
|
||||
else:
|
||||
print(color(" Skipped — configure later with 'hermes setup'", Colors.DIM))
|
||||
url = var.get("url", "")
|
||||
if url:
|
||||
_print_info(f" Get yours at: {url}")
|
||||
|
||||
default_val = var.get("default", "")
|
||||
if default_val:
|
||||
value = _prompt(f" {var.get('prompt', var['key'])}", default_val)
|
||||
else:
|
||||
value = _prompt(f" {var.get('prompt', var['key'])}", password=True)
|
||||
|
||||
if value:
|
||||
save_env_value(var["key"], value)
|
||||
_print_success(f" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
all_configured = False
|
||||
|
||||
# Run post-setup hooks if needed
|
||||
if provider.get("post_setup") and all_configured:
|
||||
_run_post_setup(provider["post_setup"])
|
||||
|
||||
if all_configured:
|
||||
_print_success(f" {provider['name']} configured!")
|
||||
|
||||
|
||||
def tools_command(args):
|
||||
"""Entry point for `hermes tools`."""
|
||||
def _configure_simple_requirements(ts_key: str):
|
||||
"""Simple fallback for toolsets that just need env vars (no provider selection)."""
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
|
||||
if not missing:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
for var, url in missing:
|
||||
if url:
|
||||
_print_info(f" Get key at: {url}")
|
||||
value = _prompt(f" {var}", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
|
||||
|
||||
def _reconfigure_tool(config: dict):
|
||||
"""Let user reconfigure an existing tool's provider or API key."""
|
||||
# Build list of configurable tools that are currently set up
|
||||
configurable = []
|
||||
for ts_key, ts_label, _ in CONFIGURABLE_TOOLSETS:
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
reqs = TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
if cat or reqs:
|
||||
if _toolset_has_keys(ts_key):
|
||||
configurable.append((ts_key, ts_label))
|
||||
|
||||
if not configurable:
|
||||
_print_info("No configured tools to reconfigure.")
|
||||
return
|
||||
|
||||
choices = [label for _, label in configurable]
|
||||
choices.append("Cancel")
|
||||
|
||||
idx = _prompt_choice(" Which tool would you like to reconfigure?", choices, len(choices) - 1)
|
||||
|
||||
if idx >= len(configurable):
|
||||
return # Cancel
|
||||
|
||||
ts_key, ts_label = configurable[idx]
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
|
||||
if cat:
|
||||
_configure_tool_category_for_reconfig(ts_key, cat, config)
|
||||
else:
|
||||
_reconfigure_simple_requirements(ts_key)
|
||||
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
|
||||
"""Reconfigure a tool category - provider selection + API key update."""
|
||||
icon = cat.get("icon", "")
|
||||
name = cat["name"]
|
||||
providers = cat["providers"]
|
||||
|
||||
if len(providers) == 1:
|
||||
provider = providers[0]
|
||||
print()
|
||||
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
|
||||
_reconfigure_provider(provider, config)
|
||||
else:
|
||||
print()
|
||||
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
|
||||
print()
|
||||
|
||||
provider_choices = []
|
||||
for p in providers:
|
||||
tag = f" ({p['tag']})" if p.get("tag") else ""
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = ""
|
||||
else:
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
|
||||
default_idx = 0
|
||||
for i, p in enumerate(providers):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
default_idx = i
|
||||
break
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
default_idx = i
|
||||
break
|
||||
|
||||
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
|
||||
_reconfigure_provider(providers[provider_idx], config)
|
||||
|
||||
|
||||
def _reconfigure_provider(provider: dict, config: dict):
|
||||
"""Reconfigure a provider - update API keys."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
|
||||
if provider.get("tts_provider"):
|
||||
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
|
||||
_print_success(f" TTS provider set to: {provider['tts_provider']}")
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
|
||||
for var in env_vars:
|
||||
existing = get_env_value(var["key"])
|
||||
if existing:
|
||||
_print_info(f" {var['key']}: configured ({existing[:8]}...)")
|
||||
url = var.get("url", "")
|
||||
if url:
|
||||
_print_info(f" Get yours at: {url}")
|
||||
default_val = var.get("default", "")
|
||||
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
|
||||
if value and value.strip():
|
||||
save_env_value(var["key"], value.strip())
|
||||
_print_success(f" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
|
||||
|
||||
def _reconfigure_simple_requirements(ts_key: str):
|
||||
"""Reconfigure simple env var requirements."""
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label}:", Colors.CYAN))
|
||||
|
||||
for var, url in requirements:
|
||||
existing = get_env_value(var)
|
||||
if existing:
|
||||
_print_info(f" {var}: configured ({existing[:8]}...)")
|
||||
if url:
|
||||
_print_info(f" Get key at: {url}")
|
||||
value = _prompt(f" {var} (Enter to keep current)", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
|
||||
|
||||
# ─── Main Entry Point ─────────────────────────────────────────────────────────
|
||||
|
||||
def tools_command(args=None):
|
||||
"""Entry point for `hermes tools` and `hermes setup tools`."""
|
||||
config = load_config()
|
||||
enabled_platforms = _get_enabled_platforms()
|
||||
|
||||
print()
|
||||
print(color("⚕ Hermes Tool Configuration", Colors.CYAN, Colors.BOLD))
|
||||
print(color(" Enable or disable tools per platform.", Colors.DIM))
|
||||
print(color(" Tools that need API keys will be configured when enabled.", Colors.DIM))
|
||||
print()
|
||||
|
||||
# Build platform choices
|
||||
@@ -380,22 +851,28 @@ def tools_command(args):
|
||||
platform_keys = []
|
||||
for pkey in enabled_platforms:
|
||||
pinfo = PLATFORMS[pkey]
|
||||
# Count currently enabled toolsets
|
||||
current = _get_platform_tools(config, pkey)
|
||||
count = len(current)
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
platform_choices.append(f"Configure {pinfo['label']} ({count}/{total} enabled)")
|
||||
platform_keys.append(pkey)
|
||||
|
||||
platform_choices.append("Done — save and exit")
|
||||
platform_choices.append("Reconfigure an existing tool's provider or API key")
|
||||
platform_choices.append("Done")
|
||||
|
||||
while True:
|
||||
idx = _prompt_choice("Select a platform to configure:", platform_choices, default=0)
|
||||
idx = _prompt_choice("Select an option:", platform_choices, default=0)
|
||||
|
||||
# "Done" selected
|
||||
if idx == len(platform_keys):
|
||||
if idx == len(platform_keys) + 1:
|
||||
break
|
||||
|
||||
# "Reconfigure" selected
|
||||
if idx == len(platform_keys):
|
||||
_reconfigure_tool(config)
|
||||
print()
|
||||
continue
|
||||
|
||||
pkey = platform_keys[idx]
|
||||
pinfo = PLATFORMS[pkey]
|
||||
|
||||
@@ -418,11 +895,15 @@ def tools_command(args):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Prompt for missing API keys on newly enabled toolsets
|
||||
# Configure newly enabled toolsets that need API keys
|
||||
if added:
|
||||
_check_and_prompt_requirements(added)
|
||||
for ts_key in sorted(added):
|
||||
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key):
|
||||
if not _toolset_has_keys(ts_key):
|
||||
_configure_toolset(ts_key, config)
|
||||
|
||||
_save_platform_tools(config, pkey, new_enabled)
|
||||
save_config(config)
|
||||
print(color(f" ✓ Saved {pinfo['label']} configuration", Colors.GREEN))
|
||||
else:
|
||||
print(color(f" No changes to {pinfo['label']}", Colors.DIM))
|
||||
|
||||
233
hermes_state.py
233
hermes_state.py
@@ -24,7 +24,7 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
|
||||
|
||||
SCHEMA_VERSION = 2
|
||||
SCHEMA_VERSION = 4
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
@@ -46,6 +46,7 @@ CREATE TABLE IF NOT EXISTS sessions (
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
title TEXT,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
|
||||
@@ -133,7 +134,33 @@ class SessionDB:
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 2")
|
||||
if current_version < 3:
|
||||
# v3: add title column to sessions
|
||||
try:
|
||||
cursor.execute("ALTER TABLE sessions ADD COLUMN title TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 3")
|
||||
if current_version < 4:
|
||||
# v4: add unique index on title (NULLs allowed, only non-NULL must be unique)
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
|
||||
"ON sessions(title) WHERE title IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 4")
|
||||
|
||||
# Unique title index — always ensure it exists (safe to run after migrations
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
|
||||
"ON sessions(title) WHERE title IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
|
||||
# FTS5 setup (separate because CREATE VIRTUAL TABLE can't be in executescript with IF NOT EXISTS reliably)
|
||||
try:
|
||||
@@ -219,6 +246,210 @@ class SessionDB:
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
# Maximum length for session titles
|
||||
MAX_TITLE_LENGTH = 100
|
||||
|
||||
@staticmethod
|
||||
def sanitize_title(title: Optional[str]) -> Optional[str]:
|
||||
"""Validate and sanitize a session title.
|
||||
|
||||
- Strips leading/trailing whitespace
|
||||
- Removes ASCII control characters (0x00-0x1F, 0x7F) and problematic
|
||||
Unicode control chars (zero-width, RTL/LTR overrides, etc.)
|
||||
- Collapses internal whitespace runs to single spaces
|
||||
- Normalizes empty/whitespace-only strings to None
|
||||
- Enforces MAX_TITLE_LENGTH
|
||||
|
||||
Returns the cleaned title string or None.
|
||||
Raises ValueError if the title exceeds MAX_TITLE_LENGTH after cleaning.
|
||||
"""
|
||||
if not title:
|
||||
return None
|
||||
|
||||
import re
|
||||
|
||||
# Remove ASCII control characters (0x00-0x1F, 0x7F) but keep
|
||||
# whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be
|
||||
# normalized to spaces by the whitespace collapsing step below
|
||||
cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title)
|
||||
|
||||
# Remove problematic Unicode control characters:
|
||||
# - Zero-width chars (U+200B-U+200F, U+FEFF)
|
||||
# - Directional overrides (U+202A-U+202E, U+2066-U+2069)
|
||||
# - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB)
|
||||
cleaned = re.sub(
|
||||
r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]',
|
||||
'', cleaned,
|
||||
)
|
||||
|
||||
# Collapse internal whitespace runs and strip
|
||||
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
||||
|
||||
if not cleaned:
|
||||
return None
|
||||
|
||||
if len(cleaned) > SessionDB.MAX_TITLE_LENGTH:
|
||||
raise ValueError(
|
||||
f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})"
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
def set_session_title(self, session_id: str, title: str) -> bool:
|
||||
"""Set or update a session's title.
|
||||
|
||||
Returns True if session was found and title was set.
|
||||
Raises ValueError if title is already in use by another session,
|
||||
or if the title fails validation (too long, invalid characters).
|
||||
Empty/whitespace-only strings are normalized to None (clearing the title).
|
||||
"""
|
||||
title = self.sanitize_title(title)
|
||||
if title:
|
||||
# Check uniqueness (allow the same session to keep its own title)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||
(title, session_id),
|
||||
)
|
||||
conflict = cursor.fetchone()
|
||||
if conflict:
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||
"""Get the title for a session, or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row["title"] if row else None
|
||||
|
||||
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
|
||||
"""Look up a session by exact title. Returns session dict or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE title = ?", (title,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_by_title(self, title: str) -> Optional[str]:
|
||||
"""Resolve a title to a session ID, preferring the latest in a lineage.
|
||||
|
||||
If the exact title exists, returns that session's ID.
|
||||
If not, searches for "title #N" variants and returns the latest one.
|
||||
If the exact title exists AND numbered variants exist, returns the
|
||||
latest numbered variant (the most recent continuation).
|
||||
"""
|
||||
# First try exact match
|
||||
exact = self.get_session_by_title(title)
|
||||
|
||||
# Also search for numbered variants: "title #2", "title #3", etc.
|
||||
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
||||
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, title, started_at FROM sessions "
|
||||
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
(f"{escaped} #%",),
|
||||
)
|
||||
numbered = cursor.fetchall()
|
||||
|
||||
if numbered:
|
||||
# Return the most recent numbered variant
|
||||
return numbered[0]["id"]
|
||||
elif exact:
|
||||
return exact["id"]
|
||||
return None
|
||||
|
||||
def get_next_title_in_lineage(self, base_title: str) -> str:
|
||||
"""Generate the next title in a lineage (e.g., "my session" → "my session #2").
|
||||
|
||||
Strips any existing " #N" suffix to find the base name, then finds
|
||||
the highest existing number and increments.
|
||||
"""
|
||||
import re
|
||||
# Strip existing #N suffix to find the true base
|
||||
match = re.match(r'^(.*?) #(\d+)$', base_title)
|
||||
if match:
|
||||
base = match.group(1)
|
||||
else:
|
||||
base = base_title
|
||||
|
||||
# Find all existing numbered variants
|
||||
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
|
||||
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
|
||||
(base, f"{escaped} #%"),
|
||||
)
|
||||
existing = [row["title"] for row in cursor.fetchall()]
|
||||
|
||||
if not existing:
|
||||
return base # No conflict, use the base name as-is
|
||||
|
||||
# Find the highest number
|
||||
max_num = 1 # The unnumbered original counts as #1
|
||||
for t in existing:
|
||||
m = re.match(r'^.* #(\d+)$', t)
|
||||
if m:
|
||||
max_num = max(max_num, int(m.group(1)))
|
||||
|
||||
return f"{base} #{max_num + 1}"
|
||||
|
||||
def list_sessions_rich(
|
||||
self,
|
||||
source: str = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions with preview (first user message) and last active timestamp.
|
||||
|
||||
Returns dicts with keys: id, source, model, title, started_at, ended_at,
|
||||
message_count, preview (first 60 chars of first user message),
|
||||
last_active (timestamp of last message).
|
||||
|
||||
Uses a single query with correlated subqueries instead of N+2 queries.
|
||||
"""
|
||||
source_clause = "WHERE s.source = ?" if source else ""
|
||||
query = f"""
|
||||
SELECT s.*,
|
||||
COALESCE(
|
||||
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
|
||||
FROM messages m
|
||||
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
|
||||
ORDER BY m.timestamp, m.id LIMIT 1),
|
||||
''
|
||||
) AS _preview_raw,
|
||||
COALESCE(
|
||||
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
|
||||
s.started_at
|
||||
) AS last_active
|
||||
FROM sessions s
|
||||
{source_clause}
|
||||
ORDER BY s.started_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params = (source, limit, offset) if source else (limit, offset)
|
||||
cursor = self._conn.execute(query, params)
|
||||
sessions = []
|
||||
for row in cursor.fetchall():
|
||||
s = dict(row)
|
||||
# Build the preview from the raw substring
|
||||
raw = s.pop("_preview_raw", "").strip()
|
||||
if raw:
|
||||
text = raw[:60]
|
||||
s["preview"] = text + ("..." if len(raw) > 60 else "")
|
||||
else:
|
||||
s["preview"] = ""
|
||||
sessions.append(s)
|
||||
|
||||
return sessions
|
||||
|
||||
# =========================================================================
|
||||
# Message storage
|
||||
# =========================================================================
|
||||
|
||||
119
hermes_time.py
Normal file
119
hermes_time.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Timezone-aware clock for Hermes.
|
||||
|
||||
Provides a single ``now()`` helper that returns a timezone-aware datetime
|
||||
based on the user's configured IANA timezone (e.g. ``Asia/Kolkata``).
|
||||
|
||||
Resolution order:
|
||||
1. ``HERMES_TIMEZONE`` environment variable
|
||||
2. ``timezone`` key in ``~/.hermes/config.yaml``
|
||||
3. Falls back to the server's local time (``datetime.now().astimezone()``)
|
||||
|
||||
Invalid timezone values log a warning and fall back safely — Hermes never
|
||||
crashes due to a bad timezone string.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone as _tz
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
except ImportError:
|
||||
# Python 3.8 fallback (shouldn't be needed — Hermes requires 3.9+)
|
||||
from backports.zoneinfo import ZoneInfo # type: ignore[no-redef]
|
||||
|
||||
# Cached state — resolved once, reused on every call.
|
||||
# Call reset_cache() to force re-resolution (e.g. after config changes).
|
||||
_cached_tz: Optional[ZoneInfo] = None
|
||||
_cached_tz_name: Optional[str] = None
|
||||
_cache_resolved: bool = False
|
||||
|
||||
|
||||
def _resolve_timezone_name() -> str:
|
||||
"""Read the configured IANA timezone string (or empty string).
|
||||
|
||||
This does file I/O when falling through to config.yaml, so callers
|
||||
should cache the result rather than calling on every ``now()``.
|
||||
"""
|
||||
# 1. Environment variable (highest priority — set by Supervisor, etc.)
|
||||
tz_env = os.getenv("HERMES_TIMEZONE", "").strip()
|
||||
if tz_env:
|
||||
return tz_env
|
||||
|
||||
# 2. config.yaml ``timezone`` key
|
||||
try:
|
||||
import yaml
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
tz_cfg = cfg.get("timezone", "")
|
||||
if isinstance(tz_cfg, str) and tz_cfg.strip():
|
||||
return tz_cfg.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _get_zoneinfo(name: str) -> Optional[ZoneInfo]:
|
||||
"""Validate and return a ZoneInfo, or None if invalid."""
|
||||
if not name:
|
||||
return None
|
||||
try:
|
||||
return ZoneInfo(name)
|
||||
except (KeyError, Exception) as exc:
|
||||
logger.warning(
|
||||
"Invalid timezone '%s': %s. Falling back to server local time.",
|
||||
name, exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_timezone() -> Optional[ZoneInfo]:
|
||||
"""Return the user's configured ZoneInfo, or None (meaning server-local).
|
||||
|
||||
Resolved once and cached. Call ``reset_cache()`` after config changes.
|
||||
"""
|
||||
global _cached_tz, _cached_tz_name, _cache_resolved
|
||||
if not _cache_resolved:
|
||||
_cached_tz_name = _resolve_timezone_name()
|
||||
_cached_tz = _get_zoneinfo(_cached_tz_name)
|
||||
_cache_resolved = True
|
||||
return _cached_tz
|
||||
|
||||
|
||||
def get_timezone_name() -> str:
|
||||
"""Return the IANA name of the configured timezone, or empty string."""
|
||||
global _cached_tz_name, _cache_resolved
|
||||
if not _cache_resolved:
|
||||
get_timezone() # populates cache
|
||||
return _cached_tz_name or ""
|
||||
|
||||
|
||||
def now() -> datetime:
|
||||
"""
|
||||
Return the current time as a timezone-aware datetime.
|
||||
|
||||
If a valid timezone is configured, returns wall-clock time in that zone.
|
||||
Otherwise returns the server's local time (via ``astimezone()``).
|
||||
"""
|
||||
tz = get_timezone()
|
||||
if tz is not None:
|
||||
return datetime.now(tz)
|
||||
# No timezone configured — use server-local (still tz-aware)
|
||||
return datetime.now().astimezone()
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Clear the cached timezone. Used by tests and after config changes."""
|
||||
global _cached_tz, _cached_tz_name, _cache_resolved
|
||||
_cached_tz = None
|
||||
_cached_tz_name = None
|
||||
_cache_resolved = False
|
||||
@@ -149,7 +149,7 @@ class MiniSWERunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-sonnet-4.6",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
env_type: str = "local",
|
||||
@@ -200,13 +200,7 @@ class MiniSWERunner:
|
||||
else:
|
||||
client_kwargs["base_url"] = "https://openrouter.ai/api/v1"
|
||||
|
||||
if base_url and "api.anthropic.com" in base_url.strip().lower():
|
||||
raise ValueError(
|
||||
"Anthropic's native /v1/messages API is not supported yet (planned for a future release). "
|
||||
"Hermes currently requires OpenAI-compatible /chat/completions endpoints. "
|
||||
"To use Claude models now, route through OpenRouter (OPENROUTER_API_KEY) "
|
||||
"or any OpenAI-compatible proxy that wraps the Anthropic API."
|
||||
)
|
||||
|
||||
|
||||
# Handle API key - OpenRouter is the primary provider
|
||||
if api_key:
|
||||
|
||||
@@ -225,6 +225,18 @@ def get_tool_definitions(
|
||||
# Ask the registry for schemas (only returns tools whose check_fn passes)
|
||||
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
|
||||
|
||||
# Rebuild execute_code schema to only list sandbox tools that are actually
|
||||
# enabled. Without this, the model sees "web_search is available in
|
||||
# execute_code" even when the user disabled the web toolset (#560-discord).
|
||||
if "execute_code" in tools_to_include:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "execute_code":
|
||||
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
|
||||
break
|
||||
|
||||
if not quiet_mode:
|
||||
if filtered_tools:
|
||||
tool_names = [t["function"]["name"] for t in filtered_tools]
|
||||
|
||||
24
optional-skills/DESCRIPTION.md
Normal file
24
optional-skills/DESCRIPTION.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Optional Skills
|
||||
|
||||
Official skills maintained by Nous Research that are **not activated by default**.
|
||||
|
||||
These skills ship with the hermes-agent repository but are not copied to
|
||||
`~/.hermes/skills/` during setup. They are discoverable via the Skills Hub:
|
||||
|
||||
```bash
|
||||
hermes skills browse # browse all skills, official shown first
|
||||
hermes skills browse --source official # browse only official optional skills
|
||||
hermes skills search <query> # finds optional skills labeled "official"
|
||||
hermes skills install <identifier> # copies to ~/.hermes/skills/ and activates
|
||||
```
|
||||
|
||||
## Why optional?
|
||||
|
||||
Some skills are useful but not broadly needed by every user:
|
||||
|
||||
- **Niche integrations** — specific paid services, specialized tools
|
||||
- **Experimental features** — promising but not yet proven
|
||||
- **Heavyweight dependencies** — require significant setup (API keys, installs)
|
||||
|
||||
By keeping them optional, we keep the default skill set lean while still
|
||||
providing curated, tested, official skills for users who want them.
|
||||
2
optional-skills/autonomous-ai-agents/DESCRIPTION.md
Normal file
2
optional-skills/autonomous-ai-agents/DESCRIPTION.md
Normal file
@@ -0,0 +1,2 @@
|
||||
Optional autonomous AI agent integrations — external coding agent CLIs
|
||||
that can be delegated to for independent coding tasks.
|
||||
143
optional-skills/autonomous-ai-agents/blackbox/SKILL.md
Normal file
143
optional-skills/autonomous-ai-agents/blackbox/SKILL.md
Normal file
@@ -0,0 +1,143 @@
|
||||
---
|
||||
name: blackbox
|
||||
description: Delegate coding tasks to Blackbox AI CLI agent. Multi-model agent with built-in judge that runs tasks through multiple LLMs and picks the best result. Requires the blackbox CLI and a Blackbox AI API key.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent (Nous Research)
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Coding-Agent, Blackbox, Multi-Agent, Judge, Multi-Model]
|
||||
related_skills: [claude-code, codex, hermes-agent]
|
||||
---
|
||||
|
||||
# Blackbox CLI
|
||||
|
||||
Delegate coding tasks to [Blackbox AI](https://www.blackbox.ai/) via the Hermes terminal. Blackbox is a multi-model coding agent CLI that dispatches tasks to multiple LLMs (Claude, Codex, Gemini, Blackbox Pro) and uses a judge to select the best implementation.
|
||||
|
||||
The CLI is [open-source](https://github.com/blackboxaicode/cli) (GPL-3.0, TypeScript, forked from Gemini CLI) and supports interactive sessions, non-interactive one-shots, checkpointing, MCP, and vision model switching.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Node.js 20+ installed
|
||||
- Blackbox CLI installed: `npm install -g @blackboxai/cli`
|
||||
- Or install from source:
|
||||
```
|
||||
git clone https://github.com/blackboxaicode/cli.git
|
||||
cd cli && npm install && npm install -g .
|
||||
```
|
||||
- API key from [app.blackbox.ai/dashboard](https://app.blackbox.ai/dashboard)
|
||||
- Configured: run `blackbox configure` and enter your API key
|
||||
- Use `pty=true` in terminal calls — Blackbox CLI is an interactive terminal app
|
||||
|
||||
## One-Shot Tasks
|
||||
|
||||
```
|
||||
terminal(command="blackbox --prompt 'Add JWT authentication with refresh tokens to the Express API'", workdir="/path/to/project", pty=true)
|
||||
```
|
||||
|
||||
For quick scratch work:
|
||||
```
|
||||
terminal(command="cd $(mktemp -d) && git init && blackbox --prompt 'Build a REST API for todos with SQLite'", pty=true)
|
||||
```
|
||||
|
||||
## Background Mode (Long Tasks)
|
||||
|
||||
For tasks that take minutes, use background mode so you can monitor progress:
|
||||
|
||||
```
|
||||
# Start in background with PTY
|
||||
terminal(command="blackbox --prompt 'Refactor the auth module to use OAuth 2.0'", workdir="~/project", background=true, pty=true)
|
||||
# Returns session_id
|
||||
|
||||
# Monitor progress
|
||||
process(action="poll", session_id="<id>")
|
||||
process(action="log", session_id="<id>")
|
||||
|
||||
# Send input if Blackbox asks a question
|
||||
process(action="submit", session_id="<id>", data="yes")
|
||||
|
||||
# Kill if needed
|
||||
process(action="kill", session_id="<id>")
|
||||
```
|
||||
|
||||
## Checkpoints & Resume
|
||||
|
||||
Blackbox CLI has built-in checkpoint support for pausing and resuming tasks:
|
||||
|
||||
```
|
||||
# After a task completes, Blackbox shows a checkpoint tag
|
||||
# Resume with a follow-up task:
|
||||
terminal(command="blackbox --resume-checkpoint 'task-abc123-2026-03-06' --prompt 'Now add rate limiting to the endpoints'", workdir="~/project", pty=true)
|
||||
```
|
||||
|
||||
## Session Commands
|
||||
|
||||
During an interactive session, use these commands:
|
||||
|
||||
| Command | Effect |
|
||||
|---------|--------|
|
||||
| `/compress` | Shrink conversation history to save tokens |
|
||||
| `/clear` | Wipe history and start fresh |
|
||||
| `/stats` | View current token usage |
|
||||
| `Ctrl+C` | Cancel current operation |
|
||||
|
||||
## PR Reviews
|
||||
|
||||
Clone to a temp directory to avoid modifying the working tree:
|
||||
|
||||
```
|
||||
terminal(command="REVIEW=$(mktemp -d) && git clone https://github.com/user/repo.git $REVIEW && cd $REVIEW && gh pr checkout 42 && blackbox --prompt 'Review this PR against main. Check for bugs, security issues, and code quality.'", pty=true)
|
||||
```
|
||||
|
||||
## Parallel Work
|
||||
|
||||
Spawn multiple Blackbox instances for independent tasks:
|
||||
|
||||
```
|
||||
terminal(command="blackbox --prompt 'Fix the login bug'", workdir="/tmp/issue-1", background=true, pty=true)
|
||||
terminal(command="blackbox --prompt 'Add unit tests for auth'", workdir="/tmp/issue-2", background=true, pty=true)
|
||||
|
||||
# Monitor all
|
||||
process(action="list")
|
||||
```
|
||||
|
||||
## Multi-Model Mode
|
||||
|
||||
Blackbox's unique feature is running the same task through multiple models and judging the results. Configure which models to use via `blackbox configure` — select multiple providers to enable the Chairman/judge workflow where the CLI evaluates outputs from different models and picks the best one.
|
||||
|
||||
## Key Flags
|
||||
|
||||
| Flag | Effect |
|
||||
|------|--------|
|
||||
| `--prompt "task"` | Non-interactive one-shot execution |
|
||||
| `--resume-checkpoint "tag"` | Resume from a saved checkpoint |
|
||||
| `--yolo` | Auto-approve all actions and model switches |
|
||||
| `blackbox session` | Start interactive chat session |
|
||||
| `blackbox configure` | Change settings, providers, models |
|
||||
| `blackbox info` | Display system information |
|
||||
|
||||
## Vision Support
|
||||
|
||||
Blackbox automatically detects images in input and can switch to multimodal analysis. VLM modes:
|
||||
- `"once"` — Switch model for current query only
|
||||
- `"session"` — Switch for entire session
|
||||
- `"persist"` — Stay on current model (no switch)
|
||||
|
||||
## Token Limits
|
||||
|
||||
Control token usage via `.blackboxcli/settings.json`:
|
||||
```json
|
||||
{
|
||||
"sessionTokenLimit": 32000
|
||||
}
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
1. **Always use `pty=true`** — Blackbox CLI is an interactive terminal app and will hang without a PTY
|
||||
2. **Use `workdir`** — keep the agent focused on the right directory
|
||||
3. **Background for long tasks** — use `background=true` and monitor with `process` tool
|
||||
4. **Don't interfere** — monitor with `poll`/`log`, don't kill sessions because they're slow
|
||||
5. **Report results** — after completion, check what changed and summarize for the user
|
||||
6. **Credits cost money** — Blackbox uses a credit-based system; multi-model mode consumes credits faster
|
||||
7. **Check prerequisites** — verify `blackbox` CLI is installed before attempting delegation
|
||||
441
optional-skills/research/qmd/SKILL.md
Normal file
441
optional-skills/research/qmd/SKILL.md
Normal file
@@ -0,0 +1,441 @@
|
||||
---
|
||||
name: qmd
|
||||
description: Search personal knowledge bases, notes, docs, and meeting transcripts locally using qmd — a hybrid retrieval engine with BM25, vector search, and LLM reranking. Supports CLI and MCP integration.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent + Teknium
|
||||
license: MIT
|
||||
platforms: [macos, linux]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Search, Knowledge-Base, RAG, Notes, MCP, Local-AI]
|
||||
related_skills: [obsidian, native-mcp, arxiv]
|
||||
---
|
||||
|
||||
# QMD — Query Markup Documents
|
||||
|
||||
Local, on-device search engine for personal knowledge bases. Indexes markdown
|
||||
notes, meeting transcripts, documentation, and any text-based files, then
|
||||
provides hybrid search combining keyword matching, semantic understanding, and
|
||||
LLM-powered reranking — all running locally with no cloud dependencies.
|
||||
|
||||
Created by [Tobi Lütke](https://github.com/tobi/qmd). MIT licensed.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to search their notes, docs, knowledge base, or meeting transcripts
|
||||
- User wants to find something across a large collection of markdown/text files
|
||||
- User wants semantic search ("find notes about X concept") not just keyword grep
|
||||
- User has already set up qmd collections and wants to query them
|
||||
- User asks to set up a local knowledge base or document search system
|
||||
- Keywords: "search my notes", "find in my docs", "knowledge base", "qmd"
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Node.js >= 22 (required)
|
||||
|
||||
```bash
|
||||
# Check version
|
||||
node --version # must be >= 22
|
||||
|
||||
# macOS — install or upgrade via Homebrew
|
||||
brew install node@22
|
||||
|
||||
# Linux — use NodeSource or nvm
|
||||
curl -fsSL https://deb.nodesource.com/setup_22.x | sudo -E bash -
|
||||
sudo apt-get install -y nodejs
|
||||
# or with nvm:
|
||||
nvm install 22 && nvm use 22
|
||||
```
|
||||
|
||||
### SQLite with Extension Support (macOS only)
|
||||
|
||||
macOS system SQLite lacks extension loading. Install via Homebrew:
|
||||
|
||||
```bash
|
||||
brew install sqlite
|
||||
```
|
||||
|
||||
### Install qmd
|
||||
|
||||
```bash
|
||||
npm install -g @tobilu/qmd
|
||||
# or with Bun:
|
||||
bun install -g @tobilu/qmd
|
||||
```
|
||||
|
||||
First run auto-downloads 3 local GGUF models (~2GB total):
|
||||
|
||||
| Model | Purpose | Size |
|
||||
|-------|---------|------|
|
||||
| embeddinggemma-300M-Q8_0 | Vector embeddings | ~300MB |
|
||||
| qwen3-reranker-0.6b-q8_0 | Result reranking | ~640MB |
|
||||
| qmd-query-expansion-1.7B | Query expansion | ~1.1GB |
|
||||
|
||||
### Verify Installation
|
||||
|
||||
```bash
|
||||
qmd --version
|
||||
qmd status
|
||||
```
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Command | What It Does | Speed |
|
||||
|---------|-------------|-------|
|
||||
| `qmd search "query"` | BM25 keyword search (no models) | ~0.2s |
|
||||
| `qmd vsearch "query"` | Semantic vector search (1 model) | ~3s |
|
||||
| `qmd query "query"` | Hybrid + reranking (all 3 models) | ~2-3s warm, ~19s cold |
|
||||
| `qmd get <docid>` | Retrieve full document content | instant |
|
||||
| `qmd multi-get "glob"` | Retrieve multiple files | instant |
|
||||
| `qmd collection add <path> --name <n>` | Add a directory as a collection | instant |
|
||||
| `qmd context add <path> "description"` | Add context metadata to improve retrieval | instant |
|
||||
| `qmd embed` | Generate/update vector embeddings | varies |
|
||||
| `qmd status` | Show index health and collection info | instant |
|
||||
| `qmd mcp` | Start MCP server (stdio) | persistent |
|
||||
| `qmd mcp --http --daemon` | Start MCP server (HTTP, warm models) | persistent |
|
||||
|
||||
## Setup Workflow
|
||||
|
||||
### 1. Add Collections
|
||||
|
||||
Point qmd at directories containing your documents:
|
||||
|
||||
```bash
|
||||
# Add a notes directory
|
||||
qmd collection add ~/notes --name notes
|
||||
|
||||
# Add project docs
|
||||
qmd collection add ~/projects/myproject/docs --name project-docs
|
||||
|
||||
# Add meeting transcripts
|
||||
qmd collection add ~/meetings --name meetings
|
||||
|
||||
# List all collections
|
||||
qmd collection list
|
||||
```
|
||||
|
||||
### 2. Add Context Descriptions
|
||||
|
||||
Context metadata helps the search engine understand what each collection
|
||||
contains. This significantly improves retrieval quality:
|
||||
|
||||
```bash
|
||||
qmd context add qmd://notes "Personal notes, ideas, and journal entries"
|
||||
qmd context add qmd://project-docs "Technical documentation for the main project"
|
||||
qmd context add qmd://meetings "Meeting transcripts and action items from team syncs"
|
||||
```
|
||||
|
||||
### 3. Generate Embeddings
|
||||
|
||||
```bash
|
||||
qmd embed
|
||||
```
|
||||
|
||||
This processes all documents in all collections and generates vector
|
||||
embeddings. Re-run after adding new documents or collections.
|
||||
|
||||
### 4. Verify
|
||||
|
||||
```bash
|
||||
qmd status # shows index health, collection stats, model info
|
||||
```
|
||||
|
||||
## Search Patterns
|
||||
|
||||
### Fast Keyword Search (BM25)
|
||||
|
||||
Best for: exact terms, code identifiers, names, known phrases.
|
||||
No models loaded — near-instant results.
|
||||
|
||||
```bash
|
||||
qmd search "authentication middleware"
|
||||
qmd search "handleError async"
|
||||
```
|
||||
|
||||
### Semantic Vector Search
|
||||
|
||||
Best for: natural language questions, conceptual queries.
|
||||
Loads embedding model (~3s first query).
|
||||
|
||||
```bash
|
||||
qmd vsearch "how does the rate limiter handle burst traffic"
|
||||
qmd vsearch "ideas for improving onboarding flow"
|
||||
```
|
||||
|
||||
### Hybrid Search with Reranking (Best Quality)
|
||||
|
||||
Best for: important queries where quality matters most.
|
||||
Uses all 3 models — query expansion, parallel BM25+vector, reranking.
|
||||
|
||||
```bash
|
||||
qmd query "what decisions were made about the database migration"
|
||||
```
|
||||
|
||||
### Structured Multi-Mode Queries
|
||||
|
||||
Combine different search types in a single query for precision:
|
||||
|
||||
```bash
|
||||
# BM25 for exact term + vector for concept
|
||||
qmd query $'lex: rate limiter\nvec: how does throttling work under load'
|
||||
|
||||
# With query expansion
|
||||
qmd query $'expand: database migration plan\nlex: "schema change"'
|
||||
```
|
||||
|
||||
### Query Syntax (lex/BM25 mode)
|
||||
|
||||
| Syntax | Effect | Example |
|
||||
|--------|--------|---------|
|
||||
| `term` | Prefix match | `perf` matches "performance" |
|
||||
| `"phrase"` | Exact phrase | `"rate limiter"` |
|
||||
| `-term` | Exclude term | `performance -sports` |
|
||||
|
||||
### HyDE (Hypothetical Document Embeddings)
|
||||
|
||||
For complex topics, write what you expect the answer to look like:
|
||||
|
||||
```bash
|
||||
qmd query $'hyde: The migration plan involves three phases. First, we add the new columns without dropping the old ones. Then we backfill data. Finally we cut over and remove legacy columns.'
|
||||
```
|
||||
|
||||
### Scoping to Collections
|
||||
|
||||
```bash
|
||||
qmd search "query" --collection notes
|
||||
qmd query "query" --collection project-docs
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
```bash
|
||||
qmd search "query" --json # JSON output (best for parsing)
|
||||
qmd search "query" --limit 5 # Limit results
|
||||
qmd get "#abc123" # Get by document ID
|
||||
qmd get "path/to/file.md" # Get by file path
|
||||
qmd get "file.md:50" -l 100 # Get specific line range
|
||||
qmd multi-get "journals/*.md" --json # Batch retrieve by glob
|
||||
```
|
||||
|
||||
## MCP Integration (Recommended)
|
||||
|
||||
qmd exposes an MCP server that provides search tools directly to
|
||||
Hermes Agent via the native MCP client. This is the preferred
|
||||
integration — once configured, the agent gets qmd tools automatically
|
||||
without needing to load this skill.
|
||||
|
||||
### Option A: Stdio Mode (Simple)
|
||||
|
||||
Add to `~/.hermes/config.yaml`:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
qmd:
|
||||
command: "qmd"
|
||||
args: ["mcp"]
|
||||
timeout: 30
|
||||
connect_timeout: 45
|
||||
```
|
||||
|
||||
This registers tools: `mcp_qmd_search`, `mcp_qmd_vsearch`,
|
||||
`mcp_qmd_deep_search`, `mcp_qmd_get`, `mcp_qmd_status`.
|
||||
|
||||
**Tradeoff:** Models load on first search call (~19s cold start),
|
||||
then stay warm for the session. Acceptable for occasional use.
|
||||
|
||||
### Option B: HTTP Daemon Mode (Fast, Recommended for Heavy Use)
|
||||
|
||||
Start the qmd daemon separately — it keeps models warm in memory:
|
||||
|
||||
```bash
|
||||
# Start daemon (persists across agent restarts)
|
||||
qmd mcp --http --daemon
|
||||
|
||||
# Runs on http://localhost:8181 by default
|
||||
```
|
||||
|
||||
Then configure Hermes Agent to connect via HTTP:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
qmd:
|
||||
url: "http://localhost:8181/mcp"
|
||||
timeout: 30
|
||||
```
|
||||
|
||||
**Tradeoff:** Uses ~2GB RAM while running, but every query is fast
|
||||
(~2-3s). Best for users who search frequently.
|
||||
|
||||
### Keeping the Daemon Running
|
||||
|
||||
#### macOS (launchd)
|
||||
|
||||
```bash
|
||||
cat > ~/Library/LaunchAgents/com.qmd.daemon.plist << 'EOF'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
|
||||
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.qmd.daemon</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>qmd</string>
|
||||
<string>mcp</string>
|
||||
<string>--http</string>
|
||||
<string>--daemon</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/tmp/qmd-daemon.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/tmp/qmd-daemon.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
|
||||
launchctl load ~/Library/LaunchAgents/com.qmd.daemon.plist
|
||||
```
|
||||
|
||||
#### Linux (systemd user service)
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.config/systemd/user
|
||||
|
||||
cat > ~/.config/systemd/user/qmd-daemon.service << 'EOF'
|
||||
[Unit]
|
||||
Description=QMD MCP Daemon
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
ExecStart=qmd mcp --http --daemon
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
Environment=PATH=/usr/local/bin:/usr/bin:/bin
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
EOF
|
||||
|
||||
systemctl --user daemon-reload
|
||||
systemctl --user enable --now qmd-daemon
|
||||
systemctl --user status qmd-daemon
|
||||
```
|
||||
|
||||
### MCP Tools Reference
|
||||
|
||||
Once connected, these tools are available as `mcp_qmd_*`:
|
||||
|
||||
| MCP Tool | Maps To | Description |
|
||||
|----------|---------|-------------|
|
||||
| `mcp_qmd_search` | `qmd search` | BM25 keyword search |
|
||||
| `mcp_qmd_vsearch` | `qmd vsearch` | Semantic vector search |
|
||||
| `mcp_qmd_deep_search` | `qmd query` | Hybrid search + reranking |
|
||||
| `mcp_qmd_get` | `qmd get` | Retrieve document by ID or path |
|
||||
| `mcp_qmd_status` | `qmd status` | Index health and stats |
|
||||
|
||||
The MCP tools accept structured JSON queries for multi-mode search:
|
||||
|
||||
```json
|
||||
{
|
||||
"searches": [
|
||||
{"type": "lex", "query": "authentication middleware"},
|
||||
{"type": "vec", "query": "how user login is verified"}
|
||||
],
|
||||
"collections": ["project-docs"],
|
||||
"limit": 10
|
||||
}
|
||||
```
|
||||
|
||||
## CLI Usage (Without MCP)
|
||||
|
||||
When MCP is not configured, use qmd directly via terminal:
|
||||
|
||||
```
|
||||
terminal(command="qmd query 'what was decided about the API redesign' --json", timeout=30)
|
||||
```
|
||||
|
||||
For setup and management tasks, always use terminal:
|
||||
|
||||
```
|
||||
terminal(command="qmd collection add ~/Documents/notes --name notes")
|
||||
terminal(command="qmd context add qmd://notes 'Personal research notes and ideas'")
|
||||
terminal(command="qmd embed")
|
||||
terminal(command="qmd status")
|
||||
```
|
||||
|
||||
## How the Search Pipeline Works
|
||||
|
||||
Understanding the internals helps choose the right search mode:
|
||||
|
||||
1. **Query Expansion** — A fine-tuned 1.7B model generates 2 alternative
|
||||
queries. The original gets 2x weight in fusion.
|
||||
2. **Parallel Retrieval** — BM25 (SQLite FTS5) and vector search run
|
||||
simultaneously across all query variants.
|
||||
3. **RRF Fusion** — Reciprocal Rank Fusion (k=60) merges results.
|
||||
Top-rank bonus: #1 gets +0.05, #2-3 get +0.02.
|
||||
4. **LLM Reranking** — qwen3-reranker scores top 30 candidates (0.0-1.0).
|
||||
5. **Position-Aware Blending** — Ranks 1-3: 75% retrieval / 25% reranker.
|
||||
Ranks 4-10: 60/40. Ranks 11+: 40/60 (trusts reranker more for long tail).
|
||||
|
||||
**Smart Chunking:** Documents are split at natural break points (headings,
|
||||
code blocks, blank lines) targeting ~900 tokens with 15% overlap. Code
|
||||
blocks are never split mid-block.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always add context descriptions** — `qmd context add` dramatically
|
||||
improves retrieval accuracy. Describe what each collection contains.
|
||||
2. **Re-embed after adding documents** — `qmd embed` must be re-run when
|
||||
new files are added to collections.
|
||||
3. **Use `qmd search` for speed** — when you need fast keyword lookup
|
||||
(code identifiers, exact names), BM25 is instant and needs no models.
|
||||
4. **Use `qmd query` for quality** — when the question is conceptual or
|
||||
the user needs the best possible results, use hybrid search.
|
||||
5. **Prefer MCP integration** — once configured, the agent gets native
|
||||
tools without needing to load this skill each time.
|
||||
6. **Daemon mode for frequent users** — if the user searches their
|
||||
knowledge base regularly, recommend the HTTP daemon setup.
|
||||
7. **First query in structured search gets 2x weight** — put the most
|
||||
important/certain query first when combining lex and vec.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Models downloading on first run"
|
||||
Normal — qmd auto-downloads ~2GB of GGUF models on first use.
|
||||
This is a one-time operation.
|
||||
|
||||
### Cold start latency (~19s)
|
||||
This happens when models aren't loaded in memory. Solutions:
|
||||
- Use HTTP daemon mode (`qmd mcp --http --daemon`) to keep warm
|
||||
- Use `qmd search` (BM25 only) when models aren't needed
|
||||
- MCP stdio mode loads models on first search, stays warm for session
|
||||
|
||||
### macOS: "unable to load extension"
|
||||
Install Homebrew SQLite: `brew install sqlite`
|
||||
Then ensure it's on PATH before system SQLite.
|
||||
|
||||
### "No collections found"
|
||||
Run `qmd collection add <path> --name <name>` to add directories,
|
||||
then `qmd embed` to index them.
|
||||
|
||||
### Embedding model override (CJK/multilingual)
|
||||
Set `QMD_EMBED_MODEL` environment variable for non-English content:
|
||||
```bash
|
||||
export QMD_EMBED_MODEL="your-multilingual-model"
|
||||
```
|
||||
|
||||
## Data Storage
|
||||
|
||||
- **Index & vectors:** `~/.cache/qmd/index.sqlite`
|
||||
- **Models:** Auto-downloaded to local cache on first run
|
||||
- **No cloud dependencies** — everything runs locally
|
||||
|
||||
## References
|
||||
|
||||
- [GitHub: tobi/qmd](https://github.com/tobi/qmd)
|
||||
- [QMD Changelog](https://github.com/tobi/qmd/blob/main/CHANGELOG.md)
|
||||
@@ -5,9 +5,9 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "hermes-agent"
|
||||
version = "0.1.0"
|
||||
description = "AI agent with advanced tool-calling and toolsets"
|
||||
description = "The self-improving AI agent — creates skills from experience, improves them during use, and runs anywhere"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.11"
|
||||
authors = [{ name = "Nous Research" }]
|
||||
license = { text = "MIT" }
|
||||
dependencies = [
|
||||
@@ -39,6 +39,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
modal = ["swe-rex[modal]>=1.4.0"]
|
||||
daytona = ["daytona>=0.148.0"]
|
||||
dev = ["pytest", "pytest-asyncio"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
cron = ["croniter"]
|
||||
@@ -49,8 +50,10 @@ pty = ["ptyprocess>=0.7.0"]
|
||||
honcho = ["honcho-ai>=2.0.1"]
|
||||
mcp = ["mcp>=1.2.0"]
|
||||
homeassistant = ["aiohttp>=3.9.0"]
|
||||
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git"]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[daytona]",
|
||||
"hermes-agent[messaging]",
|
||||
"hermes-agent[cron]",
|
||||
"hermes-agent[cli]",
|
||||
|
||||
282
run_agent.py
282
run_agent.py
@@ -82,6 +82,8 @@ from agent.prompt_builder import (
|
||||
from agent.model_metadata import (
|
||||
fetch_model_metadata, get_model_context_length,
|
||||
estimate_tokens_rough, estimate_messages_tokens_rough,
|
||||
get_next_probe_tier, parse_context_limit_from_error,
|
||||
save_context_length,
|
||||
)
|
||||
from agent.context_compressor import ContextCompressor
|
||||
from agent.prompt_caching import apply_anthropic_cache_control
|
||||
@@ -97,6 +99,46 @@ from agent.trajectory import (
|
||||
)
|
||||
|
||||
|
||||
class IterationBudget:
|
||||
"""Thread-safe shared iteration counter for parent and child agents.
|
||||
|
||||
Tracks total LLM-call iterations consumed across a parent agent and all
|
||||
its subagents. A single ``IterationBudget`` is created by the parent
|
||||
and passed to every child so they share the same cap.
|
||||
|
||||
``execute_code`` (programmatic tool calling) iterations are refunded via
|
||||
:meth:`refund` so they don't eat into the budget.
|
||||
"""
|
||||
|
||||
def __init__(self, max_total: int):
|
||||
self.max_total = max_total
|
||||
self._used = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def consume(self) -> bool:
|
||||
"""Try to consume one iteration. Returns True if allowed."""
|
||||
with self._lock:
|
||||
if self._used >= self.max_total:
|
||||
return False
|
||||
self._used += 1
|
||||
return True
|
||||
|
||||
def refund(self) -> None:
|
||||
"""Give back one iteration (e.g. for execute_code turns)."""
|
||||
with self._lock:
|
||||
if self._used > 0:
|
||||
self._used -= 1
|
||||
|
||||
@property
|
||||
def used(self) -> int:
|
||||
return self._used
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
with self._lock:
|
||||
return max(0, self.max_total - self._used)
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
AI Agent with tool calling capabilities.
|
||||
@@ -112,7 +154,7 @@ class AIAgent:
|
||||
provider: str = None,
|
||||
api_mode: str = None,
|
||||
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
|
||||
max_iterations: int = 60, # Default tool-calling iterations
|
||||
max_iterations: int = 90, # Default tool-calling iterations (shared with subagents)
|
||||
tool_delay: float = 1.0,
|
||||
enabled_toolsets: List[str] = None,
|
||||
disabled_toolsets: List[str] = None,
|
||||
@@ -140,6 +182,7 @@ class AIAgent:
|
||||
skip_memory: bool = False,
|
||||
session_db=None,
|
||||
honcho_session_key: str = None,
|
||||
iteration_budget: "IterationBudget" = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -150,7 +193,7 @@ class AIAgent:
|
||||
provider (str): Provider identifier (optional; used for telemetry/routing hints)
|
||||
api_mode (str): API mode override: "chat_completions" or "codex_responses"
|
||||
model (str): Model name to use (default: "anthropic/claude-opus-4.6")
|
||||
max_iterations (int): Maximum number of tool calling iterations (default: 60)
|
||||
max_iterations (int): Maximum number of tool calling iterations (default: 90)
|
||||
tool_delay (float): Delay between tool calls in seconds (default: 1.0)
|
||||
enabled_toolsets (List[str]): Only enable tools from these toolsets (optional)
|
||||
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
|
||||
@@ -170,7 +213,7 @@ class AIAgent:
|
||||
Provided by the platform layer (CLI or gateway). If None, the clarify tool returns an error.
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning configuration override (e.g. {"effort": "none"} to disable thinking).
|
||||
If None, defaults to {"enabled": True, "effort": "xhigh"} for OpenRouter. Set to disable/customize reasoning.
|
||||
If None, defaults to {"enabled": True, "effort": "medium"} for OpenRouter. Set to disable/customize reasoning.
|
||||
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
|
||||
Useful for injecting a few-shot example or priming the model's response style.
|
||||
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
|
||||
@@ -184,6 +227,9 @@ class AIAgent:
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
# Shared iteration budget — parent creates, children inherit.
|
||||
# Consumed by every LLM turn across parent + all subagents.
|
||||
self.iteration_budget = iteration_budget or IterationBudget(max_iterations)
|
||||
self.tool_delay = tool_delay
|
||||
self.save_trajectories = save_trajectories
|
||||
self.verbose_logging = verbose_logging
|
||||
@@ -207,13 +253,7 @@ class AIAgent:
|
||||
self.provider = "openai-codex"
|
||||
else:
|
||||
self.api_mode = "chat_completions"
|
||||
if base_url and "api.anthropic.com" in base_url.strip().lower():
|
||||
raise ValueError(
|
||||
"Anthropic's native /v1/messages API is not supported yet (planned for a future release). "
|
||||
"Hermes currently requires OpenAI-compatible /chat/completions endpoints. "
|
||||
"To use Claude models now, route through OpenRouter (OPENROUTER_API_KEY) "
|
||||
"or any OpenAI-compatible proxy that wraps the Anthropic API."
|
||||
)
|
||||
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
@@ -241,7 +281,7 @@ class AIAgent:
|
||||
|
||||
# Model response configuration
|
||||
self.max_tokens = max_tokens # None = use model default
|
||||
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
|
||||
self.reasoning_config = reasoning_config # None = use default (medium for OpenRouter)
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
||||
@@ -343,6 +383,12 @@ class AIAgent:
|
||||
"X-OpenRouter-Title": "Hermes Agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
}
|
||||
elif "api.kimi.com" in effective_base.lower():
|
||||
# Kimi Code API requires a recognized coding-agent User-Agent
|
||||
# (see https://github.com/MoonshotAI/kimi-cli)
|
||||
client_kwargs["default_headers"] = {
|
||||
"User-Agent": "KimiCLI/1.0",
|
||||
}
|
||||
|
||||
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
||||
try:
|
||||
@@ -536,6 +582,7 @@ class AIAgent:
|
||||
summary_target_tokens=500,
|
||||
summary_model_override=compression_summary_model,
|
||||
quiet_mode=self.quiet_mode,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
self.compression_enabled = compression_enabled
|
||||
self._user_turn_count = 0
|
||||
@@ -1360,7 +1407,8 @@ class AIAgent:
|
||||
if context_files_prompt:
|
||||
prompt_parts.append(context_files_prompt)
|
||||
|
||||
now = datetime.now()
|
||||
from hermes_time import now as _hermes_now
|
||||
now = _hermes_now()
|
||||
prompt_parts.append(
|
||||
f"Conversation started: {now.strftime('%A, %B %d, %Y %I:%M %p')}"
|
||||
)
|
||||
@@ -2015,6 +2063,49 @@ class AIAgent:
|
||||
|
||||
return True
|
||||
|
||||
def _try_refresh_nous_client_credentials(self, *, force: bool = True) -> bool:
|
||||
if self.api_mode != "chat_completions" or self.provider != "nous":
|
||||
return False
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import resolve_nous_runtime_credentials
|
||||
|
||||
creds = resolve_nous_runtime_credentials(
|
||||
min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))),
|
||||
timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")),
|
||||
force_mint=force,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Nous credential refresh failed: %s", exc)
|
||||
return False
|
||||
|
||||
api_key = creds.get("api_key")
|
||||
base_url = creds.get("base_url")
|
||||
if not isinstance(api_key, str) or not api_key.strip():
|
||||
return False
|
||||
if not isinstance(base_url, str) or not base_url.strip():
|
||||
return False
|
||||
|
||||
self.api_key = api_key.strip()
|
||||
self.base_url = base_url.strip().rstrip("/")
|
||||
self._client_kwargs["api_key"] = self.api_key
|
||||
self._client_kwargs["base_url"] = self.base_url
|
||||
# Nous requests should not inherit OpenRouter-only attribution headers.
|
||||
self._client_kwargs.pop("default_headers", None)
|
||||
|
||||
try:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to rebuild OpenAI client after Nous refresh: %s", exc)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _interruptible_api_call(self, api_kwargs: dict):
|
||||
"""
|
||||
Run the API call in a background thread so the main conversation loop
|
||||
@@ -2066,8 +2157,8 @@ class AIAgent:
|
||||
if not instructions:
|
||||
instructions = DEFAULT_AGENT_IDENTITY
|
||||
|
||||
# Resolve reasoning effort: config > default (xhigh)
|
||||
reasoning_effort = "xhigh"
|
||||
# Resolve reasoning effort: config > default (medium)
|
||||
reasoning_effort = "medium"
|
||||
reasoning_enabled = True
|
||||
if self.reasoning_config and isinstance(self.reasoning_config, dict):
|
||||
if self.reasoning_config.get("enabled") is False:
|
||||
@@ -2133,7 +2224,7 @@ class AIAgent:
|
||||
else:
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
"effort": "medium"
|
||||
}
|
||||
|
||||
# Nous Portal product attribution
|
||||
@@ -2393,6 +2484,8 @@ class AIAgent:
|
||||
|
||||
if self._session_db:
|
||||
try:
|
||||
# Propagate title to the new session with auto-numbering
|
||||
old_title = self._session_db.get_session_title(self.session_id)
|
||||
self._session_db.end_session(self.session_id, "compression")
|
||||
old_session_id = self.session_id
|
||||
self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||
@@ -2402,6 +2495,13 @@ class AIAgent:
|
||||
model=self.model,
|
||||
parent_session_id=old_session_id,
|
||||
)
|
||||
# Auto-number the title for the continuation session
|
||||
if old_title:
|
||||
try:
|
||||
new_title = self._session_db.get_next_title_in_lineage(old_title)
|
||||
self._session_db.set_session_title(self.session_id, new_title)
|
||||
except (ValueError, Exception) as e:
|
||||
logger.debug("Could not propagate title on compression: %s", e)
|
||||
self._session_db.update_system_prompt(self.session_id, new_system_prompt)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB compression split failed: %s", e)
|
||||
@@ -2528,7 +2628,6 @@ class AIAgent:
|
||||
context=function_args.get("context"),
|
||||
toolsets=function_args.get("toolsets"),
|
||||
tasks=tasks_arg,
|
||||
model=function_args.get("model"),
|
||||
max_iterations=function_args.get("max_iterations"),
|
||||
parent_agent=self,
|
||||
)
|
||||
@@ -2677,7 +2776,7 @@ class AIAgent:
|
||||
else:
|
||||
summary_extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
"effort": "medium"
|
||||
}
|
||||
if _is_nous:
|
||||
summary_extra_body["tags"] = ["product=hermes-agent"]
|
||||
@@ -2740,7 +2839,7 @@ class AIAgent:
|
||||
"messages": api_messages,
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs["max_tokens"] = self.max_tokens
|
||||
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
if summary_extra_body:
|
||||
summary_kwargs["extra_body"] = summary_extra_body
|
||||
|
||||
@@ -2789,13 +2888,15 @@ class AIAgent:
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
# Reset retry counters at the start of each conversation to prevent state leakage
|
||||
# Reset retry counters and iteration budget at the start of each turn
|
||||
# so subagent usage from a previous turn doesn't eat into the next one.
|
||||
self._invalid_tool_retries = 0
|
||||
self._invalid_json_retries = 0
|
||||
self._empty_content_retries = 0
|
||||
self._last_content_with_tools = None
|
||||
self._turns_since_memory = 0
|
||||
self._iters_since_skill = 0
|
||||
self.iteration_budget = IterationBudget(self.max_iterations)
|
||||
|
||||
# Initialize conversation (copy to avoid mutating the caller's list)
|
||||
messages = list(conversation_history) if conversation_history else []
|
||||
@@ -2927,7 +3028,7 @@ class AIAgent:
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
|
||||
while api_call_count < self.max_iterations:
|
||||
while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0:
|
||||
# Check for interrupt request (e.g., user sent new message)
|
||||
if self._interrupt_requested:
|
||||
interrupted = True
|
||||
@@ -2936,6 +3037,10 @@ class AIAgent:
|
||||
break
|
||||
|
||||
api_call_count += 1
|
||||
if not self.iteration_budget.consume():
|
||||
if not self.quiet_mode:
|
||||
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.max_total} total across agent + subagents)")
|
||||
break
|
||||
|
||||
# Fire step_callback for gateway hooks (agent:step event)
|
||||
if self.step_callback is not None:
|
||||
@@ -3012,6 +3117,13 @@ class AIAgent:
|
||||
if self._use_prompt_caching:
|
||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
||||
|
||||
# Safety net: strip orphaned tool results / add stubs for missing
|
||||
# results before sending to the API. The compressor handles this
|
||||
# during compression, but orphans can also sneak in from session
|
||||
# loading or manual message manipulation.
|
||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||||
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token
|
||||
@@ -3040,9 +3152,13 @@ class AIAgent:
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 6 # Increased to allow longer backoff periods
|
||||
compression_attempts = 0
|
||||
max_compression_attempts = 3
|
||||
codex_auth_retry_attempted = False
|
||||
nous_auth_retry_attempted = False
|
||||
|
||||
finish_reason = "stop"
|
||||
response = None # Guard against UnboundLocalError if all retries fail
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
@@ -3236,6 +3352,13 @@ class AIAgent:
|
||||
}
|
||||
self.context_compressor.update_from_response(usage_dict)
|
||||
|
||||
# Cache discovered context length after successful call
|
||||
if self.context_compressor._context_probed:
|
||||
ctx = self.context_compressor.context_length
|
||||
save_context_length(self.model, self.base_url, ctx)
|
||||
print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}")
|
||||
self.context_compressor._context_probed = False
|
||||
|
||||
self.session_prompt_tokens += prompt_tokens
|
||||
self.session_completion_tokens += completion_tokens
|
||||
self.session_total_tokens += total_tokens
|
||||
@@ -3283,6 +3406,16 @@ class AIAgent:
|
||||
if self._try_refresh_codex_client_credentials(force=True):
|
||||
print(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...")
|
||||
continue
|
||||
if (
|
||||
self.api_mode == "chat_completions"
|
||||
and self.provider == "nous"
|
||||
and status_code == 401
|
||||
and not nous_auth_retry_attempted
|
||||
):
|
||||
nous_auth_retry_attempted = True
|
||||
if self._try_refresh_nous_client_credentials(force=True):
|
||||
print(f"{self.log_prefix}🔐 Nous agent key refreshed after 401. Retrying request...")
|
||||
continue
|
||||
|
||||
retry_count += 1
|
||||
elapsed_time = time.time() - api_start_time
|
||||
@@ -3321,7 +3454,19 @@ class AIAgent:
|
||||
)
|
||||
|
||||
if is_payload_too_large:
|
||||
print(f"{self.log_prefix}⚠️ Request payload too large (413) - attempting compression...")
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.")
|
||||
logging.error(f"{self.log_prefix}413 compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"messages": messages,
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Request payload too large: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
print(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
@@ -3330,6 +3475,7 @@ class AIAgent:
|
||||
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
time.sleep(2) # Brief pause between compression retries
|
||||
continue # Retry with compressed messages
|
||||
else:
|
||||
print(f"{self.log_prefix}❌ Payload too large and cannot compress further.")
|
||||
@@ -3355,18 +3501,52 @@ class AIAgent:
|
||||
])
|
||||
|
||||
if is_context_length_error:
|
||||
print(f"{self.log_prefix}⚠️ Context length exceeded - attempting compression...")
|
||||
compressor = self.context_compressor
|
||||
old_ctx = compressor.context_length
|
||||
|
||||
# Try to parse the actual limit from the error message
|
||||
parsed_limit = parse_context_limit_from_error(error_msg)
|
||||
if parsed_limit and parsed_limit < old_ctx:
|
||||
new_ctx = parsed_limit
|
||||
print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})")
|
||||
else:
|
||||
# Step down to the next probe tier
|
||||
new_ctx = get_next_probe_tier(old_ctx)
|
||||
|
||||
if new_ctx and new_ctx < old_ctx:
|
||||
compressor.context_length = new_ctx
|
||||
compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent)
|
||||
compressor._context_probed = True
|
||||
print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens")
|
||||
else:
|
||||
print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...")
|
||||
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.")
|
||||
logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"messages": messages,
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
print(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message, approx_tokens=approx_tokens
|
||||
)
|
||||
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
continue # Retry with compressed messages
|
||||
if len(messages) < original_len or new_ctx and new_ctx < old_ctx:
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
time.sleep(2) # Brief pause between compression retries
|
||||
continue # Retry with compressed messages or new tier
|
||||
else:
|
||||
# Can't compress further
|
||||
# Can't compress further and already at minimum tier
|
||||
print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.")
|
||||
print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.")
|
||||
logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.")
|
||||
@@ -3442,6 +3622,14 @@ class AIAgent:
|
||||
if interrupted:
|
||||
break
|
||||
|
||||
# Guard: if all retries exhausted without a successful response
|
||||
# (e.g. repeated context-length errors that exhausted retry_count),
|
||||
# the `response` variable is still None. Break out cleanly.
|
||||
if response is None:
|
||||
print(f"{self.log_prefix}❌ All API retries exhausted with no successful response.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
break
|
||||
|
||||
try:
|
||||
if self.api_mode == "codex_responses":
|
||||
assistant_message, finish_reason = self._normalize_codex_response(response)
|
||||
@@ -3658,6 +3846,13 @@ class AIAgent:
|
||||
self._log_msg_to_db(assistant_msg)
|
||||
|
||||
self._execute_tool_calls(assistant_message, messages, effective_task_id)
|
||||
|
||||
# Refund the iteration if the ONLY tool(s) called were
|
||||
# execute_code (programmatic tool calling). These are
|
||||
# cheap RPC-style calls that shouldn't eat the budget.
|
||||
_tc_names = {tc.function.name for tc in assistant_message.tool_calls}
|
||||
if _tc_names == {"execute_code"}:
|
||||
self.iteration_budget.refund()
|
||||
|
||||
if self.compression_enabled and self.context_compressor.should_compress():
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
@@ -3678,13 +3873,33 @@ class AIAgent:
|
||||
|
||||
# Check if response only has think block with no actual content after it
|
||||
if not self._has_content_after_think_block(final_response):
|
||||
# Track retries for empty-after-think responses
|
||||
# If the previous turn already delivered real content alongside
|
||||
# tool calls (e.g. "You're welcome!" + memory save), the model
|
||||
# has nothing more to say. Use the earlier content immediately
|
||||
# instead of wasting API calls on retries that won't help.
|
||||
fallback = getattr(self, '_last_content_with_tools', None)
|
||||
if fallback:
|
||||
logger.debug("Empty follow-up after tool calls — using prior turn content as final response")
|
||||
self._last_content_with_tools = None
|
||||
self._empty_content_retries = 0
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_names = []
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
tool_names.append(fn.get("name", "unknown"))
|
||||
msg["content"] = f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..."
|
||||
break
|
||||
final_response = self._strip_think_blocks(fallback).strip()
|
||||
break
|
||||
|
||||
# No fallback available — this is a genuine empty response.
|
||||
# Retry in case the model just had a bad generation.
|
||||
if not hasattr(self, '_empty_content_retries'):
|
||||
self._empty_content_retries = 0
|
||||
self._empty_content_retries += 1
|
||||
|
||||
# Show the reasoning/thinking content so the user can see
|
||||
# what the model was thinking even though content is empty
|
||||
reasoning_text = self._extract_reasoning(assistant_message)
|
||||
print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
|
||||
if reasoning_text:
|
||||
@@ -3840,7 +4055,12 @@ class AIAgent:
|
||||
final_response = f"I apologize, but I encountered repeated errors: {error_msg}"
|
||||
break
|
||||
|
||||
if api_call_count >= self.max_iterations and final_response is None:
|
||||
if final_response is None and (
|
||||
api_call_count >= self.max_iterations
|
||||
or self.iteration_budget.remaining <= 0
|
||||
):
|
||||
if self.iteration_budget.remaining <= 0 and not self.quiet_mode:
|
||||
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} used, including subagents)")
|
||||
final_response = self._handle_max_iterations(messages, api_call_count)
|
||||
|
||||
# Determine if conversation completed successfully
|
||||
@@ -3911,7 +4131,7 @@ def main(
|
||||
|
||||
Args:
|
||||
query (str): Natural language query for the agent. Defaults to Python 3.13 example.
|
||||
model (str): Model name to use (OpenRouter format: provider/model). Defaults to anthropic/claude-sonnet-4-20250514.
|
||||
model (str): Model name to use (OpenRouter format: provider/model). Defaults to anthropic/claude-sonnet-4.6.
|
||||
api_key (str): API key for authentication. Uses OPENROUTER_API_KEY env var if not provided.
|
||||
base_url (str): Base URL for the model API. Defaults to https://openrouter.ai/api/v1
|
||||
max_turns (int): Maximum number of API call iterations. Defaults to 10.
|
||||
|
||||
@@ -829,6 +829,33 @@ install_node_deps() {
|
||||
log_warn "npm install failed (browser tools may not work)"
|
||||
}
|
||||
log_success "Node.js dependencies installed"
|
||||
|
||||
# Install Playwright browser + system dependencies.
|
||||
# Playwright's install-deps only supports apt/dnf/zypper natively.
|
||||
# For Arch/Manjaro we install the system libs via pacman first.
|
||||
log_info "Installing browser engine (Playwright Chromium)..."
|
||||
case "$DISTRO" in
|
||||
arch|manjaro)
|
||||
if command -v pacman &> /dev/null; then
|
||||
log_info "Arch/Manjaro detected — installing Chromium system dependencies via pacman..."
|
||||
if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then
|
||||
sudo NEEDRESTART_MODE=a pacman -S --noconfirm --needed \
|
||||
nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib >/dev/null 2>&1 || true
|
||||
elif [ "$(id -u)" -eq 0 ]; then
|
||||
pacman -S --noconfirm --needed \
|
||||
nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib >/dev/null 2>&1 || true
|
||||
else
|
||||
log_warn "Cannot install browser deps without sudo. Run manually:"
|
||||
log_warn " sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
|
||||
fi
|
||||
fi
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
|
||||
;;
|
||||
*)
|
||||
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || true
|
||||
;;
|
||||
esac
|
||||
log_success "Browser engine installed"
|
||||
fi
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
|
||||
3
skills/apple/DESCRIPTION.md
Normal file
3
skills/apple/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Apple/macOS-specific skills — iMessage, Reminders, Notes, FindMy, and macOS automation. These skills only load on macOS systems.
|
||||
---
|
||||
88
skills/apple/apple-notes/SKILL.md
Normal file
88
skills/apple/apple-notes/SKILL.md
Normal file
@@ -0,0 +1,88 @@
|
||||
---
|
||||
name: apple-notes
|
||||
description: Manage Apple Notes via the memo CLI on macOS (create, view, search, edit).
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Notes, Apple, macOS, note-taking]
|
||||
related_skills: [obsidian]
|
||||
---
|
||||
|
||||
# Apple Notes
|
||||
|
||||
Use `memo` to manage Apple Notes directly from the terminal. Notes sync across all Apple devices via iCloud.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Notes.app
|
||||
- Install: `brew tap antoniorodr/memo && brew install antoniorodr/memo/memo`
|
||||
- Grant Automation access to Notes.app when prompted (System Settings → Privacy → Automation)
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to create, view, or search Apple Notes
|
||||
- Saving information to Notes.app for cross-device access
|
||||
- Organizing notes into folders
|
||||
- Exporting notes to Markdown/HTML
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Obsidian vault management → use the `obsidian` skill
|
||||
- Bear Notes → separate app (not supported here)
|
||||
- Quick agent-only notes → use the `memory` tool instead
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### View Notes
|
||||
|
||||
```bash
|
||||
memo notes # List all notes
|
||||
memo notes -f "Folder Name" # Filter by folder
|
||||
memo notes -s "query" # Search notes (fuzzy)
|
||||
```
|
||||
|
||||
### Create Notes
|
||||
|
||||
```bash
|
||||
memo notes -a # Interactive editor
|
||||
memo notes -a "Note Title" # Quick add with title
|
||||
```
|
||||
|
||||
### Edit Notes
|
||||
|
||||
```bash
|
||||
memo notes -e # Interactive selection to edit
|
||||
```
|
||||
|
||||
### Delete Notes
|
||||
|
||||
```bash
|
||||
memo notes -d # Interactive selection to delete
|
||||
```
|
||||
|
||||
### Move Notes
|
||||
|
||||
```bash
|
||||
memo notes -m # Move note to folder (interactive)
|
||||
```
|
||||
|
||||
### Export Notes
|
||||
|
||||
```bash
|
||||
memo notes -ex # Export to HTML/Markdown
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Cannot edit notes containing images or attachments
|
||||
- Interactive prompts require terminal access (use pty=true if needed)
|
||||
- macOS only — requires Apple Notes.app
|
||||
|
||||
## Rules
|
||||
|
||||
1. Prefer Apple Notes when user wants cross-device sync (iPhone/iPad/Mac)
|
||||
2. Use the `memory` tool for agent-internal notes that don't need to sync
|
||||
3. Use the `obsidian` skill for Markdown-native knowledge management
|
||||
96
skills/apple/apple-reminders/SKILL.md
Normal file
96
skills/apple/apple-reminders/SKILL.md
Normal file
@@ -0,0 +1,96 @@
|
||||
---
|
||||
name: apple-reminders
|
||||
description: Manage Apple Reminders via remindctl CLI (list, add, complete, delete).
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Reminders, tasks, todo, macOS, Apple]
|
||||
---
|
||||
|
||||
# Apple Reminders
|
||||
|
||||
Use `remindctl` to manage Apple Reminders directly from the terminal. Tasks sync across all Apple devices via iCloud.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Reminders.app
|
||||
- Install: `brew install steipete/tap/remindctl`
|
||||
- Grant Reminders permission when prompted
|
||||
- Check: `remindctl status` / Request: `remindctl authorize`
|
||||
|
||||
## When to Use
|
||||
|
||||
- User mentions "reminder" or "Reminders app"
|
||||
- Creating personal to-dos with due dates that sync to iOS
|
||||
- Managing Apple Reminders lists
|
||||
- User wants tasks to appear on their iPhone/iPad
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Scheduling agent alerts → use the cronjob tool instead
|
||||
- Calendar events → use Apple Calendar or Google Calendar
|
||||
- Project task management → use GitHub Issues, Notion, etc.
|
||||
- If user says "remind me" but means an agent alert → clarify first
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### View Reminders
|
||||
|
||||
```bash
|
||||
remindctl # Today's reminders
|
||||
remindctl today # Today
|
||||
remindctl tomorrow # Tomorrow
|
||||
remindctl week # This week
|
||||
remindctl overdue # Past due
|
||||
remindctl all # Everything
|
||||
remindctl 2026-01-04 # Specific date
|
||||
```
|
||||
|
||||
### Manage Lists
|
||||
|
||||
```bash
|
||||
remindctl list # List all lists
|
||||
remindctl list Work # Show specific list
|
||||
remindctl list Projects --create # Create list
|
||||
remindctl list Work --delete # Delete list
|
||||
```
|
||||
|
||||
### Create Reminders
|
||||
|
||||
```bash
|
||||
remindctl add "Buy milk"
|
||||
remindctl add --title "Call mom" --list Personal --due tomorrow
|
||||
remindctl add --title "Meeting prep" --due "2026-02-15 09:00"
|
||||
```
|
||||
|
||||
### Complete / Delete
|
||||
|
||||
```bash
|
||||
remindctl complete 1 2 3 # Complete by ID
|
||||
remindctl delete 4A83 --force # Delete by ID
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
```bash
|
||||
remindctl today --json # JSON for scripting
|
||||
remindctl today --plain # TSV format
|
||||
remindctl today --quiet # Counts only
|
||||
```
|
||||
|
||||
## Date Formats
|
||||
|
||||
Accepted by `--due` and date filters:
|
||||
- `today`, `tomorrow`, `yesterday`
|
||||
- `YYYY-MM-DD`
|
||||
- `YYYY-MM-DD HH:mm`
|
||||
- ISO 8601 (`2026-01-04T12:34:56Z`)
|
||||
|
||||
## Rules
|
||||
|
||||
1. When user says "remind me", clarify: Apple Reminders (syncs to phone) vs agent cronjob alert
|
||||
2. Always confirm reminder content and due date before creating
|
||||
3. Use `--json` for programmatic parsing
|
||||
131
skills/apple/findmy/SKILL.md
Normal file
131
skills/apple/findmy/SKILL.md
Normal file
@@ -0,0 +1,131 @@
|
||||
---
|
||||
name: findmy
|
||||
description: Track Apple devices and AirTags via FindMy.app on macOS using AppleScript and screen capture.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [FindMy, AirTag, location, tracking, macOS, Apple]
|
||||
---
|
||||
|
||||
# Find My (Apple)
|
||||
|
||||
Track Apple devices and AirTags via the FindMy.app on macOS. Since Apple doesn't
|
||||
provide a CLI for FindMy, this skill uses AppleScript to open the app and
|
||||
screen capture to read device locations.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Find My app and iCloud signed in
|
||||
- Devices/AirTags already registered in Find My
|
||||
- Screen Recording permission for terminal (System Settings → Privacy → Screen Recording)
|
||||
- **Optional but recommended**: Install `peekaboo` for better UI automation:
|
||||
`brew install steipete/tap/peekaboo`
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks "where is my [device/cat/keys/bag]?"
|
||||
- Tracking AirTag locations
|
||||
- Checking device locations (iPhone, iPad, Mac, AirPods)
|
||||
- Monitoring pet or item movement over time (AirTag patrol routes)
|
||||
|
||||
## Method 1: AppleScript + Screenshot (Basic)
|
||||
|
||||
### Open FindMy and Navigate
|
||||
|
||||
```bash
|
||||
# Open Find My app
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
|
||||
# Wait for it to load
|
||||
sleep 3
|
||||
|
||||
# Take a screenshot of the Find My window
|
||||
screencapture -w -o /tmp/findmy.png
|
||||
```
|
||||
|
||||
Then use `vision_analyze` to read the screenshot:
|
||||
```
|
||||
vision_analyze(image_url="/tmp/findmy.png", question="What devices/items are shown and what are their locations?")
|
||||
```
|
||||
|
||||
### Switch Between Tabs
|
||||
|
||||
```bash
|
||||
# Switch to Devices tab
|
||||
osascript -e '
|
||||
tell application "System Events"
|
||||
tell process "FindMy"
|
||||
click button "Devices" of toolbar 1 of window 1
|
||||
end tell
|
||||
end tell'
|
||||
|
||||
# Switch to Items tab (AirTags)
|
||||
osascript -e '
|
||||
tell application "System Events"
|
||||
tell process "FindMy"
|
||||
click button "Items" of toolbar 1 of window 1
|
||||
end tell
|
||||
end tell'
|
||||
```
|
||||
|
||||
## Method 2: Peekaboo UI Automation (Recommended)
|
||||
|
||||
If `peekaboo` is installed, use it for more reliable UI interaction:
|
||||
|
||||
```bash
|
||||
# Open Find My
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
sleep 3
|
||||
|
||||
# Capture and annotate the UI
|
||||
peekaboo see --app "FindMy" --annotate --path /tmp/findmy-ui.png
|
||||
|
||||
# Click on a specific device/item by element ID
|
||||
peekaboo click --on B3 --app "FindMy"
|
||||
|
||||
# Capture the detail view
|
||||
peekaboo image --app "FindMy" --path /tmp/findmy-detail.png
|
||||
```
|
||||
|
||||
Then analyze with vision:
|
||||
```
|
||||
vision_analyze(image_url="/tmp/findmy-detail.png", question="What is the location shown for this device/item? Include address and coordinates if visible.")
|
||||
```
|
||||
|
||||
## Workflow: Track AirTag Location Over Time
|
||||
|
||||
For monitoring an AirTag (e.g., tracking a cat's patrol route):
|
||||
|
||||
```bash
|
||||
# 1. Open FindMy to Items tab
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
sleep 3
|
||||
|
||||
# 2. Click on the AirTag item (stay on page — AirTag only updates when page is open)
|
||||
|
||||
# 3. Periodically capture location
|
||||
while true; do
|
||||
screencapture -w -o /tmp/findmy-$(date +%H%M%S).png
|
||||
sleep 300 # Every 5 minutes
|
||||
done
|
||||
```
|
||||
|
||||
Analyze each screenshot with vision to extract coordinates, then compile a route.
|
||||
|
||||
## Limitations
|
||||
|
||||
- FindMy has **no CLI or API** — must use UI automation
|
||||
- AirTags only update location while the FindMy page is actively displayed
|
||||
- Location accuracy depends on nearby Apple devices in the FindMy network
|
||||
- Screen Recording permission required for screenshots
|
||||
- AppleScript UI automation may break across macOS versions
|
||||
|
||||
## Rules
|
||||
|
||||
1. Keep FindMy app in the foreground when tracking AirTags (updates stop when minimized)
|
||||
2. Use `vision_analyze` to read screenshot content — don't try to parse pixels
|
||||
3. For ongoing tracking, use a cronjob to periodically capture and log locations
|
||||
4. Respect privacy — only track devices/items the user owns
|
||||
100
skills/apple/imessage/SKILL.md
Normal file
100
skills/apple/imessage/SKILL.md
Normal file
@@ -0,0 +1,100 @@
|
||||
---
|
||||
name: imessage
|
||||
description: Send and receive iMessages/SMS via the imsg CLI on macOS.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [iMessage, SMS, messaging, macOS, Apple]
|
||||
---
|
||||
|
||||
# iMessage
|
||||
|
||||
Use `imsg` to read and send iMessage/SMS via macOS Messages.app.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Messages.app signed in
|
||||
- Install: `brew install steipete/tap/imsg`
|
||||
- Grant Full Disk Access for terminal (System Settings → Privacy → Full Disk Access)
|
||||
- Grant Automation permission for Messages.app when prompted
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to send an iMessage or text message
|
||||
- Reading iMessage conversation history
|
||||
- Checking recent Messages.app chats
|
||||
- Sending to phone numbers or Apple IDs
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Telegram/Discord/Slack/WhatsApp messages → use the appropriate gateway channel
|
||||
- Group chat management (adding/removing members) → not supported
|
||||
- Bulk/mass messaging → always confirm with user first
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### List Chats
|
||||
|
||||
```bash
|
||||
imsg chats --limit 10 --json
|
||||
```
|
||||
|
||||
### View History
|
||||
|
||||
```bash
|
||||
# By chat ID
|
||||
imsg history --chat-id 1 --limit 20 --json
|
||||
|
||||
# With attachments info
|
||||
imsg history --chat-id 1 --limit 20 --attachments --json
|
||||
```
|
||||
|
||||
### Send Messages
|
||||
|
||||
```bash
|
||||
# Text only
|
||||
imsg send --to "+14155551212" --text "Hello!"
|
||||
|
||||
# With attachment
|
||||
imsg send --to "+14155551212" --text "Check this out" --file /path/to/image.jpg
|
||||
|
||||
# Force iMessage or SMS
|
||||
imsg send --to "+14155551212" --text "Hi" --service imessage
|
||||
imsg send --to "+14155551212" --text "Hi" --service sms
|
||||
```
|
||||
|
||||
### Watch for New Messages
|
||||
|
||||
```bash
|
||||
imsg watch --chat-id 1 --attachments
|
||||
```
|
||||
|
||||
## Service Options
|
||||
|
||||
- `--service imessage` — Force iMessage (requires recipient has iMessage)
|
||||
- `--service sms` — Force SMS (green bubble)
|
||||
- `--service auto` — Let Messages.app decide (default)
|
||||
|
||||
## Rules
|
||||
|
||||
1. **Always confirm recipient and message content** before sending
|
||||
2. **Never send to unknown numbers** without explicit user approval
|
||||
3. **Verify file paths** exist before attaching
|
||||
4. **Don't spam** — rate-limit yourself
|
||||
|
||||
## Example Workflow
|
||||
|
||||
User: "Text mom that I'll be late"
|
||||
|
||||
```bash
|
||||
# 1. Find mom's chat
|
||||
imsg chats --limit 20 --json | jq '.[] | select(.displayName | contains("Mom"))'
|
||||
|
||||
# 2. Confirm with user: "Found Mom at +1555123456. Send 'I'll be late' via iMessage?"
|
||||
|
||||
# 3. Send after confirmation
|
||||
imsg send --to "+1555123456" --text "I'll be late"
|
||||
```
|
||||
76
skills/market-data/polymarket/SKILL.md
Normal file
76
skills/market-data/polymarket/SKILL.md
Normal file
@@ -0,0 +1,76 @@
|
||||
---
|
||||
name: polymarket
|
||||
description: Query Polymarket prediction market data — search markets, get prices, orderbooks, and price history. Read-only via public REST APIs, no API key needed.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent + Teknium
|
||||
tags: [polymarket, prediction-markets, market-data, trading]
|
||||
---
|
||||
|
||||
# Polymarket — Prediction Market Data
|
||||
|
||||
Query prediction market data from Polymarket using their public REST APIs.
|
||||
All endpoints are read-only and require zero authentication.
|
||||
|
||||
See `references/api-endpoints.md` for the full endpoint reference with curl examples.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks about prediction markets, betting odds, or event probabilities
|
||||
- User wants to know "what are the odds of X happening?"
|
||||
- User asks about Polymarket specifically
|
||||
- User wants market prices, orderbook data, or price history
|
||||
- User asks to monitor or track prediction market movements
|
||||
|
||||
## Key Concepts
|
||||
|
||||
- **Events** contain one or more **Markets** (1:many relationship)
|
||||
- **Markets** are binary outcomes with Yes/No prices between 0.00 and 1.00
|
||||
- Prices ARE probabilities: price 0.65 means the market thinks 65% likely
|
||||
- `outcomePrices` field: JSON-encoded array like `["0.80", "0.20"]`
|
||||
- `clobTokenIds` field: JSON-encoded array of two token IDs [Yes, No] for price/book queries
|
||||
- `conditionId` field: hex string used for price history queries
|
||||
- Volume is in USDC (US dollars)
|
||||
|
||||
## Three Public APIs
|
||||
|
||||
1. **Gamma API** at `gamma-api.polymarket.com` — Discovery, search, browsing
|
||||
2. **CLOB API** at `clob.polymarket.com` — Real-time prices, orderbooks, history
|
||||
3. **Data API** at `data-api.polymarket.com` — Trades, open interest
|
||||
|
||||
## Typical Workflow
|
||||
|
||||
When a user asks about prediction market odds:
|
||||
|
||||
1. **Search** using the Gamma API public-search endpoint with their query
|
||||
2. **Parse** the response — extract events and their nested markets
|
||||
3. **Present** market question, current prices as percentages, and volume
|
||||
4. **Deep dive** if asked — use clobTokenIds for orderbook, conditionId for history
|
||||
|
||||
## Presenting Results
|
||||
|
||||
Format prices as percentages for readability:
|
||||
- outcomePrices `["0.652", "0.348"]` becomes "Yes: 65.2%, No: 34.8%"
|
||||
- Always show the market question and probability
|
||||
- Include volume when available
|
||||
|
||||
Example: `"Will X happen?" — 65.2% Yes ($1.2M volume)`
|
||||
|
||||
## Parsing Double-Encoded Fields
|
||||
|
||||
The Gamma API returns `outcomePrices`, `outcomes`, and `clobTokenIds` as JSON strings
|
||||
inside JSON responses (double-encoded). When processing with Python, parse them with
|
||||
`json.loads(market['outcomePrices'])` to get the actual array.
|
||||
|
||||
## Rate Limits
|
||||
|
||||
Generous — unlikely to hit for normal usage:
|
||||
- Gamma: 4,000 requests per 10 seconds (general)
|
||||
- CLOB: 9,000 requests per 10 seconds (general)
|
||||
- Data: 1,000 requests per 10 seconds (general)
|
||||
|
||||
## Limitations
|
||||
|
||||
- This skill is read-only — it does not support placing trades
|
||||
- Trading requires wallet-based crypto authentication (EIP-712 signatures)
|
||||
- Some new markets may have empty price history
|
||||
- Geographic restrictions apply to trading but read-only data is globally accessible
|
||||
220
skills/market-data/polymarket/references/api-endpoints.md
Normal file
220
skills/market-data/polymarket/references/api-endpoints.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# Polymarket API Endpoints Reference
|
||||
|
||||
All endpoints are public REST (GET), return JSON, and need no authentication.
|
||||
|
||||
## Gamma API — gamma-api.polymarket.com
|
||||
|
||||
### Search Markets
|
||||
|
||||
```
|
||||
GET /public-search?q=QUERY
|
||||
```
|
||||
|
||||
Response structure:
|
||||
```json
|
||||
{
|
||||
"events": [
|
||||
{
|
||||
"id": "12345",
|
||||
"title": "Event title",
|
||||
"slug": "event-slug",
|
||||
"volume": 1234567.89,
|
||||
"markets": [
|
||||
{
|
||||
"question": "Will X happen?",
|
||||
"outcomePrices": "[\"0.65\", \"0.35\"]",
|
||||
"outcomes": "[\"Yes\", \"No\"]",
|
||||
"clobTokenIds": "[\"TOKEN_YES\", \"TOKEN_NO\"]",
|
||||
"conditionId": "0xabc...",
|
||||
"volume": 500000
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"pagination": {"hasMore": true, "totalResults": 100}
|
||||
}
|
||||
```
|
||||
|
||||
### List Events
|
||||
|
||||
```
|
||||
GET /events?limit=N&active=true&closed=false&order=volume&ascending=false
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `limit` — max results (default varies)
|
||||
- `offset` — pagination offset
|
||||
- `active` — true/false
|
||||
- `closed` — true/false
|
||||
- `order` — sort field: `volume`, `createdAt`, `updatedAt`
|
||||
- `ascending` — true/false
|
||||
- `tag` — filter by tag slug
|
||||
- `slug` — get specific event by slug
|
||||
|
||||
Response: array of event objects. Each event includes a `markets` array.
|
||||
|
||||
Event fields: `id`, `title`, `slug`, `description`, `volume`, `liquidity`,
|
||||
`openInterest`, `active`, `closed`, `category`, `startDate`, `endDate`,
|
||||
`markets` (array of market objects).
|
||||
|
||||
### List Markets
|
||||
|
||||
```
|
||||
GET /markets?limit=N&active=true&closed=false&order=volume&ascending=false
|
||||
```
|
||||
|
||||
Same filter parameters as events, plus:
|
||||
- `slug` — get specific market by slug
|
||||
|
||||
Market fields: `id`, `question`, `conditionId`, `slug`, `description`,
|
||||
`outcomes`, `outcomePrices`, `volume`, `liquidity`, `active`, `closed`,
|
||||
`marketType`, `clobTokenIds`, `endDate`, `category`, `createdAt`.
|
||||
|
||||
Important: `outcomePrices`, `outcomes`, and `clobTokenIds` are JSON strings
|
||||
(double-encoded). Parse with json.loads() in Python.
|
||||
|
||||
### List Tags
|
||||
|
||||
```
|
||||
GET /tags
|
||||
```
|
||||
|
||||
Returns array of tag objects: `id`, `label`, `slug`.
|
||||
Use the `slug` value when filtering events/markets by tag.
|
||||
|
||||
---
|
||||
|
||||
## CLOB API — clob.polymarket.com
|
||||
|
||||
All CLOB price endpoints use `token_id` from the market's `clobTokenIds` field.
|
||||
Index 0 = Yes outcome, Index 1 = No outcome.
|
||||
|
||||
### Current Price
|
||||
|
||||
```
|
||||
GET /price?token_id=TOKEN_ID&side=buy
|
||||
```
|
||||
|
||||
Response: `{"price": "0.650"}`
|
||||
|
||||
The `side` parameter: `buy` or `sell`.
|
||||
|
||||
### Midpoint Price
|
||||
|
||||
```
|
||||
GET /midpoint?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response: `{"mid": "0.645"}`
|
||||
|
||||
### Spread
|
||||
|
||||
```
|
||||
GET /spread?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response: `{"spread": "0.02"}`
|
||||
|
||||
### Orderbook
|
||||
|
||||
```
|
||||
GET /book?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"market": "condition_id",
|
||||
"asset_id": "token_id",
|
||||
"bids": [{"price": "0.64", "size": "500"}, ...],
|
||||
"asks": [{"price": "0.66", "size": "300"}, ...],
|
||||
"min_order_size": "5",
|
||||
"tick_size": "0.01",
|
||||
"last_trade_price": "0.65"
|
||||
}
|
||||
```
|
||||
|
||||
Bids and asks are sorted by price. Size is in shares (USDC-denominated).
|
||||
|
||||
### Price History
|
||||
|
||||
```
|
||||
GET /prices-history?market=CONDITION_ID&interval=INTERVAL&fidelity=N
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `market` — the conditionId (hex string with 0x prefix)
|
||||
- `interval` — time range: `all`, `1d`, `1w`, `1m`, `3m`, `6m`, `1y`
|
||||
- `fidelity` — number of data points to return
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"history": [
|
||||
{"t": 1709000000, "p": "0.55"},
|
||||
{"t": 1709100000, "p": "0.58"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`t` is Unix timestamp, `p` is price (probability).
|
||||
|
||||
Note: Very new markets may return empty history.
|
||||
|
||||
### CLOB Markets List
|
||||
|
||||
```
|
||||
GET /markets?limit=N
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"condition_id": "0xabc...",
|
||||
"question": "Will X?",
|
||||
"tokens": [
|
||||
{"token_id": "123...", "outcome": "Yes", "price": 0.65},
|
||||
{"token_id": "456...", "outcome": "No", "price": 0.35}
|
||||
],
|
||||
"active": true,
|
||||
"closed": false
|
||||
}
|
||||
],
|
||||
"next_cursor": "cursor_string",
|
||||
"limit": 100,
|
||||
"count": 1000
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data API — data-api.polymarket.com
|
||||
|
||||
### Recent Trades
|
||||
|
||||
```
|
||||
GET /trades?limit=N
|
||||
GET /trades?market=CONDITION_ID&limit=N
|
||||
```
|
||||
|
||||
Trade fields: `side` (BUY/SELL), `size`, `price`, `timestamp`,
|
||||
`title`, `slug`, `outcome`, `transactionHash`, `conditionId`.
|
||||
|
||||
### Open Interest
|
||||
|
||||
```
|
||||
GET /oi?market=CONDITION_ID
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field Cross-Reference
|
||||
|
||||
To go from a Gamma market to CLOB data:
|
||||
|
||||
1. Get market from Gamma: has `clobTokenIds` and `conditionId`
|
||||
2. Parse `clobTokenIds` (JSON string): `["YES_TOKEN", "NO_TOKEN"]`
|
||||
3. Use YES_TOKEN with `/price`, `/book`, `/midpoint`, `/spread`
|
||||
4. Use `conditionId` with `/prices-history` and Data API endpoints
|
||||
284
skills/market-data/polymarket/scripts/polymarket.py
Normal file
284
skills/market-data/polymarket/scripts/polymarket.py
Normal file
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Polymarket CLI helper — query prediction market data.
|
||||
|
||||
Usage:
|
||||
python3 polymarket.py search "bitcoin"
|
||||
python3 polymarket.py trending [--limit 10]
|
||||
python3 polymarket.py market <slug>
|
||||
python3 polymarket.py event <slug>
|
||||
python3 polymarket.py price <token_id>
|
||||
python3 polymarket.py book <token_id>
|
||||
python3 polymarket.py history <condition_id> [--interval all] [--fidelity 50]
|
||||
python3 polymarket.py trades [--limit 10] [--market CONDITION_ID]
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import urllib.error
|
||||
|
||||
GAMMA = "https://gamma-api.polymarket.com"
|
||||
CLOB = "https://clob.polymarket.com"
|
||||
DATA = "https://data-api.polymarket.com"
|
||||
|
||||
|
||||
def _get(url: str) -> dict | list:
|
||||
"""GET request, return parsed JSON."""
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "hermes-agent/1.0"})
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
print(f"HTTP {e.code}: {e.reason}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except urllib.error.URLError as e:
|
||||
print(f"Connection error: {e.reason}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _parse_json_field(val):
|
||||
"""Parse double-encoded JSON fields (outcomePrices, outcomes, clobTokenIds)."""
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return val
|
||||
return val
|
||||
|
||||
|
||||
def _fmt_pct(price_str: str) -> str:
|
||||
"""Format price string as percentage."""
|
||||
try:
|
||||
return f"{float(price_str) * 100:.1f}%"
|
||||
except (ValueError, TypeError):
|
||||
return price_str
|
||||
|
||||
|
||||
def _fmt_volume(vol) -> str:
|
||||
"""Format volume as human-readable."""
|
||||
try:
|
||||
v = float(vol)
|
||||
if v >= 1_000_000:
|
||||
return f"${v / 1_000_000:.1f}M"
|
||||
if v >= 1_000:
|
||||
return f"${v / 1_000:.1f}K"
|
||||
return f"${v:.0f}"
|
||||
except (ValueError, TypeError):
|
||||
return str(vol)
|
||||
|
||||
|
||||
def _print_market(m: dict, indent: str = ""):
|
||||
"""Print a market summary."""
|
||||
question = m.get("question", "?")
|
||||
prices = _parse_json_field(m.get("outcomePrices", "[]"))
|
||||
outcomes = _parse_json_field(m.get("outcomes", "[]"))
|
||||
vol = _fmt_volume(m.get("volume", 0))
|
||||
closed = m.get("closed", False)
|
||||
status = " [CLOSED]" if closed else ""
|
||||
|
||||
if isinstance(prices, list) and len(prices) >= 2:
|
||||
outcome_labels = outcomes if isinstance(outcomes, list) else ["Yes", "No"]
|
||||
price_str = " / ".join(
|
||||
f"{outcome_labels[i]}: {_fmt_pct(prices[i])}"
|
||||
for i in range(min(len(prices), len(outcome_labels)))
|
||||
)
|
||||
print(f"{indent}{question}{status}")
|
||||
print(f"{indent} {price_str} | Volume: {vol}")
|
||||
else:
|
||||
print(f"{indent}{question}{status} | Volume: {vol}")
|
||||
|
||||
slug = m.get("slug", "")
|
||||
if slug:
|
||||
print(f"{indent} slug: {slug}")
|
||||
|
||||
|
||||
def cmd_search(query: str):
|
||||
"""Search for markets."""
|
||||
q = urllib.parse.quote(query)
|
||||
data = _get(f"{GAMMA}/public-search?q={q}")
|
||||
events = data.get("events", [])
|
||||
total = data.get("pagination", {}).get("totalResults", len(events))
|
||||
print(f"Found {total} results for \"{query}\":\n")
|
||||
for evt in events[:10]:
|
||||
print(f"=== {evt['title']} ===")
|
||||
print(f" Volume: {_fmt_volume(evt.get('volume', 0))} | slug: {evt.get('slug', '')}")
|
||||
markets = evt.get("markets", [])
|
||||
for m in markets[:5]:
|
||||
_print_market(m, indent=" ")
|
||||
if len(markets) > 5:
|
||||
print(f" ... and {len(markets) - 5} more markets")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_trending(limit: int = 10):
|
||||
"""Show trending events by volume."""
|
||||
events = _get(f"{GAMMA}/events?limit={limit}&active=true&closed=false&order=volume&ascending=false")
|
||||
print(f"Top {len(events)} trending events:\n")
|
||||
for i, evt in enumerate(events, 1):
|
||||
print(f"{i}. {evt['title']}")
|
||||
print(f" Volume: {_fmt_volume(evt.get('volume', 0))} | Markets: {len(evt.get('markets', []))}")
|
||||
print(f" slug: {evt.get('slug', '')}")
|
||||
markets = evt.get("markets", [])
|
||||
for m in markets[:3]:
|
||||
_print_market(m, indent=" ")
|
||||
if len(markets) > 3:
|
||||
print(f" ... and {len(markets) - 3} more markets")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_market(slug: str):
|
||||
"""Get market details by slug."""
|
||||
markets = _get(f"{GAMMA}/markets?slug={urllib.parse.quote(slug)}")
|
||||
if not markets:
|
||||
print(f"No market found with slug: {slug}")
|
||||
return
|
||||
m = markets[0]
|
||||
print(f"Market: {m.get('question', '?')}")
|
||||
print(f"Status: {'CLOSED' if m.get('closed') else 'ACTIVE'}")
|
||||
_print_market(m)
|
||||
print(f"\n conditionId: {m.get('conditionId', 'N/A')}")
|
||||
tokens = _parse_json_field(m.get("clobTokenIds", "[]"))
|
||||
if isinstance(tokens, list):
|
||||
outcomes = _parse_json_field(m.get("outcomes", "[]"))
|
||||
for i, t in enumerate(tokens):
|
||||
label = outcomes[i] if isinstance(outcomes, list) and i < len(outcomes) else f"Outcome {i}"
|
||||
print(f" token ({label}): {t}")
|
||||
desc = m.get("description", "")
|
||||
if desc:
|
||||
print(f"\n Description: {desc[:500]}")
|
||||
|
||||
|
||||
def cmd_event(slug: str):
|
||||
"""Get event details by slug."""
|
||||
events = _get(f"{GAMMA}/events?slug={urllib.parse.quote(slug)}")
|
||||
if not events:
|
||||
print(f"No event found with slug: {slug}")
|
||||
return
|
||||
evt = events[0]
|
||||
print(f"Event: {evt['title']}")
|
||||
print(f"Volume: {_fmt_volume(evt.get('volume', 0))}")
|
||||
print(f"Status: {'CLOSED' if evt.get('closed') else 'ACTIVE'}")
|
||||
print(f"Markets: {len(evt.get('markets', []))}\n")
|
||||
for m in evt.get("markets", []):
|
||||
_print_market(m, indent=" ")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_price(token_id: str):
|
||||
"""Get current price for a token."""
|
||||
buy = _get(f"{CLOB}/price?token_id={token_id}&side=buy")
|
||||
mid = _get(f"{CLOB}/midpoint?token_id={token_id}")
|
||||
spread = _get(f"{CLOB}/spread?token_id={token_id}")
|
||||
print(f"Token: {token_id[:30]}...")
|
||||
print(f" Buy price: {_fmt_pct(buy.get('price', '?'))}")
|
||||
print(f" Midpoint: {_fmt_pct(mid.get('mid', '?'))}")
|
||||
print(f" Spread: {spread.get('spread', '?')}")
|
||||
|
||||
|
||||
def cmd_book(token_id: str):
|
||||
"""Get orderbook for a token."""
|
||||
book = _get(f"{CLOB}/book?token_id={token_id}")
|
||||
bids = book.get("bids", [])
|
||||
asks = book.get("asks", [])
|
||||
last = book.get("last_trade_price", "?")
|
||||
print(f"Orderbook for {token_id[:30]}...")
|
||||
print(f"Last trade: {_fmt_pct(last)} | Tick size: {book.get('tick_size', '?')}")
|
||||
print(f"\n Top bids ({len(bids)} total):")
|
||||
# Show bids sorted by price descending (best bids first)
|
||||
sorted_bids = sorted(bids, key=lambda x: float(x.get("price", 0)), reverse=True)
|
||||
for b in sorted_bids[:10]:
|
||||
print(f" {_fmt_pct(b['price']):>7} | Size: {float(b['size']):>10.2f}")
|
||||
print(f"\n Top asks ({len(asks)} total):")
|
||||
sorted_asks = sorted(asks, key=lambda x: float(x.get("price", 0)))
|
||||
for a in sorted_asks[:10]:
|
||||
print(f" {_fmt_pct(a['price']):>7} | Size: {float(a['size']):>10.2f}")
|
||||
|
||||
|
||||
def cmd_history(condition_id: str, interval: str = "all", fidelity: int = 50):
|
||||
"""Get price history for a market."""
|
||||
data = _get(f"{CLOB}/prices-history?market={condition_id}&interval={interval}&fidelity={fidelity}")
|
||||
history = data.get("history", [])
|
||||
if not history:
|
||||
print("No price history available for this market.")
|
||||
return
|
||||
print(f"Price history ({len(history)} points, interval={interval}):\n")
|
||||
from datetime import datetime, timezone
|
||||
for pt in history:
|
||||
ts = datetime.fromtimestamp(pt["t"], tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
|
||||
price = _fmt_pct(pt["p"])
|
||||
bar = "█" * int(float(pt["p"]) * 40)
|
||||
print(f" {ts} {price:>7} {bar}")
|
||||
|
||||
|
||||
def cmd_trades(limit: int = 10, market: str = None):
|
||||
"""Get recent trades."""
|
||||
url = f"{DATA}/trades?limit={limit}"
|
||||
if market:
|
||||
url += f"&market={market}"
|
||||
trades = _get(url)
|
||||
if not isinstance(trades, list):
|
||||
print(f"Unexpected response: {trades}")
|
||||
return
|
||||
print(f"Recent trades ({len(trades)}):\n")
|
||||
for t in trades:
|
||||
side = t.get("side", "?")
|
||||
price = _fmt_pct(t.get("price", "?"))
|
||||
size = t.get("size", "?")
|
||||
outcome = t.get("outcome", "?")
|
||||
title = t.get("title", "?")[:50]
|
||||
ts = t.get("timestamp", "")
|
||||
print(f" {side:4} {price:>7} x{float(size):>8.2f} [{outcome}] {title}")
|
||||
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
if not args or args[0] in ("-h", "--help", "help"):
|
||||
print(__doc__)
|
||||
return
|
||||
|
||||
cmd = args[0]
|
||||
|
||||
if cmd == "search" and len(args) >= 2:
|
||||
cmd_search(" ".join(args[1:]))
|
||||
elif cmd == "trending":
|
||||
limit = 10
|
||||
if "--limit" in args:
|
||||
idx = args.index("--limit")
|
||||
limit = int(args[idx + 1]) if idx + 1 < len(args) else 10
|
||||
cmd_trending(limit)
|
||||
elif cmd == "market" and len(args) >= 2:
|
||||
cmd_market(args[1])
|
||||
elif cmd == "event" and len(args) >= 2:
|
||||
cmd_event(args[1])
|
||||
elif cmd == "price" and len(args) >= 2:
|
||||
cmd_price(args[1])
|
||||
elif cmd == "book" and len(args) >= 2:
|
||||
cmd_book(args[1])
|
||||
elif cmd == "history" and len(args) >= 2:
|
||||
interval = "all"
|
||||
fidelity = 50
|
||||
if "--interval" in args:
|
||||
idx = args.index("--interval")
|
||||
interval = args[idx + 1] if idx + 1 < len(args) else "all"
|
||||
if "--fidelity" in args:
|
||||
idx = args.index("--fidelity")
|
||||
fidelity = int(args[idx + 1]) if idx + 1 < len(args) else 50
|
||||
cmd_history(args[1], interval, fidelity)
|
||||
elif cmd == "trades":
|
||||
limit = 10
|
||||
market = None
|
||||
if "--limit" in args:
|
||||
idx = args.index("--limit")
|
||||
limit = int(args[idx + 1]) if idx + 1 < len(args) else 10
|
||||
if "--market" in args:
|
||||
idx = args.index("--market")
|
||||
market = args[idx + 1] if idx + 1 < len(args) else None
|
||||
cmd_trades(limit, market)
|
||||
else:
|
||||
print(f"Unknown command: {cmd}")
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
335
skills/mlops/accelerate/SKILL.md
Normal file
335
skills/mlops/accelerate/SKILL.md
Normal file
@@ -0,0 +1,335 @@
|
||||
---
|
||||
name: huggingface-accelerate
|
||||
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [accelerate, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
|
||||
|
||||
---
|
||||
|
||||
# HuggingFace Accelerate - Unified Distributed Training
|
||||
|
||||
## Quick start
|
||||
|
||||
Accelerate simplifies distributed training to 4 lines of code.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
**Convert PyTorch script** (4 lines):
|
||||
```python
|
||||
import torch
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
|
||||
model = torch.nn.Transformer()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset)
|
||||
|
||||
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
- loss.backward()
|
||||
+ accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Run** (single command):
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From single GPU to multi-GPU
|
||||
|
||||
**Original script**:
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
|
||||
model = torch.nn.Linear(10, 2).to('cuda')
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**With Accelerate** (4 lines added):
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
from accelerate import Accelerator # +1
|
||||
|
||||
accelerator = Accelerator() # +2
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
# No .to('cuda') needed - automatic!
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
accelerator.backward(loss) # +4
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Configure** (interactive):
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
- Which machine? (single/multi GPU/TPU/CPU)
|
||||
- How many machines? (1)
|
||||
- Mixed precision? (no/fp16/bf16/fp8)
|
||||
- DeepSpeed? (no/yes)
|
||||
|
||||
**Launch** (works on any setup):
|
||||
```bash
|
||||
# Single GPU
|
||||
accelerate launch train.py
|
||||
|
||||
# Multi-GPU (8 GPUs)
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
train.py
|
||||
```
|
||||
|
||||
### Workflow 2: Mixed precision training
|
||||
|
||||
**Enable FP16/BF16**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# FP16 (with gradient scaling)
|
||||
accelerator = Accelerator(mixed_precision='fp16')
|
||||
|
||||
# BF16 (no scaling, more stable)
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# FP8 (H100+)
|
||||
accelerator = Accelerator(mixed_precision='fp8')
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
# Everything else is automatic!
|
||||
for batch in dataloader:
|
||||
with accelerator.autocast(): # Optional, done automatically
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
```
|
||||
|
||||
### Workflow 3: DeepSpeed ZeRO integration
|
||||
|
||||
**Enable DeepSpeed ZeRO-2**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
deepspeed_plugin={
|
||||
"zero_stage": 2, # ZeRO-2
|
||||
"offload_optimizer": False,
|
||||
"gradient_accumulation_steps": 4
|
||||
}
|
||||
)
|
||||
|
||||
# Same code as before!
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: DeepSpeed → ZeRO-2
|
||||
```
|
||||
|
||||
**deepspeed_config.json**:
|
||||
```json
|
||||
{
|
||||
"fp16": {"enabled": false},
|
||||
"bf16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_bucket_size": 5e8
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file deepspeed_config.json train.py
|
||||
```
|
||||
|
||||
### Workflow 4: FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Enable FSDP**:
|
||||
```python
|
||||
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
|
||||
cpu_offload=False
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
fsdp_plugin=fsdp_plugin
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: FSDP → Full Shard → No CPU Offload
|
||||
```
|
||||
|
||||
### Workflow 5: Gradient accumulation
|
||||
|
||||
**Accumulate gradients**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Accelerate when**:
|
||||
- Want simplest distributed training
|
||||
- Need single script for any hardware
|
||||
- Use HuggingFace ecosystem
|
||||
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
|
||||
- Need quick prototyping
|
||||
|
||||
**Key advantages**:
|
||||
- **4 lines**: Minimal code changes
|
||||
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
|
||||
- **Automatic**: Device placement, mixed precision, sharding
|
||||
- **Interactive config**: No manual launcher setup
|
||||
- **Single launch**: Works everywhere
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **PyTorch Lightning**: Need callbacks, high-level abstractions
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **DeepSpeed**: Direct API control, advanced features
|
||||
- **Raw DDP**: Maximum control, minimal abstraction
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Wrong device placement**
|
||||
|
||||
Don't manually move to device:
|
||||
```python
|
||||
# WRONG
|
||||
batch = batch.to('cuda')
|
||||
|
||||
# CORRECT
|
||||
# Accelerate handles it automatically after prepare()
|
||||
```
|
||||
|
||||
**Issue: Gradient accumulation not working**
|
||||
|
||||
Use context manager:
|
||||
```python
|
||||
# CORRECT
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Issue: Checkpointing in distributed**
|
||||
|
||||
Use accelerator methods:
|
||||
```python
|
||||
# Save only on main process
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save_state('checkpoint/')
|
||||
|
||||
# Load on all processes
|
||||
accelerator.load_state('checkpoint/')
|
||||
```
|
||||
|
||||
**Issue: Different results with FSDP**
|
||||
|
||||
Ensure same random seed:
|
||||
```python
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(42)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
|
||||
|
||||
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
|
||||
|
||||
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (slow)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
|
||||
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
|
||||
- **TPU**: Supported
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Launcher requirements**:
|
||||
- **DDP**: `torch.distributed.run` (built-in)
|
||||
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
|
||||
- **FSDP**: PyTorch 1.12+ (built-in)
|
||||
- **Megatron**: Custom setup
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/accelerate
|
||||
- GitHub: https://github.com/huggingface/accelerate
|
||||
- Version: 1.11.0+
|
||||
- Tutorial: "Accelerate your scripts"
|
||||
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
|
||||
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries
|
||||
|
||||
|
||||
|
||||
453
skills/mlops/accelerate/references/custom-plugins.md
Normal file
453
skills/mlops/accelerate/references/custom-plugins.md
Normal file
@@ -0,0 +1,453 @@
|
||||
# Custom Plugins for Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
|
||||
|
||||
## Plugin Architecture
|
||||
|
||||
### Base Plugin Structure
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
"""Custom training plugin."""
|
||||
|
||||
# Plugin configuration
|
||||
param1: int = 1
|
||||
param2: str = "default"
|
||||
|
||||
def __post_init__(self):
|
||||
# Validation logic
|
||||
if self.param1 < 1:
|
||||
raise ValueError("param1 must be >= 1")
|
||||
```
|
||||
|
||||
### Using Custom Plugin
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Create plugin
|
||||
custom_plugin = CustomPlugin(param1=4, param2="value")
|
||||
|
||||
# Pass to Accelerator
|
||||
accelerator = Accelerator(
|
||||
custom_plugin=custom_plugin # Not a real parameter, example only
|
||||
)
|
||||
```
|
||||
|
||||
## Built-In Plugin Examples
|
||||
|
||||
### 1. GradScalerKwargs (FP16 Configuration)
|
||||
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
# Configure gradient scaler for FP16
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16, # Initial loss scale
|
||||
growth_factor=2.0, # Scale growth rate
|
||||
backoff_factor=0.5, # Scale backoff rate
|
||||
growth_interval=2000, # Steps between scale increases
|
||||
enabled=True # Enable scaler
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Fine-tune FP16 gradient scaling behavior
|
||||
|
||||
### 2. DistributedDataParallelKwargs
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# Configure DDP behavior
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Gradient bucketing size
|
||||
find_unused_parameters=False, # Find unused params (slower)
|
||||
check_reduction=False, # Check gradient reduction
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False # Static computation graph
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
kwargs_handlers=[ddp_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Optimize DDP performance for specific models
|
||||
|
||||
### 3. FP8RecipeKwargs (H100 FP8)
|
||||
|
||||
```python
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
# Configure FP8 training (H100)
|
||||
fp8_recipe = FP8RecipeKwargs(
|
||||
backend="te", # TransformerEngine backend
|
||||
margin=0, # Scaling margin
|
||||
interval=1, # Scaling interval
|
||||
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
|
||||
amax_history_len=1024, # AMAX history length
|
||||
amax_compute_algo="max" # AMAX computation algorithm
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp8',
|
||||
kwargs_handlers=[fp8_recipe]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Ultra-fast training on H100 GPUs
|
||||
|
||||
## Custom DeepSpeed Configuration
|
||||
|
||||
### ZeRO-3 with CPU Offload
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
# Custom DeepSpeed config
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3, # ZeRO-3
|
||||
offload_optimizer_device="cpu", # CPU offload optimizer
|
||||
offload_param_device="cpu", # CPU offload parameters
|
||||
zero3_init_flag=True, # ZeRO-3 initialization
|
||||
zero3_save_16bit_model=True, # Save FP16 weights
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### ZeRO-2 with NVMe Offload
|
||||
|
||||
```python
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="nvme", # NVMe offload
|
||||
offload_param_device="nvme",
|
||||
nvme_path="/local_nvme", # NVMe mount path
|
||||
)
|
||||
```
|
||||
|
||||
### Custom JSON Config
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Load custom DeepSpeed config
|
||||
with open('deepspeed_config.json', 'r') as f:
|
||||
ds_config = json.load(f)
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
|
||||
|
||||
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
|
||||
```
|
||||
|
||||
**Example config** (`deepspeed_config.json`):
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
## Custom FSDP Configuration
|
||||
|
||||
### FSDP with Custom Auto-Wrap Policy
|
||||
|
||||
```python
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
import functools
|
||||
|
||||
# Custom wrap policy (size-based)
|
||||
wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy,
|
||||
min_num_params=1e6 # Wrap layers with 1M+ params
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
|
||||
mixed_precision_policy=None, # Use Accelerator's mixed precision
|
||||
auto_wrap_policy=wrap_policy, # Custom wrapping
|
||||
cpu_offload=False,
|
||||
ignored_modules=None, # Modules to not wrap
|
||||
state_dict_type="FULL_STATE_DICT", # Save format
|
||||
optim_state_dict_config=None,
|
||||
limit_all_gathers=False,
|
||||
use_orig_params=True, # Use original param shapes
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### FSDP with Transformer Auto-Wrap
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
# Wrap at transformer block level
|
||||
wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrap_policy
|
||||
)
|
||||
```
|
||||
|
||||
## Creating Custom Training Strategy
|
||||
|
||||
### Example: Custom Gradient Accumulation
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
class CustomGradientAccumulation:
|
||||
def __init__(self, steps=4, adaptive=False):
|
||||
self.steps = steps
|
||||
self.adaptive = adaptive
|
||||
self.current_step = 0
|
||||
|
||||
def should_sync(self, loss):
|
||||
"""Decide whether to sync gradients."""
|
||||
self.current_step += 1
|
||||
|
||||
# Adaptive: sync on high loss
|
||||
if self.adaptive and loss > threshold:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
# Regular: sync every N steps
|
||||
if self.current_step >= self.steps:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Usage
|
||||
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
|
||||
accelerator = Accelerator()
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# Scale loss
|
||||
loss = loss / custom_accum.steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Conditional sync
|
||||
if custom_accum.should_sync(loss.item()):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### Example: Custom Mixed Precision
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomMixedPrecision:
|
||||
"""Custom mixed precision with dynamic loss scaling."""
|
||||
|
||||
def __init__(self, init_scale=2**16, scale_window=2000):
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_interval=scale_window
|
||||
)
|
||||
self.scale_history = []
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Scale loss for backward."""
|
||||
return self.scaler.scale(loss)
|
||||
|
||||
def unscale_and_clip(self, optimizer, max_norm=1.0):
|
||||
"""Unscale gradients and clip."""
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
optimizer.param_groups[0]['params'],
|
||||
max_norm
|
||||
)
|
||||
|
||||
def step(self, optimizer):
|
||||
"""Optimizer step with scaler update."""
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
scale_after = self.scaler.get_scale()
|
||||
|
||||
# Track scale changes
|
||||
if scale_before != scale_after:
|
||||
self.scale_history.append(scale_after)
|
||||
|
||||
# Usage
|
||||
custom_mp = CustomMixedPrecision()
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
loss = model(**batch).loss
|
||||
|
||||
scaled_loss = custom_mp.scale_loss(loss)
|
||||
scaled_loss.backward()
|
||||
|
||||
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
|
||||
custom_mp.step(optimizer)
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Advanced: Custom Distributed Backend
|
||||
|
||||
### Custom AllReduce Strategy
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
|
||||
class CustomAllReduce:
|
||||
"""Custom all-reduce with compression."""
|
||||
|
||||
def __init__(self, compression_ratio=0.1):
|
||||
self.compression_ratio = compression_ratio
|
||||
|
||||
def compress_gradients(self, tensor):
|
||||
"""Top-k gradient compression."""
|
||||
k = int(tensor.numel() * self.compression_ratio)
|
||||
values, indices = torch.topk(tensor.abs().view(-1), k)
|
||||
return values, indices
|
||||
|
||||
def all_reduce_compressed(self, tensor):
|
||||
"""All-reduce with gradient compression."""
|
||||
# Compress
|
||||
values, indices = self.compress_gradients(tensor)
|
||||
|
||||
# All-reduce compressed gradients
|
||||
dist.all_reduce(values, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Decompress
|
||||
tensor_compressed = torch.zeros_like(tensor).view(-1)
|
||||
tensor_compressed[indices] = values / dist.get_world_size()
|
||||
|
||||
return tensor_compressed.view_as(tensor)
|
||||
|
||||
# Usage in training loop
|
||||
custom_ar = CustomAllReduce(compression_ratio=0.1)
|
||||
|
||||
for batch in dataloader:
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
|
||||
# Custom all-reduce
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Plugin Best Practices
|
||||
|
||||
### 1. Validation in `__post_init__`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
learning_rate: float = 1e-3
|
||||
warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate parameters
|
||||
if self.learning_rate <= 0:
|
||||
raise ValueError("learning_rate must be positive")
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError("warmup_steps must be non-negative")
|
||||
|
||||
# Compute derived values
|
||||
self.min_lr = self.learning_rate * 0.1
|
||||
```
|
||||
|
||||
### 2. Compatibility Checks
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
feature_enabled: bool = True
|
||||
|
||||
def is_compatible(self, accelerator):
|
||||
"""Check if plugin is compatible with accelerator config."""
|
||||
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
|
||||
raise ValueError("Custom plugin not compatible with FP8")
|
||||
return True
|
||||
```
|
||||
|
||||
### 3. State Management
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
counter: int = 0
|
||||
history: list = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.history is None:
|
||||
self.history = []
|
||||
|
||||
def update_state(self, value):
|
||||
"""Update plugin state during training."""
|
||||
self.counter += 1
|
||||
self.history.append(value)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
|
||||
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
|
||||
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
|
||||
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu
|
||||
489
skills/mlops/accelerate/references/megatron-integration.md
Normal file
489
skills/mlops/accelerate/references/megatron-integration.md
Normal file
@@ -0,0 +1,489 @@
|
||||
# Megatron Integration with Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
|
||||
|
||||
**Megatron capabilities**:
|
||||
- **Tensor Parallelism (TP)**: Split layers across GPUs
|
||||
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
|
||||
- **Data Parallelism (DP)**: Replicate model across GPU groups
|
||||
- **Sequence Parallelism**: Split sequences for long contexts
|
||||
|
||||
## Setup
|
||||
|
||||
### Install Megatron-LM
|
||||
|
||||
```bash
|
||||
# Clone Megatron-LM repository
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd Megatron-LM
|
||||
pip install -e .
|
||||
|
||||
# Install Apex (NVIDIA optimizations)
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
```
|
||||
|
||||
### Accelerate Configuration
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
```
|
||||
In which compute environment are you running?
|
||||
> This machine
|
||||
|
||||
Which type of machine are you using?
|
||||
> Multi-GPU
|
||||
|
||||
How many different machines will you use?
|
||||
> 1
|
||||
|
||||
Do you want to use DeepSpeed/FSDP?
|
||||
> No
|
||||
|
||||
Do you want to use Megatron-LM?
|
||||
> Yes
|
||||
|
||||
What is the Tensor Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Do you want to enable Sequence Parallelism?
|
||||
> No
|
||||
|
||||
What is the Pipeline Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
What is the Data Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
|
||||
> SELECTIVE
|
||||
|
||||
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
|
||||
> SEQUENTIAL
|
||||
```
|
||||
|
||||
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MEGATRON_LM
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config:
|
||||
megatron_lm_gradient_clipping: 1.0
|
||||
megatron_lm_learning_rate_decay_iters: 320000
|
||||
megatron_lm_num_micro_batches: 1
|
||||
megatron_lm_pp_degree: 2
|
||||
megatron_lm_recompute_activations: true
|
||||
megatron_lm_sequence_parallelism: false
|
||||
megatron_lm_tp_degree: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**Splits each transformer layer across GPUs**:
|
||||
|
||||
```python
|
||||
# Layer split across 2 GPUs
|
||||
# GPU 0: First half of attention heads
|
||||
# GPU 1: Second half of attention heads
|
||||
|
||||
# Each GPU computes partial outputs
|
||||
# All-reduce combines results
|
||||
```
|
||||
|
||||
**TP degree recommendations**:
|
||||
- **TP=1**: No tensor parallelism (single GPU per layer)
|
||||
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
|
||||
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
|
||||
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Reduces memory per GPU
|
||||
- All-reduce communication (fast)
|
||||
|
||||
**Drawbacks**:
|
||||
- Requires fast inter-GPU bandwidth (NVLink)
|
||||
- Communication overhead per layer
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**Splits model depth across GPUs**:
|
||||
|
||||
```python
|
||||
# 12-layer model, PP=4
|
||||
# GPU 0: Layers 0-2
|
||||
# GPU 1: Layers 3-5
|
||||
# GPU 2: Layers 6-8
|
||||
# GPU 3: Layers 9-11
|
||||
```
|
||||
|
||||
**PP degree recommendations**:
|
||||
- **PP=1**: No pipeline parallelism
|
||||
- **PP=2**: 2 pipeline stages (good for 20-40B models)
|
||||
- **PP=4**: 4 pipeline stages (good for 70B+ models)
|
||||
- **PP=8**: 8 pipeline stages (good for 175B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Linear memory reduction (4× PP = 4× less memory)
|
||||
- Works across nodes (slower interconnect OK)
|
||||
|
||||
**Drawbacks**:
|
||||
- Pipeline bubbles (idle time)
|
||||
- Requires micro-batching
|
||||
|
||||
### Data Parallelism (DP)
|
||||
|
||||
**Replicates model across GPU groups**:
|
||||
|
||||
```python
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
# Group 0 (GPUs 0-3): Full model replica
|
||||
# Group 1 (GPUs 4-7): Full model replica
|
||||
```
|
||||
|
||||
**DP degree**:
|
||||
- `DP = total_gpus / (TP × PP)`
|
||||
- Example: 8 GPUs, TP=2, PP=2 → DP=2
|
||||
|
||||
**Benefits**:
|
||||
- Increases throughput
|
||||
- Scales batch size
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
**Splits long sequences across GPUs** (extends TP):
|
||||
|
||||
```python
|
||||
# 8K sequence, TP=2, Sequence Parallel=True
|
||||
# GPU 0: Tokens 0-4095
|
||||
# GPU 1: Tokens 4096-8191
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Enables very long sequences (100K+ tokens)
|
||||
- Reduces activation memory
|
||||
|
||||
**Requirements**:
|
||||
- Must use with TP > 1
|
||||
- RoPE/ALiBi position encodings work best
|
||||
|
||||
## Accelerate Code Example
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
|
||||
# Configure Megatron
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2, # Tensor parallelism degree
|
||||
pp_degree=2, # Pipeline parallelism degree
|
||||
num_micro_batches=4, # Micro-batches for pipeline
|
||||
gradient_clipping=1.0, # Gradient clipping value
|
||||
sequence_parallelism=False, # Enable sequence parallelism
|
||||
recompute_activations=True, # Activation checkpointing
|
||||
use_distributed_optimizer=True, # Distributed optimizer
|
||||
custom_prepare_model_function=None, # Custom model prep
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Training loop (same as DDP!)
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
### Full Training Script
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
def main():
|
||||
# Megatron configuration
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2,
|
||||
pp_degree=2,
|
||||
num_micro_batches=4,
|
||||
gradient_clipping=1.0,
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
gradient_accumulation_steps=8,
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Model
|
||||
config = GPT2Config(
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_embd=1024,
|
||||
)
|
||||
model = GPT2LMHeadModel(config)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
|
||||
|
||||
# Prepare
|
||||
model, optimizer, train_loader = accelerator.prepare(
|
||||
model, optimizer, train_loader
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
with accelerator.accumulate(model):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save checkpoint
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.save_state(f'checkpoint-epoch-{epoch}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch Command
|
||||
|
||||
```bash
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node (2 nodes, 8 GPUs each)
|
||||
# Node 0
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
|
||||
# Node 1
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 1 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Activation Checkpointing
|
||||
|
||||
**Reduces memory by recomputing activations**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
recompute_activations=True, # Enable checkpointing
|
||||
checkpoint_num_layers=1, # Checkpoint every N layers
|
||||
distribute_checkpointed_activations=True, # Distribute across TP
|
||||
partition_activations=True, # Partition in PP
|
||||
check_for_nan_in_loss_and_grad=True, # Stability check
|
||||
)
|
||||
```
|
||||
|
||||
**Strategies**:
|
||||
- `SELECTIVE`: Checkpoint transformer blocks only
|
||||
- `FULL`: Checkpoint all layers
|
||||
- `NONE`: No checkpointing
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
## Distributed Optimizer
|
||||
|
||||
**Shards optimizer state across DP ranks**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
use_distributed_optimizer=True, # Enable sharded optimizer
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Reduces optimizer memory by DP degree
|
||||
- Example: DP=4 → 4× less optimizer memory per GPU
|
||||
|
||||
**Compatible with**:
|
||||
- AdamW, Adam, SGD
|
||||
- Mixed precision training
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Micro-Batch Size
|
||||
|
||||
```python
|
||||
# Pipeline parallelism requires micro-batching
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # 16 micro-batches per pipeline
|
||||
)
|
||||
|
||||
# Effective batch = num_micro_batches × micro_batch_size × DP
|
||||
# Example: 16 × 2 × 4 = 128
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- More micro-batches → less pipeline bubble
|
||||
- Typical: 4-16 micro-batches
|
||||
|
||||
### Sequence Length
|
||||
|
||||
```python
|
||||
# For long sequences, enable sequence parallelism
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=4,
|
||||
sequence_parallelism=True, # Required: TP > 1
|
||||
)
|
||||
|
||||
# Enables sequences up to TP × normal limit
|
||||
# Example: TP=4, 8K normal → 32K with sequence parallel
|
||||
```
|
||||
|
||||
### GPU Topology
|
||||
|
||||
**NVLink required for TP**:
|
||||
```bash
|
||||
# Check NVLink topology
|
||||
nvidia-smi topo -m
|
||||
|
||||
# Good topology (NVLink between all GPUs)
|
||||
# GPU0 - GPU1: NV12 (fast)
|
||||
# GPU0 - GPU2: NV12 (fast)
|
||||
|
||||
# Bad topology (PCIe only)
|
||||
# GPU0 - GPU4: PHB (slow, avoid TP across these)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- **TP**: Within same node (NVLink)
|
||||
- **PP**: Across nodes (slower interconnect OK)
|
||||
- **DP**: Any topology
|
||||
|
||||
## Model Size Guidelines
|
||||
|
||||
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|
||||
|------------|------|----|----|----|--------------|
|
||||
| 7B | 8 | 1 | 1 | 8 | 1 |
|
||||
| 13B | 8 | 2 | 1 | 4 | 1 |
|
||||
| 20B | 16 | 4 | 1 | 4 | 1 |
|
||||
| 40B | 32 | 4 | 2 | 4 | 4 |
|
||||
| 70B | 64 | 8 | 2 | 4 | 8 |
|
||||
| 175B | 128 | 8 | 4 | 4 | 16 |
|
||||
|
||||
**Assumptions**: BF16, 2K sequence length, A100 80GB
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
# Save full model state
|
||||
accelerator.save_state('checkpoint-1000')
|
||||
|
||||
# Megatron saves separate files per rank
|
||||
# checkpoint-1000/
|
||||
# pytorch_model_tp_0_pp_0.bin
|
||||
# pytorch_model_tp_0_pp_1.bin
|
||||
# pytorch_model_tp_1_pp_0.bin
|
||||
# pytorch_model_tp_1_pp_1.bin
|
||||
# optimizer_tp_0_pp_0.bin
|
||||
# ...
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
accelerator.load_state('checkpoint-1000')
|
||||
|
||||
# Automatically loads correct shard per rank
|
||||
```
|
||||
|
||||
### Convert to Standard PyTorch
|
||||
|
||||
```bash
|
||||
# Merge Megatron checkpoint to single file
|
||||
python merge_megatron_checkpoint.py \
|
||||
--checkpoint-dir checkpoint-1000 \
|
||||
--output pytorch_model.bin
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: OOM with Pipeline Parallelism
|
||||
|
||||
**Solution**: Increase micro-batches
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # Increase from 4
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Training
|
||||
|
||||
**Check 1**: Pipeline bubbles (PP too high)
|
||||
```python
|
||||
# Reduce PP, increase TP
|
||||
tp_degree=4 # Increase
|
||||
pp_degree=2 # Decrease
|
||||
```
|
||||
|
||||
**Check 2**: Micro-batch size too small
|
||||
```python
|
||||
num_micro_batches=8 # Increase
|
||||
```
|
||||
|
||||
### Issue: NVLink Not Detected
|
||||
|
||||
```bash
|
||||
# Verify NVLink
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# If no NVLink, avoid TP > 1
|
||||
# Use PP or DP instead
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
|
||||
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
|
||||
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
|
||||
- NVIDIA Apex: https://github.com/NVIDIA/apex
|
||||
525
skills/mlops/accelerate/references/performance.md
Normal file
525
skills/mlops/accelerate/references/performance.md
Normal file
@@ -0,0 +1,525 @@
|
||||
# Accelerate Performance Tuning
|
||||
|
||||
## Profiling
|
||||
|
||||
### Basic Profiling
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
import time
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
batch = next(iter(dataloader))
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Profile training loop
|
||||
start = time.time()
|
||||
total_batches = 100
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= total_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone() # Sync all processes
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
batches_per_sec = total_batches / elapsed
|
||||
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
|
||||
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
|
||||
print(f"Batches/sec: {batches_per_sec:.2f}")
|
||||
```
|
||||
|
||||
### PyTorch Profiler Integration
|
||||
|
||||
```python
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True
|
||||
) as prof:
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10: # Profile first 10 batches
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Print profiling results
|
||||
print(prof.key_averages().table(
|
||||
sort_by="cuda_time_total", row_limit=20
|
||||
))
|
||||
|
||||
# Export to Chrome tracing
|
||||
prof.export_chrome_trace("trace.json")
|
||||
# View at chrome://tracing
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### 1. Gradient Accumulation
|
||||
|
||||
**Problem**: Large batch size causes OOM
|
||||
|
||||
**Solution**: Accumulate gradients across micro-batches
|
||||
|
||||
```python
|
||||
accelerator = Accelerator(gradient_accumulation_steps=8)
|
||||
|
||||
# Effective batch = batch_size × accumulation_steps × num_gpus
|
||||
# Example: 4 × 8 × 8 = 256
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation logic
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
|
||||
|
||||
### 2. Gradient Checkpointing
|
||||
|
||||
**Enable in model**:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
use_cache=False # Required for gradient checkpointing
|
||||
)
|
||||
|
||||
# Enable checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare with Accelerate
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
### 3. Mixed Precision
|
||||
|
||||
**BF16 (A100/H100)**:
|
||||
```python
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# Automatic mixed precision
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch) # Forward in BF16
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss) # Backward in FP32
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**FP16 (V100, older GPUs)**:
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16,
|
||||
growth_interval=2000
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% compared to FP32
|
||||
|
||||
### 4. CPU Offloading (DeepSpeed)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu", # Offload optimizer to CPU
|
||||
offload_param_device="cpu", # Offload parameters to CPU
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
|
||||
|
||||
**Trade-off**: 20-30% slower due to CPU-GPU transfers
|
||||
|
||||
### 5. Flash Attention
|
||||
|
||||
```python
|
||||
# Install flash-attn
|
||||
# pip install flash-attn
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% for attention, 2× faster
|
||||
|
||||
**Requirements**: A100/H100, sequence length must be multiple of 128
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### 1. Gradient Bucketing (DDP)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Bucket size for gradient reduction
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=False # Set True if model doesn't change
|
||||
)
|
||||
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
```
|
||||
|
||||
**Recommended bucket sizes**:
|
||||
- Small models (<1B): 25 MB
|
||||
- Medium models (1-10B): 50-100 MB
|
||||
- Large models (>10B): 100-200 MB
|
||||
|
||||
### 2. Find Unused Parameters
|
||||
|
||||
```python
|
||||
# Only enable if model has unused parameters (slower!)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
find_unused_parameters=True
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Models with conditional branches (e.g., mixture of experts)
|
||||
|
||||
**Cost**: 10-20% slower
|
||||
|
||||
### 3. NCCL Tuning
|
||||
|
||||
```bash
|
||||
# Set environment variables before launch
|
||||
export NCCL_DEBUG=INFO # Debug info
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Network interface
|
||||
export NCCL_P2P_LEVEL=NVL # Use NVLink
|
||||
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
**NCCL_P2P_LEVEL options**:
|
||||
- `NVL`: NVLink (fastest, within node)
|
||||
- `PIX`: PCIe (fast, within node)
|
||||
- `PHB`: PCIe host bridge (slow, cross-node)
|
||||
|
||||
## Data Loading Optimization
|
||||
|
||||
### 1. DataLoader Workers
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Parallel data loading
|
||||
pin_memory=True, # Pin memory for faster GPU transfer
|
||||
prefetch_factor=2, # Prefetch batches per worker
|
||||
persistent_workers=True # Keep workers alive between epochs
|
||||
)
|
||||
|
||||
train_loader = accelerator.prepare(train_loader)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
|
||||
- `pin_memory`: Always True for GPU training
|
||||
- `prefetch_factor`: 2-4 (higher for slow data loading)
|
||||
|
||||
### 2. Data Preprocessing
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Bad: Preprocess during training (slow)
|
||||
dataset = load_dataset("openwebtext")
|
||||
|
||||
for batch in dataset:
|
||||
tokens = tokenizer(batch['text']) # Slow!
|
||||
...
|
||||
|
||||
# Good: Preprocess once, save
|
||||
dataset = load_dataset("openwebtext")
|
||||
tokenized = dataset.map(
|
||||
lambda x: tokenizer(x['text']),
|
||||
batched=True,
|
||||
num_proc=8, # Parallel preprocessing
|
||||
remove_columns=['text']
|
||||
)
|
||||
tokenized.save_to_disk("preprocessed_data")
|
||||
|
||||
# Load preprocessed
|
||||
dataset = load_from_disk("preprocessed_data")
|
||||
```
|
||||
|
||||
### 3. Faster Tokenization
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable Rust-based tokenizers (10× faster)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"gpt2",
|
||||
use_fast=True # Use fast Rust tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## Compilation (PyTorch 2.0+)
|
||||
|
||||
### Compile Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Compile model for faster execution
|
||||
model = torch.compile(
|
||||
model,
|
||||
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
|
||||
fullgraph=False, # Compile entire graph (stricter)
|
||||
dynamic=True # Support dynamic shapes
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Speedup**: 10-50% depending on model
|
||||
|
||||
**Compilation modes**:
|
||||
- `default`: Balanced (best for most cases)
|
||||
- `reduce-overhead`: Min overhead (best for small batches)
|
||||
- `max-autotune`: Max performance (slow compile, best for production)
|
||||
|
||||
### Compilation Best Practices
|
||||
|
||||
```python
|
||||
# Bad: Compile after prepare (won't work)
|
||||
model = accelerator.prepare(model)
|
||||
model = torch.compile(model) # Error!
|
||||
|
||||
# Good: Compile before prepare
|
||||
model = torch.compile(model)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
# First iteration: slow (compilation)
|
||||
# Subsequent iterations: fast (compiled)
|
||||
outputs = model(**batch)
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarking Different Strategies
|
||||
|
||||
### Script Template
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def benchmark_strategy(strategy_name, accelerator_kwargs):
|
||||
"""Benchmark a specific training strategy."""
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# Setup
|
||||
model = create_model()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
dataloader = create_dataloader()
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(
|
||||
model, optimizer, dataloader
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10:
|
||||
break
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Benchmark
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
num_batches = 100
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= num_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(f"\n{strategy_name}:")
|
||||
print(f" Throughput: {throughput:.2f} samples/sec")
|
||||
print(f" Memory: {memory_used:.2f} GB")
|
||||
print(f" Time: {elapsed:.2f} sec")
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Benchmark different strategies
|
||||
strategies = [
|
||||
("DDP + FP32", {}),
|
||||
("DDP + BF16", {"mixed_precision": "bf16"}),
|
||||
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
|
||||
("FSDP", {"fsdp_plugin": fsdp_plugin}),
|
||||
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
|
||||
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
|
||||
]
|
||||
|
||||
for name, kwargs in strategies:
|
||||
benchmark_strategy(name, kwargs)
|
||||
```
|
||||
|
||||
## Performance Checklist
|
||||
|
||||
**Before training**:
|
||||
- [ ] Use BF16/FP16 mixed precision
|
||||
- [ ] Enable gradient checkpointing (if OOM)
|
||||
- [ ] Set appropriate `num_workers` (2-4 per GPU)
|
||||
- [ ] Enable `pin_memory=True`
|
||||
- [ ] Preprocess data once, not during training
|
||||
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
|
||||
|
||||
**For large models**:
|
||||
- [ ] Use FSDP or DeepSpeed ZeRO-3
|
||||
- [ ] Enable CPU offloading (if still OOM)
|
||||
- [ ] Use Flash Attention
|
||||
- [ ] Increase gradient accumulation
|
||||
|
||||
**For multi-node**:
|
||||
- [ ] Check network topology (InfiniBand > Ethernet)
|
||||
- [ ] Tune NCCL settings
|
||||
- [ ] Use larger bucket sizes for DDP
|
||||
- [ ] Verify NVLink for tensor parallelism
|
||||
|
||||
**Profiling**:
|
||||
- [ ] Profile first 10-100 batches
|
||||
- [ ] Check GPU utilization (`nvidia-smi dmon`)
|
||||
- [ ] Check data loading time (should be <5% of iteration)
|
||||
- [ ] Identify communication bottlenecks
|
||||
|
||||
## Common Performance Issues
|
||||
|
||||
### Issue: Low GPU Utilization (<80%)
|
||||
|
||||
**Cause 1**: Data loading bottleneck
|
||||
```python
|
||||
# Solution: Increase workers and prefetch
|
||||
num_workers=8
|
||||
prefetch_factor=4
|
||||
```
|
||||
|
||||
**Cause 2**: Small batch size
|
||||
```python
|
||||
# Solution: Increase batch size or use gradient accumulation
|
||||
batch_size=32 # Increase
|
||||
gradient_accumulation_steps=4 # Or accumulate
|
||||
```
|
||||
|
||||
### Issue: High Memory Usage
|
||||
|
||||
**Solution 1**: Gradient checkpointing
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Solution 2**: Reduce batch size, increase accumulation
|
||||
```python
|
||||
batch_size=8 # Reduce from 32
|
||||
gradient_accumulation_steps=16 # Maintain effective batch
|
||||
```
|
||||
|
||||
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
|
||||
```python
|
||||
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
|
||||
```
|
||||
|
||||
### Issue: Slow Multi-GPU Training
|
||||
|
||||
**Cause**: Communication bottleneck
|
||||
|
||||
**Check 1**: Gradient bucket size
|
||||
```python
|
||||
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
|
||||
```
|
||||
|
||||
**Check 2**: NCCL settings
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
|
||||
```
|
||||
|
||||
**Check 3**: Network bandwidth
|
||||
```bash
|
||||
# Test inter-GPU bandwidth
|
||||
nvidia-smi nvlink -s
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
|
||||
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
|
||||
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
|
||||
- Flash Attention: https://github.com/Dao-AILab/flash-attention
|
||||
567
skills/mlops/audiocraft/SKILL.md
Normal file
567
skills/mlops/audiocraft/SKILL.md
Normal file
@@ -0,0 +1,567 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
666
skills/mlops/audiocraft/references/advanced-usage.md
Normal file
666
skills/mlops/audiocraft/references/advanced-usage.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
504
skills/mlops/audiocraft/references/troubleshooting.md
Normal file
504
skills/mlops/audiocraft/references/troubleshooting.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
81
skills/mlops/code-review/SKILL.md
Normal file
81
skills/mlops/code-review/SKILL.md
Normal file
@@ -0,0 +1,81 @@
|
||||
---
|
||||
name: code-review
|
||||
description: Guidelines for performing thorough code reviews with security and quality focus
|
||||
---
|
||||
|
||||
# Code Review Skill
|
||||
|
||||
Use this skill when reviewing code changes, pull requests, or auditing existing code.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
### 1. Security First
|
||||
- [ ] No hardcoded secrets, API keys, or credentials
|
||||
- [ ] Input validation on all user-provided data
|
||||
- [ ] SQL queries use parameterized statements (no string concatenation)
|
||||
- [ ] File operations validate paths (no path traversal)
|
||||
- [ ] Authentication/authorization checks present where needed
|
||||
|
||||
### 2. Error Handling
|
||||
- [ ] All external calls (API, DB, file) have try/catch
|
||||
- [ ] Errors are logged with context (but no sensitive data)
|
||||
- [ ] User-facing errors are helpful but don't leak internals
|
||||
- [ ] Resources are cleaned up in finally blocks or context managers
|
||||
|
||||
### 3. Code Quality
|
||||
- [ ] Functions do one thing and are reasonably sized (<50 lines ideal)
|
||||
- [ ] Variable names are descriptive (no single letters except loops)
|
||||
- [ ] No commented-out code left behind
|
||||
- [ ] Complex logic has explanatory comments
|
||||
- [ ] No duplicate code (DRY principle)
|
||||
|
||||
### 4. Testing Considerations
|
||||
- [ ] Edge cases handled (empty inputs, nulls, boundaries)
|
||||
- [ ] Happy path and error paths both work
|
||||
- [ ] New code has corresponding tests (if test suite exists)
|
||||
|
||||
## Review Response Format
|
||||
|
||||
When providing review feedback, structure it as:
|
||||
|
||||
```
|
||||
## Summary
|
||||
[1-2 sentence overall assessment]
|
||||
|
||||
## Critical Issues (Must Fix)
|
||||
- Issue 1: [description + suggested fix]
|
||||
- Issue 2: ...
|
||||
|
||||
## Suggestions (Nice to Have)
|
||||
- Suggestion 1: [description]
|
||||
|
||||
## Questions
|
||||
- [Any clarifying questions about intent]
|
||||
```
|
||||
|
||||
## Common Patterns to Flag
|
||||
|
||||
### Python
|
||||
```python
|
||||
# Bad: SQL injection risk
|
||||
cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
# Good: Parameterized query
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
```
|
||||
|
||||
### JavaScript
|
||||
```javascript
|
||||
// Bad: XSS risk
|
||||
element.innerHTML = userInput;
|
||||
|
||||
// Good: Safe text content
|
||||
element.textContent = userInput;
|
||||
```
|
||||
|
||||
## Tone Guidelines
|
||||
|
||||
- Be constructive, not critical
|
||||
- Explain *why* something is an issue, not just *what*
|
||||
- Offer solutions, not just problems
|
||||
- Acknowledge good patterns you see
|
||||
224
skills/mlops/faiss/SKILL.md
Normal file
224
skills/mlops/faiss/SKILL.md
Normal file
@@ -0,0 +1,224 @@
|
||||
---
|
||||
name: faiss
|
||||
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [faiss-cpu, faiss-gpu, numpy]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
|
||||
|
||||
---
|
||||
|
||||
# FAISS - Efficient Similarity Search
|
||||
|
||||
Facebook AI's library for billion-scale vector similarity search.
|
||||
|
||||
## When to use FAISS
|
||||
|
||||
**Use FAISS when:**
|
||||
- Need fast similarity search on large vector datasets (millions/billions)
|
||||
- GPU acceleration required
|
||||
- Pure vector similarity (no metadata filtering needed)
|
||||
- High throughput, low latency critical
|
||||
- Offline/batch processing of embeddings
|
||||
|
||||
**Metrics**:
|
||||
- **31,700+ GitHub stars**
|
||||
- Meta/Facebook AI Research
|
||||
- **Handles billions of vectors**
|
||||
- **C++** with Python bindings
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Chroma/Pinecone**: Need metadata filtering
|
||||
- **Weaviate**: Need full database features
|
||||
- **Annoy**: Simpler, fewer features
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# CPU only
|
||||
pip install faiss-cpu
|
||||
|
||||
# GPU support
|
||||
pip install faiss-gpu
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
# Create sample data (1000 vectors, 128 dimensions)
|
||||
d = 128
|
||||
nb = 1000
|
||||
vectors = np.random.random((nb, d)).astype('float32')
|
||||
|
||||
# Create index
|
||||
index = faiss.IndexFlatL2(d) # L2 distance
|
||||
index.add(vectors) # Add vectors
|
||||
|
||||
# Search
|
||||
k = 5 # Find 5 nearest neighbors
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
|
||||
print(f"Nearest neighbors: {indices}")
|
||||
print(f"Distances: {distances}")
|
||||
```
|
||||
|
||||
## Index types
|
||||
|
||||
### 1. Flat (exact search)
|
||||
|
||||
```python
|
||||
# L2 (Euclidean) distance
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Inner product (cosine similarity if normalized)
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Slowest, most accurate
|
||||
```
|
||||
|
||||
### 2. IVF (inverted file) - Fast approximate
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# IVF index with 100 clusters
|
||||
nlist = 100
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 3. HNSW (Hierarchical NSW) - Best quality/speed
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# No training needed
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 4. Product Quantization - Memory efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers
|
||||
nbits = 8
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train and add
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
## Save and load
|
||||
|
||||
```python
|
||||
# Save index
|
||||
faiss.write_index(index, "large.index")
|
||||
|
||||
# Load index
|
||||
index = faiss.read_index("large.index")
|
||||
|
||||
# Continue using
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
# Single GPU
|
||||
res = faiss.StandardGpuResources()
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Multi-GPU
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# 10-100× faster than CPU
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Create FAISS vector store
|
||||
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
|
||||
|
||||
# Save
|
||||
vectorstore.save_local("faiss_index")
|
||||
|
||||
# Load
|
||||
vectorstore = FAISS.load_local(
|
||||
"faiss_index",
|
||||
OpenAIEmbeddings(),
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
|
||||
# Search
|
||||
results = vectorstore.similarity_search("query", k=5)
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
import faiss
|
||||
|
||||
# Create FAISS index
|
||||
d = 1536
|
||||
faiss_index = faiss.IndexFlatL2(d)
|
||||
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
|
||||
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
|
||||
3. **Use GPU for large datasets** - 10-100× faster
|
||||
4. **Save trained indices** - Training is expensive
|
||||
5. **Tune nprobe/ef_search** - Balance speed/accuracy
|
||||
6. **Monitor memory** - PQ for large datasets
|
||||
7. **Batch queries** - Better GPU utilization
|
||||
|
||||
## Performance
|
||||
|
||||
| Index Type | Build Time | Search Time | Memory | Accuracy |
|
||||
|------------|------------|-------------|--------|----------|
|
||||
| Flat | Fast | Slow | High | 100% |
|
||||
| IVF | Medium | Fast | Medium | 95-99% |
|
||||
| HNSW | Slow | Fastest | High | 99% |
|
||||
| PQ | Medium | Fast | Low | 90-95% |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
280
skills/mlops/faiss/references/index_types.md
Normal file
280
skills/mlops/faiss/references/index_types.md
Normal file
@@ -0,0 +1,280 @@
|
||||
# FAISS Index Types Guide
|
||||
|
||||
Complete guide to choosing and using FAISS index types.
|
||||
|
||||
## Index selection guide
|
||||
|
||||
| Dataset Size | Index Type | Training | Accuracy | Speed |
|
||||
|--------------|------------|----------|----------|-------|
|
||||
| < 10K | Flat | No | 100% | Slow |
|
||||
| 10K-1M | IVF | Yes | 95-99% | Fast |
|
||||
| 1M-10M | HNSW | No | 99% | Fastest |
|
||||
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
|
||||
|
||||
## Flat indices (exact search)
|
||||
|
||||
### IndexFlatL2 - L2 (Euclidean) distance
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
d = 128 # Dimension
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Add vectors
|
||||
vectors = np.random.random((1000, d)).astype('float32')
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
k = 5
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset < 10,000 vectors
|
||||
- Need 100% accuracy
|
||||
- Serving as baseline
|
||||
|
||||
### IndexFlatIP - Inner product (cosine similarity)
|
||||
|
||||
```python
|
||||
# For cosine similarity, normalize vectors first
|
||||
import faiss
|
||||
|
||||
d = 128
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Normalize vectors (required for cosine similarity)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
faiss.normalize_L2(query)
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Need cosine similarity
|
||||
- Recommendation systems
|
||||
- Text embeddings
|
||||
|
||||
## IVF indices (inverted file)
|
||||
|
||||
### IndexIVFFlat - Cluster-based search
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# Create IVF index with 100 clusters
|
||||
nlist = 100 # Number of clusters
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10 # Search 10 closest clusters
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `nlist`: Number of clusters (√N to 4√N recommended)
|
||||
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Dataset 10K-1M vectors
|
||||
- Need fast approximate search
|
||||
- Can afford training time
|
||||
|
||||
### Tuning nprobe
|
||||
|
||||
```python
|
||||
# Test different nprobe values
|
||||
for nprobe in [1, 5, 10, 20, 50]:
|
||||
index.nprobe = nprobe
|
||||
distances, indices = index.search(query, k)
|
||||
# Measure recall/speed trade-off
|
||||
```
|
||||
|
||||
**Guidelines:**
|
||||
- `nprobe=1`: Fastest, ~50% recall
|
||||
- `nprobe=10`: Good balance, ~95% recall
|
||||
- `nprobe=nlist`: Exact search (same as Flat)
|
||||
|
||||
## HNSW indices (graph-based)
|
||||
|
||||
### IndexHNSWFlat - Hierarchical NSW
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer (16-64)
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# Optional: Set ef_construction (build time parameter)
|
||||
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
|
||||
|
||||
# Add vectors (no training needed!)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.hnsw.efSearch = 16 # Search time parameter
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Connections per layer (16-64, default 32)
|
||||
- `efConstruction`: Build quality (40-200, higher = better)
|
||||
- `efSearch`: Search quality (16-512, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Need best quality approximate search
|
||||
- Can afford higher memory (more connections)
|
||||
- Dataset 1M-10M vectors
|
||||
|
||||
## PQ indices (product quantization)
|
||||
|
||||
### IndexPQ - Memory-efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers (divides d)
|
||||
nbits = 8 # Bits per subquantizer
|
||||
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `m`: Subquantizers (d must be divisible by m)
|
||||
- `nbits`: Bits per code (8 or 16)
|
||||
|
||||
**Memory savings:**
|
||||
- Original: d × 4 bytes (float32)
|
||||
- PQ: m bytes
|
||||
- Compression ratio: 4d/m
|
||||
|
||||
**Use when:**
|
||||
- Limited memory
|
||||
- Large datasets (> 10M vectors)
|
||||
- Can accept ~90-95% accuracy
|
||||
|
||||
### IndexIVFPQ - IVF + PQ combined
|
||||
|
||||
```python
|
||||
# Best for very large datasets
|
||||
nlist = 4096
|
||||
m = 8
|
||||
nbits = 8
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
|
||||
|
||||
# Train
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.nprobe = 32
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset > 10M vectors
|
||||
- Need fast search + low memory
|
||||
- Can accept 90-95% accuracy
|
||||
|
||||
## GPU indices
|
||||
|
||||
### Single GPU
|
||||
|
||||
```python
|
||||
import faiss
|
||||
|
||||
# Create CPU index
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
|
||||
# Move to GPU
|
||||
res = faiss.StandardGpuResources() # GPU resources
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Use normally
|
||||
index_gpu.add(vectors)
|
||||
distances, indices = index_gpu.search(query, k)
|
||||
```
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
```python
|
||||
# Use all available GPUs
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# Or specific GPUs
|
||||
gpus = [0, 1, 2, 3] # Use GPUs 0-3
|
||||
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
|
||||
```
|
||||
|
||||
**Speedup:**
|
||||
- Single GPU: 10-50× faster than CPU
|
||||
- Multi-GPU: Near-linear scaling
|
||||
|
||||
## Index factory
|
||||
|
||||
```python
|
||||
# Easy index creation with string descriptors
|
||||
index = faiss.index_factory(d, "IVF100,Flat")
|
||||
index = faiss.index_factory(d, "HNSW32")
|
||||
index = faiss.index_factory(d, "IVF4096,PQ8")
|
||||
|
||||
# Train and use
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
**Common descriptors:**
|
||||
- `"Flat"`: Exact search
|
||||
- `"IVF100,Flat"`: IVF with 100 clusters
|
||||
- `"HNSW32"`: HNSW with M=32
|
||||
- `"IVF4096,PQ8"`: IVF + PQ compression
|
||||
|
||||
## Performance comparison
|
||||
|
||||
### Search speed (1M vectors, k=10)
|
||||
|
||||
| Index | Build Time | Search Time | Memory | Recall |
|
||||
|-------|------------|-------------|--------|--------|
|
||||
| Flat | 0s | 50ms | 512 MB | 100% |
|
||||
| IVF100 | 5s | 2ms | 512 MB | 95% |
|
||||
| HNSW32 | 60s | 1ms | 1GB | 99% |
|
||||
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
|
||||
|
||||
*CPU (16 cores), 128-dim vectors*
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with Flat** - Baseline for comparison
|
||||
2. **Use IVF for medium datasets** - Good balance
|
||||
3. **Use HNSW for best quality** - If memory allows
|
||||
4. **Add PQ for memory savings** - Large datasets
|
||||
5. **GPU for > 100K vectors** - 10-50× speedup
|
||||
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
|
||||
7. **Train on representative data** - Better clustering
|
||||
8. **Save trained indices** - Avoid retraining
|
||||
|
||||
## Resources
|
||||
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **Paper**: https://arxiv.org/abs/1702.08734
|
||||
370
skills/mlops/flash-attention/SKILL.md
Normal file
370
skills/mlops/flash-attention/SKILL.md
Normal file
@@ -0,0 +1,370 @@
|
||||
---
|
||||
name: optimizing-attention-flash
|
||||
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [flash-attn, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
|
||||
|
||||
---
|
||||
|
||||
# Flash Attention - Fast Memory-Efficient Attention
|
||||
|
||||
## Quick start
|
||||
|
||||
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
|
||||
|
||||
**PyTorch native (easiest, PyTorch 2.2+)**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
|
||||
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Automatically uses Flash Attention if available
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**flash-attn library (more features)**:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q, k, v: [batch, seqlen, nheads, headdim]
|
||||
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Enable in existing PyTorch model
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Flash Attention Integration:
|
||||
- [ ] Step 1: Check PyTorch version (≥2.2)
|
||||
- [ ] Step 2: Enable Flash Attention backend
|
||||
- [ ] Step 3: Verify speedup with profiling
|
||||
- [ ] Step 4: Test accuracy matches baseline
|
||||
```
|
||||
|
||||
**Step 1: Check PyTorch version**
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
# Should be ≥2.2.0
|
||||
```
|
||||
|
||||
If <2.2, upgrade:
|
||||
```bash
|
||||
pip install --upgrade torch
|
||||
```
|
||||
|
||||
**Step 2: Enable Flash Attention backend**
|
||||
|
||||
Replace standard attention:
|
||||
```python
|
||||
# Before (standard attention)
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
|
||||
out = attn_weights @ v
|
||||
|
||||
# After (Flash Attention)
|
||||
import torch.nn.functional as F
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
```
|
||||
|
||||
Force Flash Attention backend:
|
||||
```python
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=False
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**Step 3: Verify speedup with profiling**
|
||||
|
||||
```python
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def test_attention(use_flash):
|
||||
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
if use_flash:
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True):
|
||||
return F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
|
||||
return attn @ v
|
||||
|
||||
# Benchmark
|
||||
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
|
||||
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
|
||||
|
||||
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
|
||||
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
|
||||
```
|
||||
|
||||
Expected: 2-4x speedup for sequences >512 tokens.
|
||||
|
||||
**Step 4: Test accuracy matches baseline**
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Flash Attention
|
||||
out_flash = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
# Standard attention
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
|
||||
out_standard = attn_weights @ v
|
||||
|
||||
# Check difference
|
||||
diff = (out_flash - out_standard).abs().max()
|
||||
print(f"Max difference: {diff:.6f}")
|
||||
# Should be <1e-3 for float16
|
||||
```
|
||||
|
||||
### Workflow 2: Use flash-attn library for advanced features
|
||||
|
||||
For multi-query attention, sliding window, or H100 FP8.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
flash-attn Library Setup:
|
||||
- [ ] Step 1: Install flash-attn library
|
||||
- [ ] Step 2: Modify attention code
|
||||
- [ ] Step 3: Enable advanced features
|
||||
- [ ] Step 4: Benchmark performance
|
||||
```
|
||||
|
||||
**Step 1: Install flash-attn library**
|
||||
|
||||
```bash
|
||||
# NVIDIA GPUs (CUDA 12.0+)
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# Verify installation
|
||||
python -c "from flash_attn import flash_attn_func; print('Success')"
|
||||
```
|
||||
|
||||
**Step 2: Modify attention code**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# Input: [batch_size, seq_len, num_heads, head_dim]
|
||||
# Transpose from [batch, heads, seq, dim] if needed
|
||||
q = q.transpose(1, 2) # [batch, seq, heads, dim]
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
dropout_p=0.1,
|
||||
causal=True, # For autoregressive models
|
||||
window_size=(-1, -1), # No sliding window
|
||||
softmax_scale=None # Auto-scale
|
||||
)
|
||||
|
||||
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
|
||||
```
|
||||
|
||||
**Step 3: Enable advanced features**
|
||||
|
||||
Multi-query attention (shared K/V across heads):
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q: [batch, seq, num_q_heads, dim]
|
||||
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
|
||||
out = flash_attn_func(q, k, v) # Automatically handles MQA
|
||||
```
|
||||
|
||||
Sliding window attention (local attention):
|
||||
```python
|
||||
# Only attend to window of 256 tokens before/after
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
window_size=(256, 256), # (left, right) window
|
||||
causal=True
|
||||
)
|
||||
```
|
||||
|
||||
**Step 4: Benchmark performance**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from flash_attn import flash_attn_func
|
||||
import time
|
||||
|
||||
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = flash_attn_func(q, k, v)
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
out = flash_attn_func(q, k, v)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
|
||||
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
|
||||
```
|
||||
|
||||
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
|
||||
|
||||
For maximum performance on H100 GPUs.
|
||||
|
||||
```
|
||||
FP8 Setup:
|
||||
- [ ] Step 1: Verify H100 GPU available
|
||||
- [ ] Step 2: Install flash-attn with FP8 support
|
||||
- [ ] Step 3: Convert inputs to FP8
|
||||
- [ ] Step 4: Run with FP8 attention
|
||||
```
|
||||
|
||||
**Step 1: Verify H100 GPU**
|
||||
|
||||
```bash
|
||||
nvidia-smi --query-gpu=name --format=csv
|
||||
# Should show "H100" or "H800"
|
||||
```
|
||||
|
||||
**Step 2: Install flash-attn with FP8 support**
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
# FP8 support included for H100
|
||||
```
|
||||
|
||||
**Step 3: Convert inputs to FP8**
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Convert to float8_e4m3 (FP8)
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
k_fp8 = k.to(torch.float8_e4m3fn)
|
||||
v_fp8 = v.to(torch.float8_e4m3fn)
|
||||
```
|
||||
|
||||
**Step 4: Run with FP8 attention**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# FlashAttention-3 automatically uses FP8 kernels on H100
|
||||
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
|
||||
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Flash Attention when:**
|
||||
- Training transformers with sequences >512 tokens
|
||||
- Running inference with long context (>2K tokens)
|
||||
- GPU memory constrained (OOM with standard attention)
|
||||
- Need 2-4x speedup without accuracy loss
|
||||
- Using PyTorch 2.2+ or can install flash-attn
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
|
||||
- **xFormers**: Need more attention variants (not just speed)
|
||||
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: ImportError: cannot import flash_attn**
|
||||
|
||||
Install with no-build-isolation flag:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or install CUDA toolkit first:
|
||||
```bash
|
||||
conda install cuda -c nvidia
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Issue: Slower than expected (no speedup)**
|
||||
|
||||
Flash Attention benefits increase with sequence length:
|
||||
- <512 tokens: Minimal speedup (10-20%)
|
||||
- 512-2K tokens: 2-3x speedup
|
||||
- >2K tokens: 3-4x speedup
|
||||
|
||||
Check sequence length is sufficient.
|
||||
|
||||
**Issue: RuntimeError: CUDA error**
|
||||
|
||||
Verify GPU supports Flash Attention:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.get_device_capability())
|
||||
# Should be ≥(7, 5) for Turing+
|
||||
```
|
||||
|
||||
Flash Attention requires:
|
||||
- Ampere (A100, A10): ✅ Full support
|
||||
- Turing (T4): ✅ Supported
|
||||
- Volta (V100): ❌ Not supported
|
||||
|
||||
**Issue: Accuracy degradation**
|
||||
|
||||
Check dtype is float16 or bfloat16 (not float32):
|
||||
```python
|
||||
q = q.to(torch.float16) # Or torch.bfloat16
|
||||
```
|
||||
|
||||
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
|
||||
|
||||
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
|
||||
|
||||
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
|
||||
|
||||
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
|
||||
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
|
||||
- **CUDA**: 12.0+ (11.8 minimum)
|
||||
- **PyTorch**: 2.2+ for native support
|
||||
|
||||
**Not supported**: V100 (Volta), CPU inference
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
|
||||
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
|
||||
- Blog: https://tridao.me/blog/2024/flash3/
|
||||
- GitHub: https://github.com/Dao-AILab/flash-attention
|
||||
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
|
||||
|
||||
|
||||
215
skills/mlops/flash-attention/references/benchmarks.md
Normal file
215
skills/mlops/flash-attention/references/benchmarks.md
Normal file
@@ -0,0 +1,215 @@
|
||||
# Performance Benchmarks
|
||||
|
||||
## Contents
|
||||
- Speed comparisons across GPUs
|
||||
- Memory usage analysis
|
||||
- Scaling with sequence length
|
||||
- Training vs inference performance
|
||||
- Flash Attention versions comparison
|
||||
|
||||
## Speed comparisons across GPUs
|
||||
|
||||
### A100 80GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|
||||
|------------|----------|--------------|--------------|---------------|
|
||||
| 512 | 1.2 | 0.9 | N/A | 1.3x |
|
||||
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
|
||||
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
|
||||
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
|
||||
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
|
||||
|
||||
### H100 80GB (Hopper)
|
||||
|
||||
**Forward pass time** (milliseconds, same config):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|
||||
|------------|----------|--------------|---------------------|--------------------|--------------|
|
||||
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
|
||||
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
|
||||
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
|
||||
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
|
||||
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
|
||||
|
||||
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
|
||||
|
||||
### A10G 24GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=4):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Speedup |
|
||||
|------------|----------|--------------|---------|
|
||||
| 512 | 2.1 | 1.6 | 1.3x |
|
||||
| 1024 | 6.8 | 2.8 | 2.4x |
|
||||
| 2048 | 25.9 | 9.4 | 2.8x |
|
||||
| 4096 | 102.1 | 35.2 | 2.9x |
|
||||
|
||||
## Memory usage analysis
|
||||
|
||||
### GPU memory consumption (batch=8, heads=32, dim=64)
|
||||
|
||||
**Standard attention memory**:
|
||||
|
||||
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|
||||
|------------|------------------|----------|-------|-------|
|
||||
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
|
||||
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
|
||||
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
|
||||
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
|
||||
|
||||
**Flash Attention 2 memory**:
|
||||
|
||||
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|
||||
|------------|---------------------|----------|-------|-----------|
|
||||
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
|
||||
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
|
||||
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
|
||||
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
|
||||
|
||||
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
|
||||
|
||||
### Memory scaling comparison
|
||||
|
||||
**Llama 2 7B model memory** (float16, batch=1):
|
||||
|
||||
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|
||||
|----------------|-------------------|-------------------|-------------------|
|
||||
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
|
||||
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
|
||||
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
|
||||
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
|
||||
| 32K | OOM | 14.2 GB | Only Flash: Yes |
|
||||
|
||||
### Training memory (Llama 2 7B, batch=4)
|
||||
|
||||
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|
||||
|---------|---------------|-----------------|-----------|
|
||||
| 2K | 18.2 | 12.4 | 32% |
|
||||
| 4K | 34.8 | 16.8 | 52% |
|
||||
| 8K | OOM (>40GB) | 26.2 | Fits! |
|
||||
|
||||
## Scaling with sequence length
|
||||
|
||||
### Computational complexity
|
||||
|
||||
**Standard attention**:
|
||||
- Time: O(N² × d)
|
||||
- Memory: O(N² + N × d)
|
||||
|
||||
**Flash Attention**:
|
||||
- Time: O(N² × d) (same, but with better constants)
|
||||
- Memory: O(N × d) (linear!)
|
||||
|
||||
### Empirical scaling (A100, batch=1, heads=32, dim=64)
|
||||
|
||||
**Time per token (milliseconds)**:
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
|
||||
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
|
||||
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
|
||||
|
||||
**Observation**: Speedup increases quadratically with sequence length!
|
||||
|
||||
### Memory per token (MB)
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
|
||||
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
|
||||
|
||||
**Observation**: Flash Attention memory per token is constant!
|
||||
|
||||
## Training vs inference performance
|
||||
|
||||
### Training (forward + backward, Llama 2 7B, A100)
|
||||
|
||||
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------------|------------------------|--------------------------|---------|
|
||||
| 4 × 2K | 1.2 | 3.1 | 2.6x |
|
||||
| 8 × 2K | 2.1 | 5.8 | 2.8x |
|
||||
| 4 × 4K | 0.4 | 1.3 | 3.3x |
|
||||
| 8 × 4K | OOM | 2.4 | Enabled |
|
||||
| 2 × 8K | 0.1 | 0.4 | 4.0x |
|
||||
|
||||
### Inference (generation, Llama 2 7B, A100)
|
||||
|
||||
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|----------------|----------------------|-------------------------|---------|
|
||||
| 512 | 48 | 52 | 1.1x |
|
||||
| 2K | 42 | 62 | 1.5x |
|
||||
| 4K | 31 | 58 | 1.9x |
|
||||
| 8K | 18 | 51 | 2.8x |
|
||||
| 16K | OOM | 42 | Enabled |
|
||||
|
||||
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
|
||||
|
||||
## Flash Attention versions comparison
|
||||
|
||||
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
|
||||
|
||||
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|
||||
|--------|-----|-----|------------|-----------|
|
||||
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
|
||||
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
|
||||
| TFLOPS | 180 | 420 | 740 | 1150 |
|
||||
| GPU util % | 35% | 55% | 75% | 82% |
|
||||
|
||||
**Key improvements**:
|
||||
- FA2: 2.3x faster than FA1 (better parallelism)
|
||||
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
|
||||
- FA3 (FP8): 2.6x faster than FA2 (low precision)
|
||||
|
||||
### Features by version
|
||||
|
||||
| Feature | FA1 | FA2 | FA3 |
|
||||
|---------|-----|-----|-----|
|
||||
| Basic attention | ✅ | ✅ | ✅ |
|
||||
| Causal masking | ✅ | ✅ | ✅ |
|
||||
| Multi-query attention | ❌ | ✅ | ✅ |
|
||||
| Sliding window | ❌ | ✅ | ✅ |
|
||||
| Paged KV cache | ❌ | ✅ | ✅ |
|
||||
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
|
||||
| Work partitioning | Basic | Advanced | Optimal |
|
||||
|
||||
## Real-world model benchmarks
|
||||
|
||||
### Llama 2 models (A100 80GB, batch=4, seq=2048)
|
||||
|
||||
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------|--------|------------------------|--------------------------|---------|
|
||||
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
|
||||
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
|
||||
|
||||
### GPT-style models (seq=1024)
|
||||
|
||||
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|-------|----------------------|-------------------------|---------|
|
||||
| GPT-2 (124M) | 520 | 680 | 1.3x |
|
||||
| GPT-J (6B) | 42 | 98 | 2.3x |
|
||||
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
|
||||
|
||||
## Recommendations by use case
|
||||
|
||||
**Training large models (>7B parameters)**:
|
||||
- Use Flash Attention 2 on A100
|
||||
- Use Flash Attention 3 FP8 on H100 for maximum speed
|
||||
- Expected: 2.5-3x speedup
|
||||
|
||||
**Long context inference (>4K tokens)**:
|
||||
- Flash Attention essential (enables contexts standard attention can't handle)
|
||||
- Expected: 2-4x speedup, 5-10x memory reduction
|
||||
|
||||
**Short sequences (<512 tokens)**:
|
||||
- Flash Attention provides 1.2-1.5x speedup
|
||||
- Minimal memory benefit
|
||||
- Still worth enabling (no downside)
|
||||
|
||||
**Multi-user serving**:
|
||||
- Flash Attention reduces per-request memory
|
||||
- Allows higher concurrent batch sizes
|
||||
- Can serve 2-3x more users on same hardware
|
||||
@@ -0,0 +1,293 @@
|
||||
# HuggingFace Transformers Integration
|
||||
|
||||
## Contents
|
||||
- Enabling Flash Attention in Transformers
|
||||
- Supported model architectures
|
||||
- Configuration examples
|
||||
- Performance comparisons
|
||||
- Troubleshooting model-specific issues
|
||||
|
||||
## Enabling Flash Attention in Transformers
|
||||
|
||||
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
|
||||
|
||||
**Simple enable for any supported model**:
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
**Install requirements**:
|
||||
```bash
|
||||
pip install transformers>=4.36
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Supported model architectures
|
||||
|
||||
As of Transformers 4.40:
|
||||
|
||||
**Fully supported**:
|
||||
- Llama / Llama 2 / Llama 3
|
||||
- Mistral / Mixtral
|
||||
- Falcon
|
||||
- GPT-NeoX
|
||||
- Phi / Phi-2 / Phi-3
|
||||
- Qwen / Qwen2
|
||||
- Gemma
|
||||
- Starcoder2
|
||||
- GPT-J
|
||||
- OPT
|
||||
- BLOOM
|
||||
|
||||
**Partially supported** (encoder-decoder):
|
||||
- BART
|
||||
- T5 / Flan-T5
|
||||
- Whisper
|
||||
|
||||
**Check support**:
|
||||
```python
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("model-name")
|
||||
print(config._attn_implementation_internal)
|
||||
# 'flash_attention_2' if supported
|
||||
```
|
||||
|
||||
## Configuration examples
|
||||
|
||||
### Llama 2 with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Generate
|
||||
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_length=100)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
### Mistral with Flash Attention for long context
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for long context
|
||||
device_map="auto",
|
||||
max_position_embeddings=32768 # Extended context
|
||||
)
|
||||
|
||||
# Process long document (32K tokens)
|
||||
long_text = "..." * 10000
|
||||
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
```
|
||||
|
||||
### Fine-tuning with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=3,
|
||||
fp16=True, # Must match model dtype
|
||||
optim="adamw_torch_fused" # Fast optimizer
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Multi-GPU training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
# Model parallelism with Flash Attention
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto", # Automatic multi-GPU placement
|
||||
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
|
||||
)
|
||||
```
|
||||
|
||||
## Performance comparisons
|
||||
|
||||
### Memory usage (Llama 2 7B, batch=1)
|
||||
|
||||
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|
||||
|-----------------|-------------------|-------------------|-----------|
|
||||
| 512 | 1.2 GB | 0.9 GB | 25% |
|
||||
| 2048 | 3.8 GB | 1.4 GB | 63% |
|
||||
| 8192 | 14.2 GB | 3.2 GB | 77% |
|
||||
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
|
||||
|
||||
### Speed (tokens/sec, A100 80GB)
|
||||
|
||||
| Model | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|----------|--------------|---------|
|
||||
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
|
||||
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
|
||||
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
|
||||
|
||||
### Training throughput (samples/sec)
|
||||
|
||||
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|------------|----------|--------------|---------|
|
||||
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
|
||||
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
|
||||
|
||||
## Troubleshooting model-specific issues
|
||||
|
||||
### Issue: Model doesn't support Flash Attention
|
||||
|
||||
Check support list above. If not supported, use PyTorch SDPA as fallback:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="sdpa", # PyTorch native (still faster)
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: CUDA out of memory during loading
|
||||
|
||||
Reduce memory footprint:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
max_memory={0: "18GB"}, # Reserve memory for KV cache
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slower inference than expected
|
||||
|
||||
Ensure dtype matches:
|
||||
|
||||
```python
|
||||
# Model and inputs must both be float16/bfloat16
|
||||
model = model.to(torch.float16)
|
||||
inputs = tokenizer(..., return_tensors="pt").to("cuda")
|
||||
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()}
|
||||
```
|
||||
|
||||
### Issue: Different outputs vs standard attention
|
||||
|
||||
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
|
||||
model_flash = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
out_standard = model_standard(**inputs).logits
|
||||
out_flash = model_flash(**inputs).logits
|
||||
|
||||
diff = (out_standard - out_flash).abs().max()
|
||||
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
|
||||
```
|
||||
|
||||
### Issue: ImportError during model loading
|
||||
|
||||
Install flash-attn:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or disable Flash Attention:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="eager", # Standard PyTorch
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use float16/bfloat16** with Flash Attention (not float32)
|
||||
2. **Set device_map="auto"** for automatic memory management
|
||||
3. **Use bfloat16 for long context** (better numerical stability)
|
||||
4. **Enable gradient checkpointing** for training large models
|
||||
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
|
||||
|
||||
**Example with all best practices**:
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for training
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Enable gradient checkpointing for memory
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Training with optimizations
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
bf16=True, # Match model dtype
|
||||
optim="adamw_torch_fused",
|
||||
gradient_checkpointing=True
|
||||
)
|
||||
```
|
||||
430
skills/mlops/gguf/SKILL.md
Normal file
430
skills/mlops/gguf/SKILL.md
Normal file
@@ -0,0 +1,430 @@
|
||||
---
|
||||
name: gguf-quantization
|
||||
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
|
||||
|
||||
---
|
||||
|
||||
# GGUF - Quantization Format for llama.cpp
|
||||
|
||||
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
|
||||
|
||||
## When to use GGUF
|
||||
|
||||
**Use GGUF when:**
|
||||
- Deploying on consumer hardware (laptops, desktops)
|
||||
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
|
||||
- Need CPU inference without GPU requirements
|
||||
- Want flexible quantization (Q2_K to Q8_0)
|
||||
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
|
||||
|
||||
**Key advantages:**
|
||||
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
|
||||
- **No Python runtime**: Pure C/C++ inference
|
||||
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
|
||||
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
|
||||
- **imatrix**: Importance matrix for better low-bit quality
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
|
||||
- **HQQ**: Fast calibration-free quantization for HuggingFace
|
||||
- **bitsandbytes**: Simple integration with transformers library
|
||||
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
|
||||
# Build (CPU)
|
||||
make
|
||||
|
||||
# Build with CUDA (NVIDIA)
|
||||
make GGML_CUDA=1
|
||||
|
||||
# Build with Metal (Apple Silicon)
|
||||
make GGML_METAL=1
|
||||
|
||||
# Install Python bindings (optional)
|
||||
pip install llama-cpp-python
|
||||
```
|
||||
|
||||
### Convert model to GGUF
|
||||
|
||||
```bash
|
||||
# Install requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Convert HuggingFace model to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
|
||||
|
||||
# Or specify output type
|
||||
python convert_hf_to_gguf.py ./path/to/model \
|
||||
--outfile model-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
### Quantize model
|
||||
|
||||
```bash
|
||||
# Basic quantization to Q4_K_M
|
||||
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# Quantize with importance matrix (better quality)
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# CLI inference
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
|
||||
|
||||
# Interactive mode
|
||||
./llama-cli -m model-q4_k_m.gguf --interactive
|
||||
|
||||
# With GPU offload
|
||||
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
|
||||
```
|
||||
|
||||
## Quantization types
|
||||
|
||||
### K-quant methods (recommended)
|
||||
|
||||
| Type | Bits | Size (7B) | Quality | Use Case |
|
||||
|------|------|-----------|---------|----------|
|
||||
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
|
||||
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
|
||||
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
|
||||
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
|
||||
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
|
||||
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
|
||||
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
|
||||
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
|
||||
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
|
||||
|
||||
### Legacy methods
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| Q4_0 | 4-bit, basic |
|
||||
| Q4_1 | 4-bit with delta |
|
||||
| Q5_0 | 5-bit, basic |
|
||||
| Q5_1 | 5-bit with delta |
|
||||
|
||||
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
|
||||
|
||||
## Conversion workflows
|
||||
|
||||
### Workflow 1: HuggingFace to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download model
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
|
||||
|
||||
# 2. Convert to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./llama-3.1-8b \
|
||||
--outfile llama-3.1-8b-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# 3. Quantize
|
||||
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# 4. Test
|
||||
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
|
||||
```
|
||||
|
||||
### Workflow 2: With importance matrix (better quality)
|
||||
|
||||
```bash
|
||||
# 1. Convert to GGUF
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
|
||||
|
||||
# 2. Create calibration text (diverse samples)
|
||||
cat > calibration.txt << 'EOF'
|
||||
The quick brown fox jumps over the lazy dog.
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
Python is a popular programming language.
|
||||
# Add more diverse text samples...
|
||||
EOF
|
||||
|
||||
# 3. Generate importance matrix
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f calibration.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix \
|
||||
-ngl 35 # GPU layers if available
|
||||
|
||||
# 4. Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf \
|
||||
model-q4_k_m.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
### Workflow 3: Multiple quantizations
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
MODEL="llama-3.1-8b-f16.gguf"
|
||||
IMATRIX="llama-3.1-8b.imatrix"
|
||||
|
||||
# Generate imatrix once
|
||||
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
|
||||
|
||||
# Create multiple quantizations
|
||||
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
|
||||
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
|
||||
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
|
||||
done
|
||||
```
|
||||
|
||||
## Python usage
|
||||
|
||||
### llama-cpp-python
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Load model
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096, # Context window
|
||||
n_gpu_layers=35, # GPU offload (0 for CPU only)
|
||||
n_threads=8 # CPU threads
|
||||
)
|
||||
|
||||
# Generate
|
||||
output = llm(
|
||||
"What is machine learning?",
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
stop=["</s>", "\n\n"]
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat completion
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3" # Or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"}
|
||||
]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=messages,
|
||||
max_tokens=256,
|
||||
temperature=0.7
|
||||
)
|
||||
print(response["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
# Stream tokens
|
||||
for chunk in llm(
|
||||
"Explain quantum computing:",
|
||||
max_tokens=256,
|
||||
stream=True
|
||||
):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
### Start OpenAI-compatible server
|
||||
|
||||
```bash
|
||||
# Start server
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096
|
||||
|
||||
# Or with Python bindings
|
||||
python -m llama_cpp.server \
|
||||
--model model-q4_k_m.gguf \
|
||||
--n_gpu_layers 35 \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080
|
||||
```
|
||||
|
||||
### Use with OpenAI client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=256
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Hardware optimization
|
||||
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make clean && make GGML_METAL=1
|
||||
|
||||
# Run with Metal acceleration
|
||||
./llama-cli -m model.gguf -ngl 99 -p "Hello"
|
||||
|
||||
# Python with Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload all layers
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
### NVIDIA CUDA
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make clean && make GGML_CUDA=1
|
||||
|
||||
# Run with CUDA
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
|
||||
# Specify GPU
|
||||
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
|
||||
```
|
||||
|
||||
### CPU optimization
|
||||
|
||||
```bash
|
||||
# Build with AVX2/AVX512
|
||||
make clean && make
|
||||
|
||||
# Run with optimal threads
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
|
||||
# Python CPU config
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=0, # CPU only
|
||||
n_threads=8, # Match physical cores
|
||||
n_batch=512 # Batch size for prompt processing
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with tools
|
||||
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
# Create Modelfile
|
||||
cat > Modelfile << 'EOF'
|
||||
FROM ./model-q4_k_m.gguf
|
||||
TEMPLATE """{{ .System }}
|
||||
{{ .Prompt }}"""
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER num_ctx 4096
|
||||
EOF
|
||||
|
||||
# Create Ollama model
|
||||
ollama create mymodel -f Modelfile
|
||||
|
||||
# Run
|
||||
ollama run mymodel "Hello!"
|
||||
```
|
||||
|
||||
### LM Studio
|
||||
|
||||
1. Place GGUF file in `~/.cache/lm-studio/models/`
|
||||
2. Open LM Studio and select the model
|
||||
3. Configure context length and GPU offload
|
||||
4. Start inference
|
||||
|
||||
### text-generation-webui
|
||||
|
||||
```bash
|
||||
# Place in models folder
|
||||
cp model-q4_k_m.gguf text-generation-webui/models/
|
||||
|
||||
# Start with llama.cpp loader
|
||||
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use K-quants**: Q4_K_M offers best quality/size balance
|
||||
2. **Use imatrix**: Always use importance matrix for Q4 and below
|
||||
3. **GPU offload**: Offload as many layers as VRAM allows
|
||||
4. **Context length**: Start with 4096, increase if needed
|
||||
5. **Thread count**: Match physical CPU cores, not logical
|
||||
6. **Batch size**: Increase n_batch for faster prompt processing
|
||||
|
||||
## Common issues
|
||||
|
||||
**Model loads slowly:**
|
||||
```bash
|
||||
# Use mmap for faster loading
|
||||
./llama-cli -m model.gguf --mmap
|
||||
```
|
||||
|
||||
**Out of memory:**
|
||||
```bash
|
||||
# Reduce GPU layers
|
||||
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
|
||||
|
||||
# Or use smaller quantization
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
**Poor quality at low bits:**
|
||||
```bash
|
||||
# Always use imatrix for Q4 and below
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
|
||||
|
||||
## Resources
|
||||
|
||||
- **Repository**: https://github.com/ggml-org/llama.cpp
|
||||
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
|
||||
- **Pre-quantized Models**: https://huggingface.co/TheBloke
|
||||
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
|
||||
- **License**: MIT
|
||||
504
skills/mlops/gguf/references/advanced-usage.md
Normal file
504
skills/mlops/gguf/references/advanced-usage.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# GGUF Advanced Usage Guide
|
||||
|
||||
## Speculative Decoding
|
||||
|
||||
### Draft Model Approach
|
||||
|
||||
```bash
|
||||
# Use smaller model as draft for faster generation
|
||||
./llama-speculative \
|
||||
-m large-model-q4_k_m.gguf \
|
||||
-md draft-model-q4_k_m.gguf \
|
||||
-p "Write a story about AI" \
|
||||
-n 500 \
|
||||
--draft 8 # Draft tokens before verification
|
||||
```
|
||||
|
||||
### Self-Speculative Decoding
|
||||
|
||||
```bash
|
||||
# Use same model with different context for speculation
|
||||
./llama-cli -m model-q4_k_m.gguf \
|
||||
--lookup-cache-static lookup.bin \
|
||||
--lookup-cache-dynamic lookup-dynamic.bin \
|
||||
-p "Hello world"
|
||||
```
|
||||
|
||||
## Batched Inference
|
||||
|
||||
### Process Multiple Prompts
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
n_batch=512 # Larger batch for parallel processing
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"What is Python?",
|
||||
"Explain machine learning.",
|
||||
"Describe neural networks."
|
||||
]
|
||||
|
||||
# Process in batch (each prompt gets separate context)
|
||||
for prompt in prompts:
|
||||
output = llm(prompt, max_tokens=100)
|
||||
print(f"Q: {prompt}")
|
||||
print(f"A: {output['choices'][0]['text']}\n")
|
||||
```
|
||||
|
||||
### Server Batching
|
||||
|
||||
```bash
|
||||
# Start server with batching
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096 \
|
||||
--parallel 4 # Concurrent requests
|
||||
--cont-batching # Continuous batching
|
||||
```
|
||||
|
||||
## Custom Model Conversion
|
||||
|
||||
### Convert with Vocabulary Modifications
|
||||
|
||||
```python
|
||||
# custom_convert.py
|
||||
import sys
|
||||
sys.path.insert(0, './llama.cpp')
|
||||
|
||||
from convert_hf_to_gguf import main
|
||||
from gguf import GGUFWriter
|
||||
|
||||
# Custom conversion with modified vocab
|
||||
def convert_with_custom_vocab(model_path, output_path):
|
||||
# Load and modify tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Add special tokens if needed
|
||||
special_tokens = {"additional_special_tokens": ["<|custom|>"]}
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
|
||||
# Then run standard conversion
|
||||
main([model_path, "--outfile", output_path])
|
||||
```
|
||||
|
||||
### Convert Specific Architecture
|
||||
|
||||
```bash
|
||||
# For Mistral-style models
|
||||
python convert_hf_to_gguf.py ./mistral-model \
|
||||
--outfile mistral-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Qwen models
|
||||
python convert_hf_to_gguf.py ./qwen-model \
|
||||
--outfile qwen-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Phi models
|
||||
python convert_hf_to_gguf.py ./phi-model \
|
||||
--outfile phi-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
## Advanced Quantization
|
||||
|
||||
### Mixed Quantization
|
||||
|
||||
```bash
|
||||
# Quantize different layer types differently
|
||||
./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \
|
||||
--allow-requantize \
|
||||
--leave-output-tensor
|
||||
```
|
||||
|
||||
### Quantization with Token Embeddings
|
||||
|
||||
```bash
|
||||
# Keep embeddings at higher precision
|
||||
./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \
|
||||
--token-embedding-type f16
|
||||
```
|
||||
|
||||
### IQ Quantization (Importance-aware)
|
||||
|
||||
```bash
|
||||
# Ultra-low bit quantization with importance
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-iq2_xxs.gguf IQ2_XXS
|
||||
|
||||
# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Memory Mapping
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Use memory mapping for large models
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
use_mmap=True, # Memory map the model
|
||||
use_mlock=False, # Don't lock in RAM
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Partial GPU Offload
|
||||
|
||||
```python
|
||||
# Calculate layers to offload based on VRAM
|
||||
import subprocess
|
||||
|
||||
def get_free_vram_gb():
|
||||
result = subprocess.run(
|
||||
['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return int(result.stdout.strip()) / 1024
|
||||
|
||||
# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4)
|
||||
free_vram = get_free_vram_gb()
|
||||
layers_to_offload = int(free_vram / 0.5)
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers
|
||||
)
|
||||
```
|
||||
|
||||
### KV Cache Optimization
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Optimize KV cache for long contexts
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=8192, # Large context
|
||||
n_gpu_layers=35,
|
||||
type_k=1, # Q8_0 for K cache (1)
|
||||
type_v=1, # Q8_0 for V cache (1)
|
||||
# Or use Q4_0 (2) for more compression
|
||||
)
|
||||
```
|
||||
|
||||
## Context Management
|
||||
|
||||
### Context Shifting
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Handle long conversations with context shifting
|
||||
conversation = []
|
||||
max_history = 10
|
||||
|
||||
def chat(user_message):
|
||||
conversation.append({"role": "user", "content": user_message})
|
||||
|
||||
# Keep only recent history
|
||||
if len(conversation) > max_history * 2:
|
||||
conversation = conversation[-max_history * 2:]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=conversation,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
assistant_message = response["choices"][0]["message"]["content"]
|
||||
conversation.append({"role": "assistant", "content": assistant_message})
|
||||
return assistant_message
|
||||
```
|
||||
|
||||
### Save and Load State
|
||||
|
||||
```bash
|
||||
# Save state to file
|
||||
./llama-cli -m model.gguf \
|
||||
-p "Once upon a time" \
|
||||
--save-session session.bin \
|
||||
-n 100
|
||||
|
||||
# Load and continue
|
||||
./llama-cli -m model.gguf \
|
||||
--load-session session.bin \
|
||||
-p " and they lived" \
|
||||
-n 100
|
||||
```
|
||||
|
||||
## Grammar Constrained Generation
|
||||
|
||||
### JSON Output
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
|
||||
# Define JSON grammar
|
||||
json_grammar = LlamaGrammar.from_string('''
|
||||
root ::= object
|
||||
object ::= "{" ws pair ("," ws pair)* "}" ws
|
||||
pair ::= string ":" ws value
|
||||
value ::= string | number | object | array | "true" | "false" | "null"
|
||||
array ::= "[" ws value ("," ws value)* "]" ws
|
||||
string ::= "\\"" [^"\\\\]* "\\""
|
||||
number ::= [0-9]+
|
||||
ws ::= [ \\t\\n]*
|
||||
''')
|
||||
|
||||
llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
output = llm(
|
||||
"Output a JSON object with name and age:",
|
||||
grammar=json_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Custom Grammar
|
||||
|
||||
```python
|
||||
# Grammar for specific format
|
||||
answer_grammar = LlamaGrammar.from_string('''
|
||||
root ::= "Answer: " letter "\\n" "Explanation: " explanation
|
||||
letter ::= [A-D]
|
||||
explanation ::= [a-zA-Z0-9 .,!?]+
|
||||
''')
|
||||
|
||||
output = llm(
|
||||
"Q: What is 2+2? A) 3 B) 4 C) 5 D) 6",
|
||||
grammar=answer_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Integration
|
||||
|
||||
### Load LoRA Adapter
|
||||
|
||||
```bash
|
||||
# Apply LoRA at runtime
|
||||
./llama-cli -m base-model-q4_k_m.gguf \
|
||||
--lora lora-adapter.gguf \
|
||||
--lora-scale 1.0 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Multiple LoRA Adapters
|
||||
|
||||
```bash
|
||||
# Stack multiple adapters
|
||||
./llama-cli -m base-model.gguf \
|
||||
--lora adapter1.gguf --lora-scale 0.5 \
|
||||
--lora adapter2.gguf --lora-scale 0.5 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python LoRA Usage
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="base-model-q4_k_m.gguf",
|
||||
lora_path="lora-adapter.gguf",
|
||||
lora_scale=1.0,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding Generation
|
||||
|
||||
### Extract Embeddings
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
embedding=True, # Enable embedding mode
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
embeddings = llm.embed("This is a test sentence.")
|
||||
print(f"Embedding dimension: {len(embeddings)}")
|
||||
```
|
||||
|
||||
### Batch Embeddings
|
||||
|
||||
```python
|
||||
texts = [
|
||||
"Machine learning is fascinating.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Python is a programming language."
|
||||
]
|
||||
|
||||
embeddings = [llm.embed(text) for text in texts]
|
||||
|
||||
# Calculate similarity
|
||||
import numpy as np
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
sim = cosine_similarity(embeddings[0], embeddings[1])
|
||||
print(f"Similarity: {sim:.4f}")
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Benchmark Script
|
||||
|
||||
```python
|
||||
import time
|
||||
from llama_cpp import Llama
|
||||
|
||||
def benchmark(model_path, prompt, n_tokens=100, n_runs=5):
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=35,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Warmup
|
||||
llm(prompt, max_tokens=10)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.time()
|
||||
output = llm(prompt, max_tokens=n_tokens)
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
tokens_per_sec = n_tokens / avg_time
|
||||
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Avg time: {avg_time:.2f}s")
|
||||
print(f"Tokens/sec: {tokens_per_sec:.1f}")
|
||||
|
||||
return tokens_per_sec
|
||||
|
||||
# Compare quantizations
|
||||
for quant in ["q4_k_m", "q5_k_m", "q8_0"]:
|
||||
benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100)
|
||||
```
|
||||
|
||||
### Optimal Configuration Finder
|
||||
|
||||
```python
|
||||
def find_optimal_config(model_path, target_vram_gb=8):
|
||||
"""Find optimal n_gpu_layers and n_batch for target VRAM."""
|
||||
from llama_cpp import Llama
|
||||
import gc
|
||||
|
||||
best_config = None
|
||||
best_speed = 0
|
||||
|
||||
for n_gpu_layers in range(0, 50, 5):
|
||||
for n_batch in [128, 256, 512, 1024]:
|
||||
try:
|
||||
gc.collect()
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=n_gpu_layers,
|
||||
n_batch=n_batch,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Quick benchmark
|
||||
start = time.time()
|
||||
llm("Hello", max_tokens=50)
|
||||
speed = 50 / (time.time() - start)
|
||||
|
||||
if speed > best_speed:
|
||||
best_speed = speed
|
||||
best_config = {
|
||||
"n_gpu_layers": n_gpu_layers,
|
||||
"n_batch": n_batch,
|
||||
"speed": speed
|
||||
}
|
||||
|
||||
del llm
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
print(f"OOM at layers={n_gpu_layers}, batch={n_batch}")
|
||||
break
|
||||
|
||||
return best_config
|
||||
```
|
||||
|
||||
## Multi-GPU Setup
|
||||
|
||||
### Distribute Across GPUs
|
||||
|
||||
```bash
|
||||
# Split model across multiple GPUs
|
||||
./llama-cli -m large-model.gguf \
|
||||
--tensor-split 0.5,0.5 \
|
||||
-ngl 60 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python Multi-GPU
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="large-model-q4_k_m.gguf",
|
||||
n_gpu_layers=60,
|
||||
tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs
|
||||
)
|
||||
```
|
||||
|
||||
## Custom Builds
|
||||
|
||||
### Build with All Optimizations
|
||||
|
||||
```bash
|
||||
# Clean build with all CPU optimizations
|
||||
make clean
|
||||
LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j
|
||||
|
||||
# With CUDA and cuBLAS
|
||||
make clean
|
||||
GGML_CUDA=1 LLAMA_CUBLAS=1 make -j
|
||||
|
||||
# With specific CUDA architecture
|
||||
GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j
|
||||
```
|
||||
|
||||
### CMake Build
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build . --config Release -j
|
||||
```
|
||||
442
skills/mlops/gguf/references/troubleshooting.md
Normal file
442
skills/mlops/gguf/references/troubleshooting.md
Normal file
@@ -0,0 +1,442 @@
|
||||
# GGUF Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Build Fails
|
||||
|
||||
**Error**: `make: *** No targets specified and no makefile found`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Ensure you're in llama.cpp directory
|
||||
cd llama.cpp
|
||||
make
|
||||
```
|
||||
|
||||
**Error**: `fatal error: cuda_runtime.h: No such file or directory`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install CUDA toolkit
|
||||
# Ubuntu
|
||||
sudo apt install nvidia-cuda-toolkit
|
||||
|
||||
# Or set CUDA path
|
||||
export CUDA_PATH=/usr/local/cuda
|
||||
export PATH=$CUDA_PATH/bin:$PATH
|
||||
make GGML_CUDA=1
|
||||
```
|
||||
|
||||
### Python Bindings Issues
|
||||
|
||||
**Error**: `ERROR: Failed building wheel for llama-cpp-python`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install build dependencies
|
||||
pip install cmake scikit-build-core
|
||||
|
||||
# For CUDA support
|
||||
CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
|
||||
# For Metal (macOS)
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
```
|
||||
|
||||
**Error**: `ImportError: libcudart.so.XX: cannot open shared object file`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Add CUDA libraries to path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Or reinstall with correct CUDA version
|
||||
pip uninstall llama-cpp-python
|
||||
CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
|
||||
```
|
||||
|
||||
## Conversion Issues
|
||||
|
||||
### Model Not Supported
|
||||
|
||||
**Error**: `KeyError: 'model.embed_tokens.weight'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check model architecture
|
||||
python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)"
|
||||
|
||||
# Use appropriate conversion script
|
||||
# For most models:
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# For older models, check if legacy script needed
|
||||
```
|
||||
|
||||
### Vocabulary Mismatch
|
||||
|
||||
**Error**: `RuntimeError: Vocabulary size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure tokenizer matches model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("./model")
|
||||
model = AutoModelForCausalLM.from_pretrained("./model")
|
||||
|
||||
print(f"Tokenizer vocab size: {len(tokenizer)}")
|
||||
print(f"Model vocab size: {model.config.vocab_size}")
|
||||
|
||||
# If mismatch, resize embeddings before conversion
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.save_pretrained("./model-fixed")
|
||||
```
|
||||
|
||||
### Out of Memory During Conversion
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during conversion
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Use CPU for conversion
|
||||
CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# Or use low memory mode
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16
|
||||
```
|
||||
|
||||
## Quantization Issues
|
||||
|
||||
### Wrong Output File Size
|
||||
|
||||
**Problem**: Quantized file is larger than expected
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify quantization type
|
||||
./llama-cli -m model.gguf --verbose
|
||||
|
||||
# Expected sizes for 7B model:
|
||||
# Q4_K_M: ~4.1 GB
|
||||
# Q5_K_M: ~4.8 GB
|
||||
# Q8_0: ~7.2 GB
|
||||
# F16: ~13.5 GB
|
||||
```
|
||||
|
||||
### Quantization Crashes
|
||||
|
||||
**Error**: `Segmentation fault` during quantization
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Increase stack size
|
||||
ulimit -s unlimited
|
||||
|
||||
# Or use less threads
|
||||
./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Poor Quality After Quantization
|
||||
|
||||
**Problem**: Model outputs gibberish after quantization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use importance matrix**:
|
||||
```bash
|
||||
# Generate imatrix with good calibration data
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f wiki_sample.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix
|
||||
|
||||
# Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
2. **Try higher precision**:
|
||||
```bash
|
||||
# Use Q5_K_M or Q6_K instead of Q4
|
||||
./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M
|
||||
```
|
||||
|
||||
3. **Check original model**:
|
||||
```bash
|
||||
# Test FP16 version first
|
||||
./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Generation is slower than expected
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable GPU offload**:
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
```
|
||||
|
||||
2. **Optimize batch size**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_batch=512, # Increase for faster prompt processing
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use appropriate threads**:
|
||||
```bash
|
||||
# Match physical cores, not logical
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
```
|
||||
|
||||
4. **Enable Flash Attention** (if supported):
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello"
|
||||
```
|
||||
|
||||
### Out of Memory
|
||||
|
||||
**Error**: `CUDA out of memory` or system freeze
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce GPU layers**:
|
||||
```python
|
||||
# Start low and increase
|
||||
llm = Llama(model_path="model.gguf", n_gpu_layers=10)
|
||||
```
|
||||
|
||||
2. **Use smaller quantization**:
|
||||
```bash
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
3. **Reduce context length**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_ctx=2048, # Reduce from 4096
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
4. **Quantize KV cache**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
type_k=2, # Q4_0 for K cache
|
||||
type_v=2, # Q4_0 for V cache
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Garbage Output
|
||||
|
||||
**Problem**: Model outputs random characters or nonsense
|
||||
|
||||
**Diagnose**:
|
||||
```python
|
||||
# Check model loading
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
|
||||
# Test with simple prompt
|
||||
output = llm("1+1=", max_tokens=5, temperature=0)
|
||||
print(output)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check model integrity**:
|
||||
```bash
|
||||
# Verify GGUF file
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -50
|
||||
```
|
||||
|
||||
2. **Use correct chat format**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
chat_format="llama-3" # Match your model: chatml, mistral, etc.
|
||||
)
|
||||
```
|
||||
|
||||
3. **Check temperature**:
|
||||
```python
|
||||
# Use lower temperature for deterministic output
|
||||
output = llm("Hello", max_tokens=50, temperature=0.1)
|
||||
```
|
||||
|
||||
### Token Issues
|
||||
|
||||
**Error**: `RuntimeError: unknown token` or encoding errors
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure UTF-8 encoding
|
||||
prompt = "Hello, world!".encode('utf-8').decode('utf-8')
|
||||
output = llm(prompt, max_tokens=50)
|
||||
```
|
||||
|
||||
## Server Issues
|
||||
|
||||
### Connection Refused
|
||||
|
||||
**Error**: `Connection refused` when accessing server
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Bind to all interfaces
|
||||
./llama-server -m model.gguf --host 0.0.0.0 --port 8080
|
||||
|
||||
# Check if port is in use
|
||||
lsof -i :8080
|
||||
```
|
||||
|
||||
### Server Crashes Under Load
|
||||
|
||||
**Problem**: Server crashes with multiple concurrent requests
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit parallelism**:
|
||||
```bash
|
||||
./llama-server -m model.gguf \
|
||||
--parallel 2 \
|
||||
-c 4096 \
|
||||
--cont-batching
|
||||
```
|
||||
|
||||
2. **Add request timeout**:
|
||||
```bash
|
||||
./llama-server -m model.gguf --timeout 300
|
||||
```
|
||||
|
||||
3. **Monitor memory**:
|
||||
```bash
|
||||
watch -n 1 nvidia-smi # For GPU
|
||||
watch -n 1 free -h # For RAM
|
||||
```
|
||||
|
||||
### API Compatibility Issues
|
||||
|
||||
**Problem**: OpenAI client not working with server
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Use correct base URL format
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1", # Include /v1
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
# Use correct model name
|
||||
response = client.chat.completions.create(
|
||||
model="local", # Or the actual model name
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Apple Silicon Issues
|
||||
|
||||
### Metal Not Working
|
||||
|
||||
**Problem**: Metal acceleration not enabled
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify Metal support
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal
|
||||
```
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Rebuild with Metal
|
||||
make clean
|
||||
make GGML_METAL=1
|
||||
|
||||
# Python bindings
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall
|
||||
```
|
||||
|
||||
### Incorrect Memory Usage on M1/M2
|
||||
|
||||
**Problem**: Model uses too much unified memory
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Offload all layers for Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload everything
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Enable Verbose Output
|
||||
|
||||
```bash
|
||||
# CLI verbose mode
|
||||
./llama-cli -m model.gguf --verbose -p "Hello" -n 50
|
||||
|
||||
# Python verbose
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
```
|
||||
|
||||
### Check Model Metadata
|
||||
|
||||
```bash
|
||||
# View GGUF metadata
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -100
|
||||
```
|
||||
|
||||
### Validate GGUF File
|
||||
|
||||
```python
|
||||
import struct
|
||||
|
||||
def validate_gguf(filepath):
|
||||
with open(filepath, 'rb') as f:
|
||||
magic = f.read(4)
|
||||
if magic != b'GGUF':
|
||||
print(f"Invalid magic: {magic}")
|
||||
return False
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
print(f"GGUF version: {version}")
|
||||
|
||||
tensor_count = struct.unpack('<Q', f.read(8))[0]
|
||||
metadata_count = struct.unpack('<Q', f.read(8))[0]
|
||||
print(f"Tensors: {tensor_count}, Metadata: {metadata_count}")
|
||||
|
||||
return True
|
||||
|
||||
validate_gguf("model.gguf")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/ggml-org/llama.cpp/issues
|
||||
2. **Discussions**: https://github.com/ggml-org/llama.cpp/discussions
|
||||
3. **Reddit**: r/LocalLLaMA
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- llama.cpp version/commit hash
|
||||
- Build command used
|
||||
- Model name and quantization
|
||||
- Full error message/stack trace
|
||||
- Hardware: CPU/GPU model, RAM, VRAM
|
||||
- OS version
|
||||
- Minimal reproduction steps
|
||||
97
skills/mlops/grpo-rl-training/README.md
Normal file
97
skills/mlops/grpo-rl-training/README.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# GRPO/RL Training Skill
|
||||
|
||||
**Expert-level guidance for Group Relative Policy Optimization with TRL**
|
||||
|
||||
## 📁 Skill Structure
|
||||
|
||||
```
|
||||
grpo-rl-training/
|
||||
├── SKILL.md # Main skill documentation (READ THIS FIRST)
|
||||
├── README.md # This file
|
||||
├── templates/
|
||||
│ └── basic_grpo_training.py # Production-ready training template
|
||||
└── examples/
|
||||
└── reward_functions_library.py # 20+ reward function examples
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
|
||||
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
|
||||
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
|
||||
4. **Modify for your use case** - Adapt dataset, rewards, and config
|
||||
|
||||
## 💡 What's Inside
|
||||
|
||||
### SKILL.md (Main Documentation)
|
||||
- Core GRPO concepts and algorithm fundamentals
|
||||
- Complete implementation workflow (dataset → rewards → training → deployment)
|
||||
- 10+ reward function examples with code
|
||||
- Hyperparameter tuning guide
|
||||
- Training insights (loss behavior, metrics, debugging)
|
||||
- Troubleshooting guide
|
||||
- Production best practices
|
||||
|
||||
### Templates
|
||||
- **basic_grpo_training.py**: Minimal, production-ready training script
|
||||
- Uses Qwen 2.5 1.5B Instruct
|
||||
- 3 reward functions (format + correctness)
|
||||
- LoRA for efficient training
|
||||
- Fully documented and ready to run
|
||||
|
||||
### Examples
|
||||
- **reward_functions_library.py**: 20+ battle-tested reward functions
|
||||
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
|
||||
- Format rewards (XML, JSON, strict/soft)
|
||||
- Length rewards (ideal length, min/max)
|
||||
- Style rewards (reasoning quality, citations, repetition penalty)
|
||||
- Combined rewards (multi-objective optimization)
|
||||
- Preset collections for common tasks
|
||||
|
||||
## 📖 Usage for Agents
|
||||
|
||||
When this skill is loaded in your agent's context:
|
||||
|
||||
1. **Always read SKILL.md first** before implementing
|
||||
2. **Start simple** - Use length-based reward to validate setup
|
||||
3. **Build incrementally** - Add one reward function at a time
|
||||
4. **Reference examples** - Copy patterns from reward_functions_library.py
|
||||
5. **Monitor training** - Watch reward metrics (not loss!)
|
||||
|
||||
## 🎯 Common Use Cases
|
||||
|
||||
| Task Type | Recommended Rewards | Template |
|
||||
|-----------|---------------------|----------|
|
||||
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
|
||||
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
|
||||
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
|
||||
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
|
||||
|
||||
## ⚠️ Critical Reminders
|
||||
|
||||
- **Loss goes UP during training** - This is normal (it's KL divergence)
|
||||
- **Use 3-5 reward functions** - Single rewards often fail
|
||||
- **Test rewards before training** - Debug each function independently
|
||||
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with num_generations=4-8** - Scale up if GPU allows
|
||||
|
||||
## 🔗 External Resources
|
||||
|
||||
- [TRL Documentation](https://huggingface.co/docs/trl)
|
||||
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
|
||||
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
|
||||
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
|
||||
|
||||
## 📝 Version
|
||||
|
||||
**v1.0.0** - Initial release (January 2025)
|
||||
|
||||
## 👨💻 Maintained By
|
||||
|
||||
Orchestra Research
|
||||
For questions or improvements, see https://orchestra.com
|
||||
|
||||
---
|
||||
|
||||
**License:** MIT
|
||||
**Last Updated:** January 2025
|
||||
575
skills/mlops/grpo-rl-training/SKILL.md
Normal file
575
skills/mlops/grpo-rl-training/SKILL.md
Normal file
@@ -0,0 +1,575 @@
|
||||
---
|
||||
name: grpo-rl-training
|
||||
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
|
||||
|
||||
---
|
||||
|
||||
# GRPO/RL Training with TRL
|
||||
|
||||
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use GRPO training when you need to:
|
||||
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks (use SFT instead)
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs (use DPO/PPO instead)
|
||||
|
||||
---
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. GRPO Algorithm Fundamentals
|
||||
|
||||
**Key Mechanism:**
|
||||
- Generates **multiple completions** for each prompt (group size: 4-16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical Difference from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical Intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
2. Compute rewards: {r₁, r₂, ..., rₙ}
|
||||
3. Learn to increase probability of high-reward completions
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward Function Design Philosophy
|
||||
|
||||
**Golden Rules:**
|
||||
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** - Higher weight = stronger signal
|
||||
3. **Use incremental rewards** - Partial credit for partial compliance
|
||||
4. **Test rewards independently** - Debug each reward function in isolation
|
||||
|
||||
**Reward Function Types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5-1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
|
||||
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Workflow
|
||||
|
||||
### Step 1: Dataset Preparation
|
||||
|
||||
**Critical Requirements:**
|
||||
- Prompts in chat format (list of dicts with 'role' and 'content')
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
**Example Structure:**
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""
|
||||
Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
- 'answer': str (ground truth, optional but recommended)
|
||||
"""
|
||||
return raw_data.map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': extract_answer(x['raw_answer'])
|
||||
})
|
||||
```
|
||||
|
||||
**Pro Tips:**
|
||||
- Use one-shot or few-shot examples in system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256-512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward Function Implementation
|
||||
|
||||
**Template Structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
completions, # List[List[Dict]]: Model generations
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""
|
||||
Evaluate completions and return rewards.
|
||||
|
||||
Returns: List of floats (one per completion)
|
||||
"""
|
||||
# Extract completion text
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
|
||||
# Compute rewards
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: Correctness Reward (Math/Coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_final_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: Format Reward (Structured Output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward XML-like structured format."""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: Incremental Format Reward (Partial Credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.25
|
||||
if '</reasoning>' in r:
|
||||
score += 0.25
|
||||
if '<answer>' in r:
|
||||
score += 0.25
|
||||
if '</answer>' in r:
|
||||
score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra_text) * 0.001
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical Insight:**
|
||||
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training Configuration
|
||||
|
||||
**Memory-Optimized Config (Small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6, # Lower = more stable
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8-16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None, # Or set fixed steps (e.g., 500)
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
optim="adamw_8bit", # Memory-efficient optimizer
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Or "none" for no logging
|
||||
)
|
||||
```
|
||||
|
||||
**High-Performance Config (Large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
num_generations=16, # Larger groups = better signal
|
||||
max_prompt_length=512,
|
||||
max_completion_length=1024,
|
||||
num_train_epochs=1,
|
||||
bf16=True,
|
||||
use_vllm=True, # Fast generation with vLLM
|
||||
logging_steps=10,
|
||||
)
|
||||
```
|
||||
|
||||
**Critical Hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model Setup and Training
|
||||
|
||||
**Standard Setup (Transformers)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load model
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2-3x faster
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16, # Rank (higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward,
|
||||
format_reward,
|
||||
correctness_reward,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth Setup (2-3x Faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="google/gemma-3-1b-it",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=32,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=32,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Training Insights
|
||||
|
||||
### 1. Loss Behavior (EXPECTED PATTERN)
|
||||
- **Loss starts near 0 and INCREASES during training**
|
||||
- This is CORRECT - loss measures KL divergence from initial policy
|
||||
- Model is learning (diverging from original behavior to optimize rewards)
|
||||
- Monitor reward metrics instead of loss for progress
|
||||
|
||||
### 2. Reward Tracking
|
||||
Key metrics to watch:
|
||||
- `reward`: Average across all completions
|
||||
- `reward_std`: Diversity within groups (should remain > 0)
|
||||
- `kl`: KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy Training Pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
200 0.8 0.25 0.05
|
||||
300 1.2 0.2 0.08 ← Good progression
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning Signs:**
|
||||
- Reward std → 0 (model collapsing to single response)
|
||||
- KL exploding (> 0.5) (diverging too much, reduce LR)
|
||||
- Reward stuck (reward functions too harsh or model capacity issue)
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
|
||||
| **No learning** | Flat rewards | Check reward function logic, increase LR |
|
||||
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
---
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### 1. Multi-Stage Training
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance (epochs=1)
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness (epochs=1)
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive Reward Scaling
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
self.func = base_reward_func
|
||||
self.weight = initial_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
rewards = self.func(*args, **kwargs)
|
||||
return [r * self.weight for r in rewards]
|
||||
|
||||
def adjust_weight(self, success_rate):
|
||||
"""Increase weight if model struggling, decrease if succeeding."""
|
||||
if success_rate < 0.3:
|
||||
self.weight *= 1.2
|
||||
elif success_rate > 0.8:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom Dataset Integration
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
"""Example: School communication platform docs."""
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
dataset = Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
return dataset
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deployment and Inference
|
||||
|
||||
### Save and Merge LoRA
|
||||
```python
|
||||
# Merge LoRA adapters into base model
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
model="production_model",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"}
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices Checklist
|
||||
|
||||
**Before Training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected max_prompt_length from data
|
||||
- [ ] Choose appropriate num_generations based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During Training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check reward_std (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50-100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After Training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting Guide
|
||||
|
||||
### Debugging Workflow
|
||||
1. **Isolate reward functions** - Test each independently
|
||||
2. **Check data distribution** - Ensure diversity in prompts
|
||||
3. **Reduce complexity** - Start with single reward, add gradually
|
||||
4. **Monitor generations** - Print samples every N steps
|
||||
5. **Validate extraction logic** - Ensure answer parsing works
|
||||
|
||||
### Quick Fixes
|
||||
```python
|
||||
# Debug reward function
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]): # Print first 2
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses) # Dummy rewards
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1]) # Generate without updating
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References and Resources
|
||||
|
||||
**Official Documentation:**
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
|
||||
- Unsloth Docs: https://docs.unsloth.ai/
|
||||
|
||||
**Example Repositories:**
|
||||
- Open R1 Implementation: https://github.com/huggingface/open-r1
|
||||
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
|
||||
**Recommended Reading:**
|
||||
- Progressive Disclosure Pattern for agent instructions
|
||||
- Reward shaping in RL (Ng et al.)
|
||||
- LoRA paper (Hu et al., 2021)
|
||||
|
||||
---
|
||||
|
||||
## Usage Instructions for Agents
|
||||
|
||||
When this skill is loaded:
|
||||
|
||||
1. **Read this entire file** before implementing GRPO training
|
||||
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
|
||||
3. **Use the templates** in `templates/` directory as starting points
|
||||
4. **Reference examples** in `examples/` for task-specific implementations
|
||||
5. **Follow the workflow** sequentially (don't skip steps)
|
||||
6. **Debug incrementally** - add one reward function at a time
|
||||
|
||||
**Critical Reminders:**
|
||||
- Always use multiple reward functions (3-5 is optimal)
|
||||
- Monitor reward metrics, not loss
|
||||
- Test reward functions before training
|
||||
- Start small (num_generations=4), scale up gradually
|
||||
- Save checkpoints frequently (every 100 steps)
|
||||
|
||||
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
|
||||
|
||||
|
||||
|
||||
228
skills/mlops/grpo-rl-training/templates/basic_grpo_training.py
Normal file
228
skills/mlops/grpo-rl-training/templates/basic_grpo_training.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Basic GRPO Training Template
|
||||
=============================
|
||||
|
||||
A minimal, production-ready template for GRPO training with TRL.
|
||||
Adapt this for your specific task by modifying:
|
||||
1. Dataset loading (get_dataset function)
|
||||
2. Reward functions (reward_*_func)
|
||||
3. System prompt (SYSTEM_PROMPT)
|
||||
4. Hyperparameters (GRPOConfig)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import re
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
# ==================== CONFIGURATION ====================
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
OUTPUT_DIR = "outputs/grpo-model"
|
||||
MAX_PROMPT_LENGTH = 256
|
||||
MAX_COMPLETION_LENGTH = 512
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
# ==================== DATASET ====================
|
||||
|
||||
def get_dataset(split="train"):
|
||||
"""
|
||||
Load and prepare your dataset.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content
|
||||
- 'answer': str (ground truth, optional)
|
||||
"""
|
||||
# Example: GSM8K math dataset
|
||||
data = load_dataset('openai/gsm8k', 'main')[split]
|
||||
|
||||
def process_example(x):
|
||||
# Extract ground truth answer
|
||||
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
||||
|
||||
return {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
return data.map(process_example)
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Extract content between XML tags."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract the final answer from structured output."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
# ==================== REWARD FUNCTIONS ====================
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
||||
"""
|
||||
Reward correct answers.
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Reward proper XML format.
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
||||
|
||||
def incremental_format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Incremental reward for partial format compliance.
|
||||
Weight: up to 0.5
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.125
|
||||
if '</reasoning>' in r:
|
||||
score += 0.125
|
||||
if '<answer>' in r:
|
||||
score += 0.125
|
||||
if '</answer>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after closing tag
|
||||
if '</answer>' in r:
|
||||
extra = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== MODEL SETUP ====================
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Load model and tokenizer with optimizations."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def get_peft_config():
|
||||
"""LoRA configuration for parameter-efficient training."""
|
||||
return LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# ==================== TRAINING ====================
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
|
||||
# Load data
|
||||
print("Loading dataset...")
|
||||
dataset = get_dataset()
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Setup model
|
||||
print("Loading model...")
|
||||
model, tokenizer = setup_model_and_tokenizer()
|
||||
|
||||
# Training configuration
|
||||
training_args = GRPOConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
run_name="grpo-training",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# GRPO specific
|
||||
num_generations=8,
|
||||
max_prompt_length=MAX_PROMPT_LENGTH,
|
||||
max_completion_length=MAX_COMPLETION_LENGTH,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
|
||||
# Optimization
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Change to "none" to disable logging
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward_func,
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=get_peft_config(),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving model to {OUTPUT_DIR}/final")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
575
skills/mlops/guidance/SKILL.md
Normal file
575
skills/mlops/guidance/SKILL.md
Normal file
@@ -0,0 +1,575 @@
|
||||
---
|
||||
name: guidance
|
||||
description: Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [guidance, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Prompt Engineering, Guidance, Constrained Generation, Structured Output, JSON Validation, Grammar, Microsoft Research, Format Enforcement, Multi-Step Workflows]
|
||||
|
||||
---
|
||||
|
||||
# Guidance: Constrained LLM Generation
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Guidance when you need to:
|
||||
- **Control LLM output syntax** with regex or grammars
|
||||
- **Guarantee valid JSON/XML/code** generation
|
||||
- **Reduce latency** vs traditional prompting approaches
|
||||
- **Enforce structured formats** (dates, emails, IDs, etc.)
|
||||
- **Build multi-step workflows** with Pythonic control flow
|
||||
- **Prevent invalid outputs** through grammatical constraints
|
||||
|
||||
**GitHub Stars**: 18,000+ | **From**: Microsoft Research
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install guidance
|
||||
|
||||
# With specific backends
|
||||
pip install guidance[transformers] # Hugging Face models
|
||||
pip install guidance[llama_cpp] # llama.cpp models
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Structured Generation
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
# Load model (supports OpenAI, Transformers, llama.cpp)
|
||||
lm = models.OpenAI("gpt-4")
|
||||
|
||||
# Generate with constraints
|
||||
result = lm + "The capital of France is " + gen("capital", max_tokens=5)
|
||||
|
||||
print(result["capital"]) # "Paris"
|
||||
```
|
||||
|
||||
### With Anthropic Claude
|
||||
|
||||
```python
|
||||
from guidance import models, gen, system, user, assistant
|
||||
|
||||
# Configure Claude
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use context managers for chat format
|
||||
with system():
|
||||
lm += "You are a helpful assistant."
|
||||
|
||||
with user():
|
||||
lm += "What is the capital of France?"
|
||||
|
||||
with assistant():
|
||||
lm += gen(max_tokens=20)
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Context Managers
|
||||
|
||||
Guidance uses Pythonic context managers for chat-style interactions.
|
||||
|
||||
```python
|
||||
from guidance import system, user, assistant, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# System message
|
||||
with system():
|
||||
lm += "You are a JSON generation expert."
|
||||
|
||||
# User message
|
||||
with user():
|
||||
lm += "Generate a person object with name and age."
|
||||
|
||||
# Assistant response
|
||||
with assistant():
|
||||
lm += gen("response", max_tokens=100)
|
||||
|
||||
print(lm["response"])
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Natural chat flow
|
||||
- Clear role separation
|
||||
- Easy to read and maintain
|
||||
|
||||
### 2. Constrained Generation
|
||||
|
||||
Guidance ensures outputs match specified patterns using regex or grammars.
|
||||
|
||||
#### Regex Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Constrain to valid email format
|
||||
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
# Constrain to date format (YYYY-MM-DD)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
|
||||
|
||||
# Constrain to phone number
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
|
||||
|
||||
print(lm["email"]) # Guaranteed valid email
|
||||
print(lm["date"]) # Guaranteed YYYY-MM-DD format
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Regex converted to grammar at token level
|
||||
- Invalid tokens filtered during generation
|
||||
- Model can only produce matching outputs
|
||||
|
||||
#### Selection Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Constrain to specific choices
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
|
||||
# Multiple-choice selection
|
||||
lm += "Best answer: " + select(
|
||||
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
|
||||
name="answer"
|
||||
)
|
||||
|
||||
print(lm["sentiment"]) # One of: positive, negative, neutral
|
||||
print(lm["answer"]) # One of: A, B, C, or D
|
||||
```
|
||||
|
||||
### 3. Token Healing
|
||||
|
||||
Guidance automatically "heals" token boundaries between prompt and generation.
|
||||
|
||||
**Problem:** Tokenization creates unnatural boundaries.
|
||||
|
||||
```python
|
||||
# Without token healing
|
||||
prompt = "The capital of France is "
|
||||
# Last token: " is "
|
||||
# First generated token might be " Par" (with leading space)
|
||||
# Result: "The capital of France is Paris" (double space!)
|
||||
```
|
||||
|
||||
**Solution:** Guidance backs up one token and regenerates.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Token healing enabled by default
|
||||
lm += "The capital of France is " + gen("capital", max_tokens=5)
|
||||
# Result: "The capital of France is Paris" (correct spacing)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Natural text boundaries
|
||||
- No awkward spacing issues
|
||||
- Better model performance (sees natural token sequences)
|
||||
|
||||
### 4. Grammar-Based Generation
|
||||
|
||||
Define complex structures using context-free grammars.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# JSON grammar (simplified)
|
||||
json_grammar = """
|
||||
{
|
||||
"name": <gen name regex="[A-Za-z ]+" max_tokens=20>,
|
||||
"age": <gen age regex="[0-9]+" max_tokens=3>,
|
||||
"email": <gen email regex="[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" max_tokens=50>
|
||||
}
|
||||
"""
|
||||
|
||||
# Generate valid JSON
|
||||
lm += gen("person", grammar=json_grammar)
|
||||
|
||||
print(lm["person"]) # Guaranteed valid JSON structure
|
||||
```
|
||||
|
||||
**Use cases:**
|
||||
- Complex structured outputs
|
||||
- Nested data structures
|
||||
- Programming language syntax
|
||||
- Domain-specific languages
|
||||
|
||||
### 5. Guidance Functions
|
||||
|
||||
Create reusable generation patterns with the `@guidance` decorator.
|
||||
|
||||
```python
|
||||
from guidance import guidance, gen, models
|
||||
|
||||
@guidance
|
||||
def generate_person(lm):
|
||||
"""Generate a person with name and age."""
|
||||
lm += "Name: " + gen("name", max_tokens=20, stop="\n")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+", max_tokens=3)
|
||||
return lm
|
||||
|
||||
# Use the function
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_person(lm)
|
||||
|
||||
print(lm["name"])
|
||||
print(lm["age"])
|
||||
```
|
||||
|
||||
**Stateful Functions:**
|
||||
|
||||
```python
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question, tools, max_rounds=5):
|
||||
"""ReAct agent with tool use."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for i in range(max_rounds):
|
||||
# Thought
|
||||
lm += f"Thought {i+1}: " + gen("thought", stop="\n")
|
||||
|
||||
# Action
|
||||
lm += "\nAction: " + select(list(tools.keys()), name="action")
|
||||
|
||||
# Execute tool
|
||||
tool_result = tools[lm["action"]]()
|
||||
lm += f"\nObservation: {tool_result}\n\n"
|
||||
|
||||
# Check if done
|
||||
lm += "Done? " + select(["Yes", "No"], name="done")
|
||||
if lm["done"] == "Yes":
|
||||
break
|
||||
|
||||
# Final answer
|
||||
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
|
||||
return lm
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key" # Or set ANTHROPIC_API_KEY env var
|
||||
)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="your-api-key" # Or set OPENAI_API_KEY env var
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (Transformers)
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda" # Or "cpu"
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (llama.cpp)
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: JSON Generation
|
||||
|
||||
```python
|
||||
from guidance import models, gen, system, user, assistant
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You generate valid JSON."
|
||||
|
||||
with user():
|
||||
lm += "Generate a user profile with name, age, and email."
|
||||
|
||||
with assistant():
|
||||
lm += """{
|
||||
"name": """ + gen("name", regex=r'"[A-Za-z ]+"', max_tokens=30) + """,
|
||||
"age": """ + gen("age", regex=r"[0-9]+", max_tokens=3) + """,
|
||||
"email": """ + gen("email", regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"', max_tokens=50) + """
|
||||
}"""
|
||||
|
||||
print(lm) # Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
text = "This product is amazing! I love it."
|
||||
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]+", max_tokens=3) + "%"
|
||||
|
||||
print(f"Sentiment: {lm['sentiment']}")
|
||||
print(f"Confidence: {lm['confidence']}%")
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Step Reasoning
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def chain_of_thought(lm, question):
|
||||
"""Generate answer with step-by-step reasoning."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
# Generate multiple reasoning steps
|
||||
for i in range(3):
|
||||
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Final answer
|
||||
lm += "\nTherefore, the answer is: " + gen("answer", max_tokens=50)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = chain_of_thought(lm, "What is 15% of 200?")
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Pattern 4: ReAct Agent
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select, guidance
|
||||
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question):
|
||||
"""ReAct agent with tool use."""
|
||||
tools = {
|
||||
"calculator": lambda expr: eval(expr),
|
||||
"search": lambda query: f"Search results for: {query}",
|
||||
}
|
||||
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for round in range(5):
|
||||
# Thought
|
||||
lm += f"Thought: " + gen("thought", stop="\n") + "\n"
|
||||
|
||||
# Action selection
|
||||
lm += "Action: " + select(["calculator", "search", "answer"], name="action")
|
||||
|
||||
if lm["action"] == "answer":
|
||||
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
|
||||
break
|
||||
|
||||
# Action input
|
||||
lm += "\nAction Input: " + gen("action_input", stop="\n") + "\n"
|
||||
|
||||
# Execute tool
|
||||
if lm["action"] in tools:
|
||||
result = tools[lm["action"]](lm["action_input"])
|
||||
lm += f"Observation: {result}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = react_agent(lm, "What is 25 * 4 + 10?")
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Pattern 5: Data Extraction
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract structured entities from text."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
# Extract person
|
||||
lm += "Person: " + gen("person", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
# Extract organization
|
||||
lm += "Organization: " + gen("organization", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
# Extract date
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}", max_tokens=10) + "\n"
|
||||
|
||||
# Extract location
|
||||
lm += "Location: " + gen("location", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = "Tim Cook announced at Apple Park on 2024-09-15 in Cupertino."
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = extract_entities(lm, text)
|
||||
|
||||
print(f"Person: {lm['person']}")
|
||||
print(f"Organization: {lm['organization']}")
|
||||
print(f"Date: {lm['date']}")
|
||||
print(f"Location: {lm['location']}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Regex for Format Validation
|
||||
|
||||
```python
|
||||
# ✅ Good: Regex ensures valid format
|
||||
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
# ❌ Bad: Free generation may produce invalid emails
|
||||
lm += "Email: " + gen("email", max_tokens=50)
|
||||
```
|
||||
|
||||
### 2. Use select() for Fixed Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Guaranteed valid category
|
||||
lm += "Status: " + select(["pending", "approved", "rejected"], name="status")
|
||||
|
||||
# ❌ Bad: May generate typos or invalid values
|
||||
lm += "Status: " + gen("status", max_tokens=20)
|
||||
```
|
||||
|
||||
### 3. Leverage Token Healing
|
||||
|
||||
```python
|
||||
# Token healing is enabled by default
|
||||
# No special action needed - just concatenate naturally
|
||||
lm += "The capital is " + gen("capital") # Automatic healing
|
||||
```
|
||||
|
||||
### 4. Use stop Sequences
|
||||
|
||||
```python
|
||||
# ✅ Good: Stop at newline for single-line outputs
|
||||
lm += "Name: " + gen("name", stop="\n")
|
||||
|
||||
# ❌ Bad: May generate multiple lines
|
||||
lm += "Name: " + gen("name", max_tokens=50)
|
||||
```
|
||||
|
||||
### 5. Create Reusable Functions
|
||||
|
||||
```python
|
||||
# ✅ Good: Reusable pattern
|
||||
@guidance
|
||||
def generate_person(lm):
|
||||
lm += "Name: " + gen("name", stop="\n")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
|
||||
return lm
|
||||
|
||||
# Use multiple times
|
||||
lm = generate_person(lm)
|
||||
lm += "\n\n"
|
||||
lm = generate_person(lm)
|
||||
```
|
||||
|
||||
### 6. Balance Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: Reasonable constraints
|
||||
lm += gen("name", regex=r"[A-Za-z ]+", max_tokens=30)
|
||||
|
||||
# ❌ Too strict: May fail or be very slow
|
||||
lm += gen("name", regex=r"^(John|Jane)$", max_tokens=10)
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Guidance | Instructor | Outlines | LMQL |
|
||||
|---------|----------|------------|----------|------|
|
||||
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
|
||||
| Grammar Support | ✅ CFG | ❌ No | ✅ CFG | ✅ CFG |
|
||||
| Pydantic Validation | ❌ No | ✅ Yes | ✅ Yes | ❌ No |
|
||||
| Token Healing | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||
| Local Models | ✅ Yes | ⚠️ Limited | ✅ Yes | ✅ Yes |
|
||||
| API Models | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
|
||||
| Pythonic Syntax | ✅ Yes | ✅ Yes | ✅ Yes | ❌ SQL-like |
|
||||
| Learning Curve | Low | Low | Medium | High |
|
||||
|
||||
**When to choose Guidance:**
|
||||
- Need regex/grammar constraints
|
||||
- Want token healing
|
||||
- Building complex workflows with control flow
|
||||
- Using local models (Transformers, llama.cpp)
|
||||
- Prefer Pythonic syntax
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Instructor: Need Pydantic validation with automatic retrying
|
||||
- Outlines: Need JSON schema validation
|
||||
- LMQL: Prefer declarative query syntax
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
**Latency Reduction:**
|
||||
- 30-50% faster than traditional prompting for constrained outputs
|
||||
- Token healing reduces unnecessary regeneration
|
||||
- Grammar constraints prevent invalid token generation
|
||||
|
||||
**Memory Usage:**
|
||||
- Minimal overhead vs unconstrained generation
|
||||
- Grammar compilation cached after first use
|
||||
- Efficient token filtering at inference time
|
||||
|
||||
**Token Efficiency:**
|
||||
- Prevents wasted tokens on invalid outputs
|
||||
- No need for retry loops
|
||||
- Direct path to valid outputs
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://guidance.readthedocs.io
|
||||
- **GitHub**: https://github.com/guidance-ai/guidance (18k+ stars)
|
||||
- **Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
|
||||
- **Discord**: Community support available
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/constraints.md` - Comprehensive regex and grammar patterns
|
||||
- `references/backends.md` - Backend-specific configuration
|
||||
- `references/examples.md` - Production-ready examples
|
||||
|
||||
|
||||
554
skills/mlops/guidance/references/backends.md
Normal file
554
skills/mlops/guidance/references/backends.md
Normal file
@@ -0,0 +1,554 @@
|
||||
# Backend Configuration Guide
|
||||
|
||||
Complete guide to configuring Guidance with different LLM backends.
|
||||
|
||||
## Table of Contents
|
||||
- API-Based Models (Anthropic, OpenAI)
|
||||
- Local Models (Transformers, llama.cpp)
|
||||
- Backend Comparison
|
||||
- Performance Tuning
|
||||
- Advanced Configuration
|
||||
|
||||
## API-Based Models
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
# Using environment variable
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
# Reads ANTHROPIC_API_KEY from environment
|
||||
|
||||
# Explicit API key
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
```
|
||||
|
||||
#### Available Models
|
||||
|
||||
```python
|
||||
# Claude 3.5 Sonnet (Latest, recommended)
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Claude 3.7 Sonnet (Fast, cost-effective)
|
||||
lm = models.Anthropic("claude-sonnet-3.7-20250219")
|
||||
|
||||
# Claude 3 Opus (Most capable)
|
||||
lm = models.Anthropic("claude-3-opus-20240229")
|
||||
|
||||
# Claude 3.5 Haiku (Fastest, cheapest)
|
||||
lm = models.Anthropic("claude-3-5-haiku-20241022")
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
```python
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key",
|
||||
max_tokens=4096, # Max tokens to generate
|
||||
temperature=0.7, # Sampling temperature (0-1)
|
||||
top_p=0.9, # Nucleus sampling
|
||||
timeout=30, # Request timeout (seconds)
|
||||
max_retries=3 # Retry failed requests
|
||||
)
|
||||
```
|
||||
|
||||
#### With Context Managers
|
||||
|
||||
```python
|
||||
from guidance import models, system, user, assistant, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You are a helpful assistant."
|
||||
|
||||
with user():
|
||||
lm += "What is the capital of France?"
|
||||
|
||||
with assistant():
|
||||
lm += gen(max_tokens=50)
|
||||
|
||||
print(lm)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
# Using environment variable
|
||||
lm = models.OpenAI("gpt-4o")
|
||||
# Reads OPENAI_API_KEY from environment
|
||||
|
||||
# Explicit API key
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
```
|
||||
|
||||
#### Available Models
|
||||
|
||||
```python
|
||||
# GPT-4o (Latest, multimodal)
|
||||
lm = models.OpenAI("gpt-4o")
|
||||
|
||||
# GPT-4o Mini (Fast, cost-effective)
|
||||
lm = models.OpenAI("gpt-4o-mini")
|
||||
|
||||
# GPT-4 Turbo
|
||||
lm = models.OpenAI("gpt-4-turbo")
|
||||
|
||||
# GPT-3.5 Turbo (Cheapest)
|
||||
lm = models.OpenAI("gpt-3.5-turbo")
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
```python
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="your-api-key",
|
||||
max_tokens=2048,
|
||||
temperature=0.7,
|
||||
top_p=1.0,
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
timeout=30
|
||||
)
|
||||
```
|
||||
|
||||
#### Chat Format
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.OpenAI("gpt-4o-mini")
|
||||
|
||||
# OpenAI uses chat format
|
||||
lm += [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is 2+2?"}
|
||||
]
|
||||
|
||||
# Generate response
|
||||
lm += gen(max_tokens=50)
|
||||
```
|
||||
|
||||
### Azure OpenAI
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
lm = models.AzureOpenAI(
|
||||
model="gpt-4o",
|
||||
azure_endpoint="https://your-resource.openai.azure.com/",
|
||||
api_key="your-azure-api-key",
|
||||
api_version="2024-02-15-preview",
|
||||
deployment_name="your-deployment-name"
|
||||
)
|
||||
```
|
||||
|
||||
## Local Models
|
||||
|
||||
### Transformers (Hugging Face)
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Load model from Hugging Face
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
```
|
||||
|
||||
#### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use GPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Use specific GPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda:0" # GPU 0
|
||||
)
|
||||
|
||||
# Use CPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cpu"
|
||||
)
|
||||
```
|
||||
|
||||
#### Advanced Configuration
|
||||
|
||||
```python
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda",
|
||||
torch_dtype="float16", # Use FP16 (faster, less memory)
|
||||
load_in_8bit=True, # 8-bit quantization
|
||||
max_memory={0: "20GB"}, # GPU memory limit
|
||||
offload_folder="./offload" # Offload to disk if needed
|
||||
)
|
||||
```
|
||||
|
||||
#### Popular Models
|
||||
|
||||
```python
|
||||
# Phi-4 (Microsoft)
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
lm = Transformers("microsoft/Phi-3-medium-4k-instruct")
|
||||
|
||||
# Llama 3 (Meta)
|
||||
lm = Transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
lm = Transformers("meta-llama/Llama-3.1-70B-Instruct")
|
||||
|
||||
# Mistral (Mistral AI)
|
||||
lm = Transformers("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
lm = Transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
|
||||
# Qwen (Alibaba)
|
||||
lm = Transformers("Qwen/Qwen2.5-7B-Instruct")
|
||||
|
||||
# Gemma (Google)
|
||||
lm = Transformers("google/gemma-2-9b-it")
|
||||
```
|
||||
|
||||
#### Generation Configuration
|
||||
|
||||
```python
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Configure generation
|
||||
from guidance import gen
|
||||
|
||||
result = lm + gen(
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
repetition_penalty=1.1
|
||||
)
|
||||
```
|
||||
|
||||
### llama.cpp
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
# Load GGUF model
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096 # Context window
|
||||
)
|
||||
```
|
||||
|
||||
#### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use GPU acceleration
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # Offload 35 layers to GPU
|
||||
n_threads=8 # CPU threads for remaining layers
|
||||
)
|
||||
|
||||
# Full GPU offload
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=-1 # Offload all layers
|
||||
)
|
||||
```
|
||||
|
||||
#### Advanced Configuration
|
||||
|
||||
```python
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/llama-3.1-8b-instruct.Q4_K_M.gguf",
|
||||
n_ctx=8192, # Context window (tokens)
|
||||
n_gpu_layers=35, # GPU layers
|
||||
n_threads=8, # CPU threads
|
||||
n_batch=512, # Batch size for prompt processing
|
||||
use_mmap=True, # Memory-map the model file
|
||||
use_mlock=False, # Lock model in RAM
|
||||
seed=42, # Random seed
|
||||
verbose=False # Suppress verbose output
|
||||
)
|
||||
```
|
||||
|
||||
#### Quantized Models
|
||||
|
||||
```python
|
||||
# Q4_K_M (4-bit, recommended for most cases)
|
||||
lm = LlamaCpp("/path/to/model.Q4_K_M.gguf")
|
||||
|
||||
# Q5_K_M (5-bit, better quality)
|
||||
lm = LlamaCpp("/path/to/model.Q5_K_M.gguf")
|
||||
|
||||
# Q8_0 (8-bit, high quality)
|
||||
lm = LlamaCpp("/path/to/model.Q8_0.gguf")
|
||||
|
||||
# F16 (16-bit float, highest quality)
|
||||
lm = LlamaCpp("/path/to/model.F16.gguf")
|
||||
```
|
||||
|
||||
#### Popular GGUF Models
|
||||
|
||||
```python
|
||||
# Llama 3.1
|
||||
lm = LlamaCpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
|
||||
|
||||
# Mistral
|
||||
lm = LlamaCpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
|
||||
|
||||
# Phi-4
|
||||
lm = LlamaCpp("phi-4-mini-instruct.Q4_K_M.gguf")
|
||||
```
|
||||
|
||||
## Backend Comparison
|
||||
|
||||
### Feature Matrix
|
||||
|
||||
| Feature | Anthropic | OpenAI | Transformers | llama.cpp |
|
||||
|---------|-----------|--------|--------------|-----------|
|
||||
| Constrained Generation | ✅ Full | ✅ Full | ✅ Full | ✅ Full |
|
||||
| Token Healing | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
| Streaming | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
| GPU Support | N/A | N/A | ✅ Yes | ✅ Yes |
|
||||
| Quantization | N/A | N/A | ✅ Yes | ✅ Yes |
|
||||
| Cost | $$$ | $$$ | Free | Free |
|
||||
| Latency | Low | Low | Medium | Low |
|
||||
| Setup Difficulty | Easy | Easy | Medium | Medium |
|
||||
|
||||
### Performance Characteristics
|
||||
|
||||
**Anthropic Claude:**
|
||||
- **Latency**: 200-500ms (API call)
|
||||
- **Throughput**: Limited by API rate limits
|
||||
- **Cost**: $3-15 per 1M input tokens
|
||||
- **Best for**: Production systems, high-quality outputs
|
||||
|
||||
**OpenAI:**
|
||||
- **Latency**: 200-400ms (API call)
|
||||
- **Throughput**: Limited by API rate limits
|
||||
- **Cost**: $0.15-30 per 1M input tokens
|
||||
- **Best for**: Cost-sensitive production, gpt-4o-mini
|
||||
|
||||
**Transformers:**
|
||||
- **Latency**: 50-200ms (local inference)
|
||||
- **Throughput**: GPU-dependent (10-100 tokens/sec)
|
||||
- **Cost**: Hardware cost only
|
||||
- **Best for**: Privacy-sensitive, high-volume, experimentation
|
||||
|
||||
**llama.cpp:**
|
||||
- **Latency**: 30-150ms (local inference)
|
||||
- **Throughput**: Hardware-dependent (20-150 tokens/sec)
|
||||
- **Cost**: Hardware cost only
|
||||
- **Best for**: Edge deployment, Apple Silicon, CPU inference
|
||||
|
||||
### Memory Requirements
|
||||
|
||||
**Transformers (FP16):**
|
||||
- 7B model: ~14GB GPU VRAM
|
||||
- 13B model: ~26GB GPU VRAM
|
||||
- 70B model: ~140GB GPU VRAM (multi-GPU)
|
||||
|
||||
**llama.cpp (Q4_K_M):**
|
||||
- 7B model: ~4.5GB RAM
|
||||
- 13B model: ~8GB RAM
|
||||
- 70B model: ~40GB RAM
|
||||
|
||||
**Optimization Tips:**
|
||||
- Use quantized models (Q4_K_M) for lower memory
|
||||
- Use GPU offloading for faster inference
|
||||
- Use CPU inference for smaller models (<7B)
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### API Models (Anthropic, OpenAI)
|
||||
|
||||
#### Reduce Latency
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use lower max_tokens (faster response)
|
||||
lm += gen(max_tokens=100) # Instead of 1000
|
||||
|
||||
# Use streaming (perceived latency reduction)
|
||||
for chunk in lm.stream(gen(max_tokens=500)):
|
||||
print(chunk, end="", flush=True)
|
||||
```
|
||||
|
||||
#### Reduce Cost
|
||||
|
||||
```python
|
||||
# Use cheaper models
|
||||
lm = models.Anthropic("claude-3-5-haiku-20241022") # vs Sonnet
|
||||
lm = models.OpenAI("gpt-4o-mini") # vs gpt-4o
|
||||
|
||||
# Reduce context size
|
||||
# - Keep prompts concise
|
||||
# - Avoid large few-shot examples
|
||||
# - Use max_tokens limits
|
||||
```
|
||||
|
||||
### Local Models (Transformers, llama.cpp)
|
||||
|
||||
#### Optimize GPU Usage
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Use FP16 for 2x speedup
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use 8-bit quantization for 4x memory reduction
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
load_in_8bit=True
|
||||
)
|
||||
|
||||
# Use flash attention (requires flash-attn package)
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
use_flash_attention_2=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Optimize llama.cpp
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
# Maximize GPU layers
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1 # All layers on GPU
|
||||
)
|
||||
|
||||
# Optimize batch size
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_batch=512, # Larger batch = faster prompt processing
|
||||
n_gpu_layers=-1
|
||||
)
|
||||
|
||||
# Use Metal (Apple Silicon)
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1, # Use Metal GPU acceleration
|
||||
use_mmap=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Batch Processing
|
||||
|
||||
```python
|
||||
# Process multiple requests efficiently
|
||||
requests = [
|
||||
"What is 2+2?",
|
||||
"What is the capital of France?",
|
||||
"What is photosynthesis?"
|
||||
]
|
||||
|
||||
# Bad: Sequential processing
|
||||
for req in requests:
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
lm += req + gen(max_tokens=50)
|
||||
|
||||
# Good: Reuse loaded model
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
for req in requests:
|
||||
lm += req + gen(max_tokens=50)
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Custom Model Configurations
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Load custom model
|
||||
tokenizer = AutoTokenizer.from_pretrained("your-model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"your-model",
|
||||
device_map="auto",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use with Guidance
|
||||
lm = Transformers(model=model, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# API keys
|
||||
export ANTHROPIC_API_KEY="sk-ant-..."
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
|
||||
# Transformers cache
|
||||
export HF_HOME="/path/to/cache"
|
||||
export TRANSFORMERS_CACHE="/path/to/cache"
|
||||
|
||||
# GPU selection
|
||||
export CUDA_VISIBLE_DEVICES=0,1 # Use GPU 0 and 1
|
||||
```
|
||||
|
||||
### Debugging
|
||||
|
||||
```python
|
||||
# Enable verbose logging
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Check backend info
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
print(f"Model: {lm.model_name}")
|
||||
print(f"Backend: {lm.backend}")
|
||||
|
||||
# Check GPU usage (Transformers)
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct", device="cuda")
|
||||
print(f"Device: {lm.device}")
|
||||
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Anthropic Docs**: https://docs.anthropic.com
|
||||
- **OpenAI Docs**: https://platform.openai.com/docs
|
||||
- **Hugging Face Models**: https://huggingface.co/models
|
||||
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
|
||||
- **GGUF Models**: https://huggingface.co/models?library=gguf
|
||||
674
skills/mlops/guidance/references/constraints.md
Normal file
674
skills/mlops/guidance/references/constraints.md
Normal file
@@ -0,0 +1,674 @@
|
||||
# Comprehensive Constraint Patterns
|
||||
|
||||
Guide to regex constraints, grammar-based generation, and token healing in Guidance.
|
||||
|
||||
## Table of Contents
|
||||
- Regex Constraints
|
||||
- Grammar-Based Generation
|
||||
- Token Healing
|
||||
- Selection Constraints
|
||||
- Complex Patterns
|
||||
- Performance Optimization
|
||||
|
||||
## Regex Constraints
|
||||
|
||||
### Basic Patterns
|
||||
|
||||
#### Numeric Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Integer (positive)
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+")
|
||||
|
||||
# Integer (with negatives)
|
||||
lm += "Temperature: " + gen("temp", regex=r"-?[0-9]+")
|
||||
|
||||
# Float (positive)
|
||||
lm += "Price: $" + gen("price", regex=r"[0-9]+\.[0-9]{2}")
|
||||
|
||||
# Float (with negatives and optional decimals)
|
||||
lm += "Value: " + gen("value", regex=r"-?[0-9]+(\.[0-9]+)?")
|
||||
|
||||
# Percentage (0-100)
|
||||
lm += "Progress: " + gen("progress", regex=r"(100|[0-9]{1,2})")
|
||||
|
||||
# Range (1-5 stars)
|
||||
lm += "Rating: " + gen("rating", regex=r"[1-5]") + " stars"
|
||||
```
|
||||
|
||||
#### Text Constraints
|
||||
|
||||
```python
|
||||
# Alphabetic only
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z]+")
|
||||
|
||||
# Alphabetic with spaces
|
||||
lm += "Full Name: " + gen("full_name", regex=r"[A-Za-z ]+")
|
||||
|
||||
# Alphanumeric
|
||||
lm += "Username: " + gen("username", regex=r"[A-Za-z0-9_]+")
|
||||
|
||||
# Capitalized words
|
||||
lm += "Title: " + gen("title", regex=r"[A-Z][a-z]+( [A-Z][a-z]+)*")
|
||||
|
||||
# Lowercase only
|
||||
lm += "Code: " + gen("code", regex=r"[a-z0-9-]+")
|
||||
|
||||
# Specific length
|
||||
lm += "ID: " + gen("id", regex=r"[A-Z]{3}-[0-9]{6}") # e.g., "ABC-123456"
|
||||
```
|
||||
|
||||
#### Date and Time Constraints
|
||||
|
||||
```python
|
||||
# Date (YYYY-MM-DD)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
|
||||
|
||||
# Date (MM/DD/YYYY)
|
||||
lm += "Date: " + gen("date_us", regex=r"\d{2}/\d{2}/\d{4}")
|
||||
|
||||
# Time (HH:MM)
|
||||
lm += "Time: " + gen("time", regex=r"\d{2}:\d{2}")
|
||||
|
||||
# Time (HH:MM:SS)
|
||||
lm += "Time: " + gen("time_full", regex=r"\d{2}:\d{2}:\d{2}")
|
||||
|
||||
# ISO 8601 datetime
|
||||
lm += "Timestamp: " + gen(
|
||||
"timestamp",
|
||||
regex=r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z"
|
||||
)
|
||||
|
||||
# Year (YYYY)
|
||||
lm += "Year: " + gen("year", regex=r"(19|20)\d{2}")
|
||||
|
||||
# Month name
|
||||
lm += "Month: " + gen(
|
||||
"month",
|
||||
regex=r"(January|February|March|April|May|June|July|August|September|October|November|December)"
|
||||
)
|
||||
```
|
||||
|
||||
#### Contact Information
|
||||
|
||||
```python
|
||||
# Email
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
)
|
||||
|
||||
# Phone (US format)
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
|
||||
|
||||
# Phone (international format)
|
||||
lm += "Phone: " + gen("phone_intl", regex=r"\+[0-9]{1,3}-[0-9]{1,14}")
|
||||
|
||||
# ZIP code (US)
|
||||
lm += "ZIP: " + gen("zip", regex=r"\d{5}(-\d{4})?")
|
||||
|
||||
# Postal code (Canada)
|
||||
lm += "Postal: " + gen("postal", regex=r"[A-Z]\d[A-Z] \d[A-Z]\d")
|
||||
|
||||
# URL
|
||||
lm += "URL: " + gen(
|
||||
"url",
|
||||
regex=r"https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/[a-zA-Z0-9._~:/?#\[\]@!$&'()*+,;=-]*)?"
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Patterns
|
||||
|
||||
#### JSON Field Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# String field with quotes
|
||||
lm += '"name": ' + gen("name", regex=r'"[A-Za-z ]+"')
|
||||
|
||||
# Numeric field (no quotes)
|
||||
lm += '"age": ' + gen("age", regex=r"[0-9]+")
|
||||
|
||||
# Boolean field
|
||||
lm += '"active": ' + gen("active", regex=r"(true|false)")
|
||||
|
||||
# Null field
|
||||
lm += '"optional": ' + gen("optional", regex=r"(null|[0-9]+)")
|
||||
|
||||
# Array of strings
|
||||
lm += '"tags": [' + gen(
|
||||
"tags",
|
||||
regex=r'"[a-z]+"(, "[a-z]+")*'
|
||||
) + ']'
|
||||
|
||||
# Complete JSON object
|
||||
lm += """{
|
||||
"name": """ + gen("name", regex=r'"[A-Za-z ]+"') + """,
|
||||
"age": """ + gen("age", regex=r"[0-9]+") + """,
|
||||
"email": """ + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + """
|
||||
}"""
|
||||
```
|
||||
|
||||
#### Code Patterns
|
||||
|
||||
```python
|
||||
# Python variable name
|
||||
lm += "Variable: " + gen("var", regex=r"[a-z_][a-z0-9_]*")
|
||||
|
||||
# Python function name
|
||||
lm += "Function: " + gen("func", regex=r"[a-z_][a-z0-9_]*")
|
||||
|
||||
# Hex color code
|
||||
lm += "Color: #" + gen("color", regex=r"[0-9A-Fa-f]{6}")
|
||||
|
||||
# UUID
|
||||
lm += "UUID: " + gen(
|
||||
"uuid",
|
||||
regex=r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
)
|
||||
|
||||
# Git commit hash (short)
|
||||
lm += "Commit: " + gen("commit", regex=r"[0-9a-f]{7}")
|
||||
|
||||
# Semantic version
|
||||
lm += "Version: " + gen("version", regex=r"[0-9]+\.[0-9]+\.[0-9]+")
|
||||
|
||||
# IP address (IPv4)
|
||||
lm += "IP: " + gen(
|
||||
"ip",
|
||||
regex=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
|
||||
)
|
||||
```
|
||||
|
||||
#### Domain-Specific Patterns
|
||||
|
||||
```python
|
||||
# Credit card number
|
||||
lm += "Card: " + gen("card", regex=r"\d{4}-\d{4}-\d{4}-\d{4}")
|
||||
|
||||
# Social Security Number (US)
|
||||
lm += "SSN: " + gen("ssn", regex=r"\d{3}-\d{2}-\d{4}")
|
||||
|
||||
# ISBN-13
|
||||
lm += "ISBN: " + gen("isbn", regex=r"978-\d{1,5}-\d{1,7}-\d{1,7}-\d")
|
||||
|
||||
# License plate (US)
|
||||
lm += "Plate: " + gen("plate", regex=r"[A-Z]{3}-\d{4}")
|
||||
|
||||
# Currency amount
|
||||
lm += "Amount: $" + gen("amount", regex=r"[0-9]{1,3}(,[0-9]{3})*\.[0-9]{2}")
|
||||
|
||||
# Percentage with decimal
|
||||
lm += "Rate: " + gen("rate", regex=r"[0-9]+\.[0-9]{1,2}%")
|
||||
```
|
||||
|
||||
## Grammar-Based Generation
|
||||
|
||||
### JSON Grammar
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def json_object(lm):
|
||||
"""Generate valid JSON object."""
|
||||
lm += "{\n"
|
||||
|
||||
# Name field (required)
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
|
||||
# Age field (required)
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
|
||||
|
||||
# Email field (required)
|
||||
lm += ' "email": ' + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + ",\n"
|
||||
|
||||
# Active field (required, boolean)
|
||||
lm += ' "active": ' + gen("active", regex=r"(true|false)") + "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = json_object(lm)
|
||||
print(lm) # Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Nested JSON Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def nested_json(lm):
|
||||
"""Generate nested JSON structure."""
|
||||
lm += "{\n"
|
||||
|
||||
# User object
|
||||
lm += ' "user": {\n'
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Address object
|
||||
lm += ' "address": {\n'
|
||||
lm += ' "street": ' + gen("street", regex=r'"[A-Za-z0-9 ]+"') + ",\n"
|
||||
lm += ' "city": ' + gen("city", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "zip": ' + gen("zip", regex=r'"\d{5}"') + "\n"
|
||||
lm += " }\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
```
|
||||
|
||||
### Array Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def json_array(lm, count=3):
|
||||
"""Generate JSON array with fixed count."""
|
||||
lm += "[\n"
|
||||
|
||||
for i in range(count):
|
||||
lm += " {\n"
|
||||
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + "\n"
|
||||
lm += " }"
|
||||
if i < count - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "]"
|
||||
return lm
|
||||
```
|
||||
|
||||
### XML Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def xml_document(lm):
|
||||
"""Generate valid XML document."""
|
||||
lm += '<?xml version="1.0"?>\n'
|
||||
lm += "<person>\n"
|
||||
|
||||
# Name element
|
||||
lm += " <name>" + gen("name", regex=r"[A-Za-z ]+") + "</name>\n"
|
||||
|
||||
# Age element
|
||||
lm += " <age>" + gen("age", regex=r"[0-9]+") + "</age>\n"
|
||||
|
||||
# Email element
|
||||
lm += " <email>" + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
) + "</email>\n"
|
||||
|
||||
lm += "</person>"
|
||||
return lm
|
||||
```
|
||||
|
||||
### CSV Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def csv_row(lm):
|
||||
"""Generate CSV row."""
|
||||
lm += gen("name", regex=r"[A-Za-z ]+") + ","
|
||||
lm += gen("age", regex=r"[0-9]+") + ","
|
||||
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def csv_document(lm, rows=5):
|
||||
"""Generate complete CSV."""
|
||||
# Header
|
||||
lm += "Name,Age,Email\n"
|
||||
|
||||
# Rows
|
||||
for i in range(rows):
|
||||
lm = csv_row(lm)
|
||||
if i < rows - 1:
|
||||
lm += "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Token Healing
|
||||
|
||||
### How Token Healing Works
|
||||
|
||||
**Problem:** Tokenization creates unnatural boundaries.
|
||||
|
||||
```python
|
||||
# Example without token healing
|
||||
prompt = "The capital of France is "
|
||||
# Tokenization: ["The", " capital", " of", " France", " is", " "]
|
||||
# Model sees last token: " "
|
||||
# First generated token might include leading space: " Paris"
|
||||
# Result: "The capital of France is Paris" (double space)
|
||||
```
|
||||
|
||||
**Solution:** Guidance backs up and regenerates the last token.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Token healing enabled by default
|
||||
lm += "The capital of France is " + gen("capital", max_tokens=5)
|
||||
|
||||
# Process:
|
||||
# 1. Back up to token before " is "
|
||||
# 2. Regenerate " is" + "capital" together
|
||||
# 3. Result: "The capital of France is Paris" (correct)
|
||||
```
|
||||
|
||||
### Token Healing Examples
|
||||
|
||||
#### Natural Continuations
|
||||
|
||||
```python
|
||||
# Before token healing
|
||||
lm += "The function name is get" + gen("rest")
|
||||
# Might generate: "The function name is get User" (space before User)
|
||||
|
||||
# With token healing
|
||||
lm += "The function name is get" + gen("rest")
|
||||
# Generates: "The function name is getUser" (correct camelCase)
|
||||
```
|
||||
|
||||
#### Code Generation
|
||||
|
||||
```python
|
||||
# Function name completion
|
||||
lm += "def calculate_" + gen("rest", stop="(")
|
||||
# Token healing ensures smooth connection: "calculate_total"
|
||||
|
||||
# Variable name completion
|
||||
lm += "my_" + gen("var_name", regex=r"[a-z_]+")
|
||||
# Token healing ensures: "my_variable_name" (not "my_ variable_name")
|
||||
```
|
||||
|
||||
#### Domain-Specific Terms
|
||||
|
||||
```python
|
||||
# Medical terms
|
||||
lm += "The patient has hyper" + gen("condition")
|
||||
# Token healing helps: "hypertension" (not "hyper tension")
|
||||
|
||||
# Technical terms
|
||||
lm += "Using micro" + gen("tech")
|
||||
# Token healing helps: "microservices" (not "micro services")
|
||||
```
|
||||
|
||||
### Disabling Token Healing
|
||||
|
||||
```python
|
||||
# Disable token healing if needed (rare)
|
||||
lm += gen("text", token_healing=False)
|
||||
```
|
||||
|
||||
## Selection Constraints
|
||||
|
||||
### Basic Selection
|
||||
|
||||
```python
|
||||
from guidance import models, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Simple selection
|
||||
lm += "Status: " + select(["active", "inactive", "pending"], name="status")
|
||||
|
||||
# Boolean selection
|
||||
lm += "Approved: " + select(["Yes", "No"], name="approved")
|
||||
|
||||
# Multiple choice
|
||||
lm += "Answer: " + select(
|
||||
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
|
||||
name="answer"
|
||||
)
|
||||
```
|
||||
|
||||
### Conditional Selection
|
||||
|
||||
```python
|
||||
from guidance import models, select, gen, guidance
|
||||
|
||||
@guidance
|
||||
def conditional_fields(lm):
|
||||
"""Generate fields conditionally based on type."""
|
||||
lm += "Type: " + select(["person", "company"], name="type")
|
||||
|
||||
if lm["type"] == "person":
|
||||
lm += "\nName: " + gen("name", regex=r"[A-Za-z ]+")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
|
||||
else:
|
||||
lm += "\nCompany Name: " + gen("company", regex=r"[A-Za-z ]+")
|
||||
lm += "\nEmployees: " + gen("employees", regex=r"[0-9]+")
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Repeated Selection
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def multiple_selections(lm):
|
||||
"""Select multiple items."""
|
||||
lm += "Select 3 colors:\n"
|
||||
|
||||
colors = ["red", "blue", "green", "yellow", "purple"]
|
||||
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + select(colors, name=f"color_{i}") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Complex Patterns
|
||||
|
||||
### Pattern 1: Structured Forms
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def user_form(lm):
|
||||
"""Generate structured user form."""
|
||||
lm += "=== User Registration ===\n\n"
|
||||
|
||||
# Name (alphabetic only)
|
||||
lm += "Full Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Age (numeric)
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
|
||||
|
||||
# Email (validated format)
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
# Phone (US format)
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") + "\n"
|
||||
|
||||
# Account type (selection)
|
||||
lm += "Account Type: " + select(
|
||||
["Standard", "Premium", "Enterprise"],
|
||||
name="account_type"
|
||||
) + "\n"
|
||||
|
||||
# Active status (boolean)
|
||||
lm += "Active: " + select(["Yes", "No"], name="active") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 2: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract multiple entities with constraints."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
# Person name (alphabetic)
|
||||
lm += "Person: " + gen("person", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Organization (alphanumeric with spaces)
|
||||
lm += "Organization: " + gen(
|
||||
"organization",
|
||||
regex=r"[A-Za-z0-9 ]+",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
# Date (YYYY-MM-DD format)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") + "\n"
|
||||
|
||||
# Location (alphabetic with spaces)
|
||||
lm += "Location: " + gen("location", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Amount (currency)
|
||||
lm += "Amount: $" + gen("amount", regex=r"[0-9,]+\.[0-9]{2}") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 3: Code Generation
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_python_function(lm):
|
||||
"""Generate Python function with constraints."""
|
||||
# Function name (valid Python identifier)
|
||||
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
|
||||
|
||||
# Parameter name
|
||||
lm += gen("param", regex=r"[a-z_][a-z0-9_]*") + "):\n"
|
||||
|
||||
# Docstring
|
||||
lm += ' """' + gen("docstring", stop='"""', max_tokens=50) + '"""\n'
|
||||
|
||||
# Function body (constrained to valid Python)
|
||||
lm += " return " + gen("return_value", stop="\n") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 4: Hierarchical Data
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def org_chart(lm):
|
||||
"""Generate organizational chart."""
|
||||
lm += "Company: " + gen("company", regex=r"[A-Za-z ]+") + "\n\n"
|
||||
|
||||
# CEO
|
||||
lm += "CEO: " + gen("ceo", regex=r"[A-Za-z ]+") + "\n"
|
||||
|
||||
# Departments
|
||||
for dept in ["Engineering", "Sales", "Marketing"]:
|
||||
lm += f"\n{dept} Department:\n"
|
||||
lm += " Head: " + gen(f"{dept.lower()}_head", regex=r"[A-Za-z ]+") + "\n"
|
||||
lm += " Size: " + gen(f"{dept.lower()}_size", regex=r"[0-9]+") + " employees\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Best Practices
|
||||
|
||||
#### 1. Use Specific Patterns
|
||||
|
||||
```python
|
||||
# ✅ Good: Specific pattern
|
||||
lm += gen("age", regex=r"[0-9]{1,3}") # Fast
|
||||
|
||||
# ❌ Bad: Overly broad pattern
|
||||
lm += gen("age", regex=r"[0-9]+") # Slower
|
||||
```
|
||||
|
||||
#### 2. Limit Max Tokens
|
||||
|
||||
```python
|
||||
# ✅ Good: Reasonable limit
|
||||
lm += gen("name", max_tokens=30)
|
||||
|
||||
# ❌ Bad: No limit
|
||||
lm += gen("name") # May generate forever
|
||||
```
|
||||
|
||||
#### 3. Use stop Sequences
|
||||
|
||||
```python
|
||||
# ✅ Good: Stop at newline
|
||||
lm += gen("line", stop="\n")
|
||||
|
||||
# ❌ Bad: Rely on max_tokens
|
||||
lm += gen("line", max_tokens=100)
|
||||
```
|
||||
|
||||
#### 4. Cache Compiled Grammars
|
||||
|
||||
```python
|
||||
# Grammars are cached automatically after first use
|
||||
# No manual caching needed
|
||||
@guidance
|
||||
def reusable_pattern(lm):
|
||||
"""This grammar is compiled once and cached."""
|
||||
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
return lm
|
||||
|
||||
# First call: compiles grammar
|
||||
lm = reusable_pattern(lm)
|
||||
|
||||
# Subsequent calls: uses cached grammar (fast)
|
||||
lm = reusable_pattern(lm)
|
||||
```
|
||||
|
||||
#### 5. Avoid Overlapping Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: Clear constraints
|
||||
lm += gen("age", regex=r"[0-9]+", max_tokens=3)
|
||||
|
||||
# ❌ Bad: Conflicting constraints
|
||||
lm += gen("age", regex=r"[0-9]{2}", max_tokens=10) # max_tokens unnecessary
|
||||
```
|
||||
|
||||
### Performance Benchmarks
|
||||
|
||||
**Regex vs Free Generation:**
|
||||
- Simple regex (digits): ~1.2x slower than free gen
|
||||
- Complex regex (email): ~1.5x slower than free gen
|
||||
- Grammar-based: ~2x slower than free gen
|
||||
|
||||
**But:**
|
||||
- 100% valid outputs (vs ~70% with free gen + validation)
|
||||
- No retry loops needed
|
||||
- Overall faster end-to-end for structured outputs
|
||||
|
||||
**Optimization Tips:**
|
||||
- Use regex for critical fields only
|
||||
- Use `select()` for small fixed sets (fastest)
|
||||
- Use `stop` sequences when possible (faster than max_tokens)
|
||||
- Cache compiled grammars by reusing functions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Token Healing Paper**: https://arxiv.org/abs/2306.17648
|
||||
- **Guidance Docs**: https://guidance.readthedocs.io
|
||||
- **GitHub**: https://github.com/guidance-ai/guidance
|
||||
767
skills/mlops/guidance/references/examples.md
Normal file
767
skills/mlops/guidance/references/examples.md
Normal file
@@ -0,0 +1,767 @@
|
||||
# Production-Ready Examples
|
||||
|
||||
Real-world examples of using Guidance for structured generation, agents, and workflows.
|
||||
|
||||
## Table of Contents
|
||||
- JSON Generation
|
||||
- Data Extraction
|
||||
- Classification Systems
|
||||
- Agent Systems
|
||||
- Multi-Step Workflows
|
||||
- Code Generation
|
||||
- Production Tips
|
||||
|
||||
## JSON Generation
|
||||
|
||||
### Basic JSON
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def generate_user(lm):
|
||||
"""Generate valid user JSON."""
|
||||
lm += "{\n"
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "email": ' + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + "\n"
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
# Use it
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm += "Generate a user profile:\n"
|
||||
lm = generate_user(lm)
|
||||
|
||||
print(lm)
|
||||
# Output: Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Nested JSON
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_order(lm):
|
||||
"""Generate nested order JSON."""
|
||||
lm += "{\n"
|
||||
|
||||
# Customer info
|
||||
lm += ' "customer": {\n'
|
||||
lm += ' "name": ' + gen("customer_name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "email": ' + gen(
|
||||
"customer_email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Order details
|
||||
lm += ' "order": {\n'
|
||||
lm += ' "id": ' + gen("order_id", regex=r'"ORD-[0-9]{6}"') + ",\n"
|
||||
lm += ' "date": ' + gen("order_date", regex=r'"\d{4}-\d{2}-\d{2}"') + ",\n"
|
||||
lm += ' "total": ' + gen("order_total", regex=r"[0-9]+\.[0-9]{2}") + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Status
|
||||
lm += ' "status": ' + gen(
|
||||
"status",
|
||||
regex=r'"(pending|processing|shipped|delivered)"'
|
||||
) + "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_order(lm)
|
||||
```
|
||||
|
||||
### JSON Array
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_user_list(lm, count=3):
|
||||
"""Generate JSON array of users."""
|
||||
lm += "[\n"
|
||||
|
||||
for i in range(count):
|
||||
lm += " {\n"
|
||||
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "active": ' + gen(f"active_{i}", regex=r"(true|false)") + "\n"
|
||||
lm += " }"
|
||||
if i < count - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "]"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_user_list(lm, count=5)
|
||||
```
|
||||
|
||||
### Dynamic JSON Schema
|
||||
|
||||
```python
|
||||
import json
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def json_from_schema(lm, schema):
|
||||
"""Generate JSON matching a schema."""
|
||||
lm += "{\n"
|
||||
|
||||
fields = list(schema["properties"].items())
|
||||
for i, (field_name, field_schema) in enumerate(fields):
|
||||
lm += f' "{field_name}": '
|
||||
|
||||
# Handle different types
|
||||
if field_schema["type"] == "string":
|
||||
if "pattern" in field_schema:
|
||||
lm += gen(field_name, regex=f'"{field_schema["pattern"]}"')
|
||||
else:
|
||||
lm += gen(field_name, regex=r'"[^"]+"')
|
||||
elif field_schema["type"] == "number":
|
||||
lm += gen(field_name, regex=r"[0-9]+(\.[0-9]+)?")
|
||||
elif field_schema["type"] == "integer":
|
||||
lm += gen(field_name, regex=r"[0-9]+")
|
||||
elif field_schema["type"] == "boolean":
|
||||
lm += gen(field_name, regex=r"(true|false)")
|
||||
|
||||
if i < len(fields) - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
# Define schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"score": {"type": "number"},
|
||||
"active": {"type": "boolean"}
|
||||
}
|
||||
}
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = json_from_schema(lm, schema)
|
||||
```
|
||||
|
||||
## Data Extraction
|
||||
|
||||
### Extract from Text
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance, system, user, assistant
|
||||
|
||||
@guidance
|
||||
def extract_person_info(lm, text):
|
||||
"""Extract structured info from text."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
with assistant():
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
|
||||
lm += "Occupation: " + gen("occupation", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = "John Smith is a 35-year-old software engineer. Contact: john@example.com"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You extract structured information from text."
|
||||
|
||||
with user():
|
||||
lm = extract_person_info(lm, text)
|
||||
|
||||
print(f"Name: {lm['name']}")
|
||||
print(f"Age: {lm['age']}")
|
||||
print(f"Occupation: {lm['occupation']}")
|
||||
print(f"Email: {lm['email']}")
|
||||
```
|
||||
|
||||
### Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract multiple entity types."""
|
||||
lm += f"Analyze: {text}\n\n"
|
||||
|
||||
# Person entities
|
||||
lm += "People:\n"
|
||||
for i in range(3): # Up to 3 people
|
||||
lm += f"- " + gen(f"person_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Organization entities
|
||||
lm += "\nOrganizations:\n"
|
||||
for i in range(2): # Up to 2 orgs
|
||||
lm += f"- " + gen(f"org_{i}", regex=r"[A-Za-z0-9 ]+", stop="\n") + "\n"
|
||||
|
||||
# Dates
|
||||
lm += "\nDates:\n"
|
||||
for i in range(2): # Up to 2 dates
|
||||
lm += f"- " + gen(f"date_{i}", regex=r"\d{4}-\d{2}-\d{2}", stop="\n") + "\n"
|
||||
|
||||
# Locations
|
||||
lm += "\nLocations:\n"
|
||||
for i in range(2): # Up to 2 locations
|
||||
lm += f"- " + gen(f"location_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = """
|
||||
Tim Cook and Satya Nadella met at Microsoft headquarters in Redmond on 2024-09-15
|
||||
to discuss the collaboration between Apple and Microsoft. The meeting continued
|
||||
in Cupertino on 2024-09-20.
|
||||
"""
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = extract_entities(lm, text)
|
||||
```
|
||||
|
||||
### Batch Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def batch_extract(lm, texts):
|
||||
"""Extract from multiple texts."""
|
||||
lm += "Batch Extraction Results:\n\n"
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
lm += f"=== Item {i+1} ===\n"
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen(f"name_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Sentiment: " + gen(
|
||||
f"sentiment_{i}",
|
||||
regex=r"(positive|negative|neutral)",
|
||||
stop="\n"
|
||||
) + "\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
texts = [
|
||||
"Alice is happy with the product",
|
||||
"Bob is disappointed with the service",
|
||||
"Carol has no strong feelings either way"
|
||||
]
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = batch_extract(lm, texts)
|
||||
```
|
||||
|
||||
## Classification Systems
|
||||
|
||||
### Sentiment Analysis
|
||||
|
||||
```python
|
||||
from guidance import models, select, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
text = "This product is absolutely amazing! Best purchase ever."
|
||||
|
||||
lm += f"Text: {text}\n\n"
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "negative", "neutral"],
|
||||
name="sentiment"
|
||||
)
|
||||
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]{1,3}") + "%\n"
|
||||
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=50)
|
||||
|
||||
print(f"Sentiment: {lm['sentiment']}")
|
||||
print(f"Confidence: {lm['confidence']}%")
|
||||
print(f"Reasoning: {lm['reasoning']}")
|
||||
```
|
||||
|
||||
### Multi-Label Classification
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def classify_article(lm, text):
|
||||
"""Classify article with multiple labels."""
|
||||
lm += f"Article: {text}\n\n"
|
||||
|
||||
# Primary category
|
||||
lm += "Primary Category: " + select(
|
||||
["Technology", "Business", "Science", "Politics", "Entertainment"],
|
||||
name="primary_category"
|
||||
) + "\n"
|
||||
|
||||
# Secondary categories (up to 3)
|
||||
lm += "\nSecondary Categories:\n"
|
||||
categories = ["Technology", "Business", "Science", "Politics", "Entertainment"]
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + select(categories, name=f"secondary_{i}") + "\n"
|
||||
|
||||
# Tags
|
||||
lm += "\nTags: " + gen("tags", stop="\n", max_tokens=50) + "\n"
|
||||
|
||||
# Target audience
|
||||
lm += "Target Audience: " + select(
|
||||
["General", "Expert", "Beginner"],
|
||||
name="audience"
|
||||
)
|
||||
|
||||
return lm
|
||||
|
||||
article = """
|
||||
Apple announced new AI features in iOS 18, leveraging machine learning to improve
|
||||
battery life and performance. The company's stock rose 5% following the announcement.
|
||||
"""
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = classify_article(lm, article)
|
||||
```
|
||||
|
||||
### Intent Classification
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def classify_intent(lm, message):
|
||||
"""Classify user intent."""
|
||||
lm += f"User Message: {message}\n\n"
|
||||
|
||||
# Intent
|
||||
lm += "Intent: " + select(
|
||||
["question", "complaint", "request", "feedback", "other"],
|
||||
name="intent"
|
||||
) + "\n"
|
||||
|
||||
# Urgency
|
||||
lm += "Urgency: " + select(
|
||||
["low", "medium", "high", "critical"],
|
||||
name="urgency"
|
||||
) + "\n"
|
||||
|
||||
# Department
|
||||
lm += "Route To: " + select(
|
||||
["support", "sales", "billing", "technical"],
|
||||
name="department"
|
||||
) + "\n"
|
||||
|
||||
# Sentiment
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "neutral", "negative"],
|
||||
name="sentiment"
|
||||
)
|
||||
|
||||
return lm
|
||||
|
||||
message = "My account was charged twice for the same order. Need help ASAP!"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = classify_intent(lm, message)
|
||||
|
||||
print(f"Intent: {lm['intent']}")
|
||||
print(f"Urgency: {lm['urgency']}")
|
||||
print(f"Department: {lm['department']}")
|
||||
```
|
||||
|
||||
## Agent Systems
|
||||
|
||||
### ReAct Agent
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select, guidance
|
||||
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question, tools, max_rounds=5):
|
||||
"""ReAct agent with tool use."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for round in range(max_rounds):
|
||||
# Thought
|
||||
lm += f"Thought {round+1}: " + gen("thought", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Action selection
|
||||
lm += "Action: " + select(
|
||||
list(tools.keys()) + ["answer"],
|
||||
name="action"
|
||||
)
|
||||
|
||||
if lm["action"] == "answer":
|
||||
lm += "\n\nFinal Answer: " + gen("answer", max_tokens=200)
|
||||
break
|
||||
|
||||
# Action input
|
||||
lm += "\nAction Input: " + gen("action_input", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Execute tool
|
||||
if lm["action"] in tools:
|
||||
try:
|
||||
result = tools[lm["action"]](lm["action_input"])
|
||||
lm += f"Observation: {result}\n\n"
|
||||
except Exception as e:
|
||||
lm += f"Observation: Error - {str(e)}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
# Define tools
|
||||
tools = {
|
||||
"calculator": lambda expr: eval(expr),
|
||||
"search": lambda query: f"Search results for '{query}': [Mock results]",
|
||||
"weather": lambda city: f"Weather in {city}: Sunny, 72°F"
|
||||
}
|
||||
|
||||
# Use agent
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = react_agent(lm, "What is (25 * 4) + 10?", tools)
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Multi-Agent System
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def coordinator_agent(lm, task):
|
||||
"""Coordinator that delegates to specialists."""
|
||||
lm += f"Task: {task}\n\n"
|
||||
|
||||
# Determine which specialist to use
|
||||
lm += "Specialist: " + select(
|
||||
["researcher", "writer", "coder", "analyst"],
|
||||
name="specialist"
|
||||
) + "\n"
|
||||
|
||||
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def researcher_agent(lm, query):
|
||||
"""Research specialist."""
|
||||
lm += f"Research Query: {query}\n\n"
|
||||
lm += "Findings:\n"
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + gen(f"finding_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def writer_agent(lm, topic):
|
||||
"""Writing specialist."""
|
||||
lm += f"Topic: {topic}\n\n"
|
||||
lm += "Title: " + gen("title", stop="\n", max_tokens=50) + "\n"
|
||||
lm += "Content:\n" + gen("content", max_tokens=500)
|
||||
return lm
|
||||
|
||||
# Coordination workflow
|
||||
task = "Write an article about AI safety"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = coordinator_agent(lm, task)
|
||||
|
||||
specialist = lm["specialist"]
|
||||
if specialist == "researcher":
|
||||
lm = researcher_agent(lm, task)
|
||||
elif specialist == "writer":
|
||||
lm = writer_agent(lm, task)
|
||||
```
|
||||
|
||||
### Tool Use with Validation
|
||||
|
||||
```python
|
||||
@guidance(stateless=False)
|
||||
def validated_tool_agent(lm, question):
|
||||
"""Agent with validated tool calls."""
|
||||
tools = {
|
||||
"add": lambda a, b: float(a) + float(b),
|
||||
"multiply": lambda a, b: float(a) * float(b),
|
||||
"divide": lambda a, b: float(a) / float(b) if float(b) != 0 else "Error: Division by zero"
|
||||
}
|
||||
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for i in range(5):
|
||||
# Select tool
|
||||
lm += "Tool: " + select(list(tools.keys()) + ["done"], name="tool")
|
||||
|
||||
if lm["tool"] == "done":
|
||||
lm += "\nAnswer: " + gen("answer", max_tokens=100)
|
||||
break
|
||||
|
||||
# Get validated numeric arguments
|
||||
lm += "\nArg1: " + gen("arg1", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
|
||||
lm += "Arg2: " + gen("arg2", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
|
||||
|
||||
# Execute
|
||||
result = tools[lm["tool"]](lm["arg1"], lm["arg2"])
|
||||
lm += f"Result: {result}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = validated_tool_agent(lm, "What is (10 + 5) * 3?")
|
||||
```
|
||||
|
||||
## Multi-Step Workflows
|
||||
|
||||
### Chain of Thought
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def chain_of_thought(lm, question):
|
||||
"""Multi-step reasoning with CoT."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
# Generate reasoning steps
|
||||
lm += "Let me think step by step:\n\n"
|
||||
for i in range(4):
|
||||
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Final answer
|
||||
lm += "\nTherefore, the answer is: " + gen("answer", stop="\n", max_tokens=50)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = chain_of_thought(lm, "If a train travels 60 mph for 2.5 hours, how far does it go?")
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Self-Consistency
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def self_consistency(lm, question, num_samples=3):
|
||||
"""Generate multiple reasoning paths and aggregate."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
answers = []
|
||||
for i in range(num_samples):
|
||||
lm += f"=== Attempt {i+1} ===\n"
|
||||
lm += "Reasoning: " + gen(f"reasoning_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
lm += "Answer: " + gen(f"answer_{i}", stop="\n", max_tokens=50) + "\n\n"
|
||||
answers.append(lm[f"answer_{i}"])
|
||||
|
||||
# Aggregate (simple majority vote)
|
||||
from collections import Counter
|
||||
most_common = Counter(answers).most_common(1)[0][0]
|
||||
|
||||
lm += f"Final Answer (by majority): {most_common}\n"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = self_consistency(lm, "What is 15% of 200?")
|
||||
```
|
||||
|
||||
### Planning and Execution
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def plan_and_execute(lm, goal):
|
||||
"""Plan tasks then execute them."""
|
||||
lm += f"Goal: {goal}\n\n"
|
||||
|
||||
# Planning phase
|
||||
lm += "Plan:\n"
|
||||
num_steps = 4
|
||||
for i in range(num_steps):
|
||||
lm += f"{i+1}. " + gen(f"plan_step_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Execution phase
|
||||
lm += "\nExecution:\n\n"
|
||||
for i in range(num_steps):
|
||||
lm += f"Step {i+1}: {lm[f'plan_step_{i}']}\n"
|
||||
lm += "Status: " + select(["completed", "in-progress", "blocked"], name=f"status_{i}") + "\n"
|
||||
lm += "Result: " + gen(f"result_{i}", stop="\n", max_tokens=150) + "\n\n"
|
||||
|
||||
# Summary
|
||||
lm += "Summary: " + gen("summary", max_tokens=200)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = plan_and_execute(lm, "Build a REST API for a blog platform")
|
||||
```
|
||||
|
||||
## Code Generation
|
||||
|
||||
### Python Function
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_python_function(lm, description):
|
||||
"""Generate Python function from description."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
|
||||
# Function signature
|
||||
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
|
||||
lm += gen("params", regex=r"[a-z_][a-z0-9_]*(, [a-z_][a-z0-9_]*)*") + "):\n"
|
||||
|
||||
# Docstring
|
||||
lm += ' """' + gen("docstring", stop='"""', max_tokens=100) + '"""\n'
|
||||
|
||||
# Function body
|
||||
lm += " " + gen("body", stop="\n", max_tokens=200) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_python_function(lm, "Check if a number is prime")
|
||||
|
||||
print(lm)
|
||||
```
|
||||
|
||||
### SQL Query
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_sql(lm, description):
|
||||
"""Generate SQL query from description."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
lm += "SQL Query:\n"
|
||||
|
||||
# SELECT clause
|
||||
lm += "SELECT " + gen("select_clause", stop=" FROM", max_tokens=100)
|
||||
|
||||
# FROM clause
|
||||
lm += " FROM " + gen("from_clause", stop=" WHERE", max_tokens=50)
|
||||
|
||||
# WHERE clause (optional)
|
||||
lm += " WHERE " + gen("where_clause", stop=";", max_tokens=100) + ";"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_sql(lm, "Get all users who signed up in the last 30 days")
|
||||
```
|
||||
|
||||
### API Endpoint
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_api_endpoint(lm, description):
|
||||
"""Generate REST API endpoint."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
|
||||
# HTTP method
|
||||
lm += "Method: " + select(["GET", "POST", "PUT", "DELETE"], name="method") + "\n"
|
||||
|
||||
# Path
|
||||
lm += "Path: /" + gen("path", regex=r"[a-z0-9/-]+", stop="\n") + "\n"
|
||||
|
||||
# Request body (if POST/PUT)
|
||||
if lm["method"] in ["POST", "PUT"]:
|
||||
lm += "\nRequest Body:\n"
|
||||
lm += "{\n"
|
||||
lm += ' "field1": ' + gen("field1", regex=r'"[a-z_]+"') + ",\n"
|
||||
lm += ' "field2": ' + gen("field2", regex=r'"[a-z_]+"') + "\n"
|
||||
lm += "}\n"
|
||||
|
||||
# Response
|
||||
lm += "\nResponse (200 OK):\n"
|
||||
lm += "{\n"
|
||||
lm += ' "status": "success",\n'
|
||||
lm += ' "data": ' + gen("response_data", max_tokens=100) + "\n"
|
||||
lm += "}\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_api_endpoint(lm, "Create a new blog post")
|
||||
```
|
||||
|
||||
## Production Tips
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def safe_extraction(lm, text):
|
||||
"""Extract with fallback handling."""
|
||||
try:
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n", max_tokens=30)
|
||||
return lm
|
||||
except Exception as e:
|
||||
# Fallback to less strict extraction
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen("name", stop="\n", max_tokens=30)
|
||||
return lm
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def cached_generation(text):
|
||||
"""Cache LLM generations."""
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm += f"Analyze: {text}\n"
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
return lm["sentiment"]
|
||||
|
||||
# First call: hits LLM
|
||||
result1 = cached_generation("This is great!")
|
||||
|
||||
# Second call: returns cached result
|
||||
result2 = cached_generation("This is great!") # Instant!
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
@guidance
|
||||
def monitored_generation(lm, text):
|
||||
"""Track generation metrics."""
|
||||
start_time = time.time()
|
||||
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Analysis: " + gen("analysis", max_tokens=100)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Log metrics
|
||||
print(f"Generation time: {elapsed:.2f}s")
|
||||
print(f"Output length: {len(lm['analysis'])} chars")
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
def batch_process(texts, batch_size=10):
|
||||
"""Process texts in batches."""
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
results = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
|
||||
for text in batch:
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "negative", "neutral"],
|
||||
name=f"sentiment_{i}"
|
||||
) + "\n\n"
|
||||
|
||||
results.extend([lm[f"sentiment_{i}"] for i in range(len(batch))])
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Guidance Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
|
||||
- **Guidance Docs**: https://guidance.readthedocs.io
|
||||
- **Community Examples**: https://github.com/guidance-ai/guidance/discussions
|
||||
307
skills/mlops/llava/SKILL.md
Normal file
307
skills/mlops/llava/SKILL.md
Normal file
@@ -0,0 +1,307 @@
|
||||
---
|
||||
name: llava
|
||||
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers, torch, pillow]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
|
||||
|
||||
---
|
||||
|
||||
# LLaVA - Large Language and Vision Assistant
|
||||
|
||||
Open-source vision-language model for conversational image understanding.
|
||||
|
||||
## When to use LLaVA
|
||||
|
||||
**Use when:**
|
||||
- Building vision-language chatbots
|
||||
- Visual question answering (VQA)
|
||||
- Image description and captioning
|
||||
- Multi-turn image conversations
|
||||
- Visual instruction following
|
||||
- Document understanding with images
|
||||
|
||||
**Metrics**:
|
||||
- **23,000+ GitHub stars**
|
||||
- GPT-4V level capabilities (targeted)
|
||||
- Apache 2.0 License
|
||||
- Multiple model sizes (7B-34B params)
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **GPT-4V**: Highest quality, API-based
|
||||
- **CLIP**: Simple zero-shot classification
|
||||
- **BLIP-2**: Better for captioning only
|
||||
- **Flamingo**: Research, not open-source
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/haotian-liu/LLaVA
|
||||
cd LLaVA
|
||||
|
||||
# Install
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
||||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from llava.conversation import conv_templates
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# Load model
|
||||
model_path = "liuhaotian/llava-v1.5-7b"
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path)
|
||||
)
|
||||
|
||||
# Load image
|
||||
image = Image.open("image.jpg")
|
||||
image_tensor = process_images([image], image_processor, model.config)
|
||||
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
||||
|
||||
# Create conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
# Generate response
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
max_new_tokens=512
|
||||
)
|
||||
|
||||
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Available models
|
||||
|
||||
| Model | Parameters | VRAM | Quality |
|
||||
|-------|------------|------|---------|
|
||||
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
|
||||
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
|
||||
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
|
||||
|
||||
```python
|
||||
# Load different models
|
||||
model_7b = "liuhaotian/llava-v1.5-7b"
|
||||
model_13b = "liuhaotian/llava-v1.5-13b"
|
||||
model_34b = "liuhaotian/llava-v1.6-34b"
|
||||
|
||||
# 4-bit quantization for lower VRAM
|
||||
load_4bit = True # Reduces VRAM by ~4×
|
||||
```
|
||||
|
||||
## CLI usage
|
||||
|
||||
```bash
|
||||
# Single image query
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg \
|
||||
--query "What is in this image?"
|
||||
|
||||
# Multi-turn conversation
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg
|
||||
# Then type questions interactively
|
||||
```
|
||||
|
||||
## Web UI (Gradio)
|
||||
|
||||
```bash
|
||||
# Launch Gradio interface
|
||||
python -m llava.serve.gradio_web_server \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--load-4bit # Optional: reduce VRAM
|
||||
|
||||
# Access at http://localhost:7860
|
||||
```
|
||||
|
||||
## Multi-turn conversations
|
||||
|
||||
```python
|
||||
# Initialize conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
|
||||
# Turn 1
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response1 = generate(conv, model, image) # "A dog playing in a park"
|
||||
|
||||
# Turn 2
|
||||
conv.messages[-1][1] = response1 # Add previous response
|
||||
conv.append_message(conv.roles[0], "What breed is the dog?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response2 = generate(conv, model, image) # "Golden Retriever"
|
||||
|
||||
# Turn 3
|
||||
conv.messages[-1][1] = response2
|
||||
conv.append_message(conv.roles[0], "What time of day is it?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response3 = generate(conv, model, image)
|
||||
```
|
||||
|
||||
## Common tasks
|
||||
|
||||
### Image captioning
|
||||
|
||||
```python
|
||||
question = "Describe this image in detail."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Visual question answering
|
||||
|
||||
```python
|
||||
question = "How many people are in the image?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Object detection (textual)
|
||||
|
||||
```python
|
||||
question = "List all the objects you can see in this image."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Scene understanding
|
||||
|
||||
```python
|
||||
question = "What is happening in this scene?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Document understanding
|
||||
|
||||
```python
|
||||
question = "What is the main topic of this document?"
|
||||
response = ask(model, document_image, question)
|
||||
```
|
||||
|
||||
## Training custom model
|
||||
|
||||
```bash
|
||||
# Stage 1: Feature alignment (558K image-caption pairs)
|
||||
bash scripts/v1_5/pretrain.sh
|
||||
|
||||
# Stage 2: Visual instruction tuning (150K instruction data)
|
||||
bash scripts/v1_5/finetune.sh
|
||||
```
|
||||
|
||||
## Quantization (reduce VRAM)
|
||||
|
||||
```python
|
||||
# 4-bit quantization
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path="liuhaotian/llava-v1.5-13b",
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
|
||||
load_4bit=True # Reduces VRAM ~4×
|
||||
)
|
||||
|
||||
# 8-bit quantization
|
||||
load_8bit=True # Reduces VRAM ~2×
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with 7B model** - Good quality, manageable VRAM
|
||||
2. **Use 4-bit quantization** - Reduces VRAM significantly
|
||||
3. **GPU required** - CPU inference extremely slow
|
||||
4. **Clear prompts** - Specific questions get better answers
|
||||
5. **Multi-turn conversations** - Maintain conversation context
|
||||
6. **Temperature 0.2-0.7** - Balance creativity/consistency
|
||||
7. **max_new_tokens 512-1024** - For detailed responses
|
||||
8. **Batch processing** - Process multiple images sequentially
|
||||
|
||||
## Performance
|
||||
|
||||
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|
||||
|-------|-------------|--------------|------------------|
|
||||
| 7B | ~14 GB | ~4 GB | ~20 |
|
||||
| 13B | ~28 GB | ~8 GB | ~12 |
|
||||
| 34B | ~70 GB | ~18 GB | ~5 |
|
||||
|
||||
*On A100 GPU*
|
||||
|
||||
## Benchmarks
|
||||
|
||||
LLaVA achieves competitive scores on:
|
||||
- **VQAv2**: 78.5%
|
||||
- **GQA**: 62.0%
|
||||
- **MM-Vet**: 35.4%
|
||||
- **MMBench**: 64.3%
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Hallucinations** - May describe things not in image
|
||||
2. **Spatial reasoning** - Struggles with precise locations
|
||||
3. **Small text** - Difficulty reading fine print
|
||||
4. **Object counting** - Imprecise for many objects
|
||||
5. **VRAM requirements** - Need powerful GPU
|
||||
6. **Inference speed** - Slower than CLIP
|
||||
|
||||
## Integration with frameworks
|
||||
|
||||
### LangChain
|
||||
|
||||
```python
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
class LLaVALLM(LLM):
|
||||
def _call(self, prompt, stop=None):
|
||||
# Custom LLaVA inference
|
||||
return response
|
||||
|
||||
llm = LLaVALLM()
|
||||
```
|
||||
|
||||
### Gradio App
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def chat(image, text, history):
|
||||
response = ask_llava(model, image, text)
|
||||
return response
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
chat,
|
||||
additional_inputs=[gr.Image(type="pil")],
|
||||
title="LLaVA Chat"
|
||||
)
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
|
||||
- **Paper**: https://arxiv.org/abs/2304.08485
|
||||
- **Demo**: https://llava.hliu.cc
|
||||
- **Models**: https://huggingface.co/liuhaotian
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user