Compare commits

...

212 Commits

Author SHA1 Message Date
dmahan93
5f9c02bb37 fix: skip tests when atroposlib/minisweagent unavailable in CI
- test_agent_loop_tool_calling.py: import atroposlib at module level
  to trigger skip (environments.agent_loop is now importable without
  atroposlib due to __init__.py graceful fallback)
- test_modal_sandbox_fixes.py: skip TestToolResolution tests when
  minisweagent not installed
2026-03-09 23:37:32 -05:00
dmahan93
3dbeaea3dc fix: guard all atroposlib imports for CI without atropos installed
- environments/__init__.py: try/except on atroposlib imports so
  submodules like tool_call_parsers remain importable standalone
- test_agent_loop.py, test_tool_call_parsers.py,
  test_managed_server_tool_support.py: skip at module level when
  atroposlib is missing
2026-03-09 23:33:24 -05:00
dmahan93
26d9b5af29 test: skip atropos-dependent tests when atroposlib not installed
Guard all test files that import from environments/ or atroposlib
with try/except + pytest.skip(allow_module_level=True) so they
gracefully skip instead of crashing when deps aren't available.
2026-03-09 23:14:53 -05:00
dmahan93
ef8cb9afd2 add a local vllm instance 2026-03-09 23:02:13 -05:00
dmahan93
407a1e24b2 fix: use ManagedServer for vLLM in TBLite eval + local_vllm config
TBLite eval was bypassing ManagedServer and calling ServerManager
directly, which uses /v1/chat/completions — not available on the
atropos vllm_api_server (/generate only).

Now uses _use_managed_server() to detect vLLM/SGLang backends and
route through ManagedServer (Phase 2) with proper tool_parser and
/generate endpoint. Falls back to Phase 1 for OpenAI endpoints.

Also adds local_vllm.yaml config for running against a local vLLM
server with Docker sandboxes.
2026-03-09 21:32:23 -05:00
dmahan93
e1e69dfd32 fix: handle dict and object tool_calls in agent loop
vLLM's ToolCallTranslator returns tool_calls as dicts, while
OpenAI API returns them as objects with .id, .function.name etc.
Normalize both formats in the agent loop.
2026-03-09 21:21:49 -05:00
dmahan93
003b6e49df test: 5 vLLM integration tests + fallback tool call parser
Tests hit a real vLLM server (Qwen/Qwen3-4B-Thinking-2507) via
ManagedServer Phase 2. Auto-skip if server isn't running.

Tests verify:
- Single tool call through full agent loop
- Multi-tool calls across turns
- ManagedServer produces SequenceNodes with tokens/logprobs
- Direct response without tools
- Thinking model produces <think> blocks

Also adds fallback parser in agent_loop.py: when ManagedServer's
ToolCallTranslator can't parse (vLLM not installed), hermes-agent's
standalone parsers extract <tool_call> tags from raw content.
2026-03-09 21:18:42 -05:00
dmahan93
dab2cfe566 add eval output to gitignore 2026-03-09 21:01:36 -05:00
dmahan93
c87bd5dd87 refactor: update to new atropos tool-calling API
Migrate from old tool_call_parser (instance) to new ToolCallTranslator
pattern from atropos add-openai-endpoint-for-managed-server branch:

- Set tool_parser on ServerManager (string name, e.g. 'hermes')
- Use managed_server(tokenizer=..., preserve_think_blocks=...)
  instead of managed_server(tokenizer=..., tool_call_parser=instance)
- ManagedServer now handles tool call translation internally via
  ToolCallTranslator (bidirectional raw text <-> OpenAI tool_calls)
- Remove old parser loading code (get_parser/KeyError fallback)

The hermes-agent tool_call_parsers/ directory is preserved as a
standalone fallback for environments that don't use vLLM's parsers.
2026-03-09 20:49:18 -05:00
dmahan93
2a67e4fa57 test: 9 agent loop tool-calling integration tests
Real LLM calls via OpenRouter using stepfun/step-3.5-flash:free (zero cost).
Falls back to paid models if free model is unavailable.

Tests: single tool call, multi-tool single turn, multi-turn chains,
unknown tool rejection, max_turns limit, direct response (no tools),
tool error handling, AgentResult structure, conversation history.
2026-03-09 20:37:55 -05:00
dmahan93
136a64942d feat: add eval_concurrency limit + Docker local config for TBLite
- Add eval_concurrency config field with asyncio.Semaphore
- Add local.yaml config using Docker backend (sandboxed, no cloud costs)
- Register docker_image alongside modal_image for backend flexibility
- Default: 8 parallel tasks for local runs
2026-03-09 20:28:28 -05:00
dmahan93
9f74d1f2ec test: 13 tests for Modal sandbox infra fixes 2026-03-09 20:26:09 -05:00
dmahan93
11ad4173de fix: Modal sandbox eval infra (9 fixes for TBLite baseline)
Fixes discovered while running TBLite baseline evaluation:

1. ephemeral_disk param not supported in modal 1.3.5 - check before passing
2. Modal legacy image builder requires working pip - add ensurepip fix via
   setup_dockerfile_commands to handle task images with broken pip
3. Host cwd leaked into Modal sandbox - add /home/ to host prefix check
4. Tilde ~ not expanded by subprocess.run(cwd=) in sandboxes - use /root
5. install_pipx must stay True for swerex-remote to be available

Dependencies also needed (not in this commit):
- git submodule update --init mini-swe-agent
- uv pip install swe-rex boto3
2026-03-09 18:36:28 -05:00
dmahan93
92cb77eaa7 Add tests for atropos tool calling integration
- test_tool_call_parsers.py: 16 tests for parser registry, hermes parser
  (single/multiple/truncated/malformed), and ParseResult contract validation
- test_agent_loop.py: 21 tests for HermesAgentLoop with mock servers
  (text responses, tool calls, max turns, unknown tools, API errors,
  extra_body forwarding, managed state, blocked tools, reasoning extraction)
- test_managed_server_tool_support.py: 9 tests validating API compatibility
  between hermes-agent and atroposlib's ManagedServer tool_call_parser support
  (gracefully skips on baseline atroposlib, passes on tool_call_support branch)
2026-03-09 15:42:16 -05:00
Teknium
c5e8166c8b Merge pull request #720 from NousResearch/feat/session-naming
feat: Session naming with unique titles, auto-lineage & rich listing
2026-03-08 16:32:13 -07:00
teknium1
2b88568653 docs: add session naming documentation across all doc files
- website/docs/user-guide/sessions.md: New 'Session Naming' section
  with /title usage, title rules, auto-lineage, gateway support.
  Updated 'Resume by Name' section, 'Rename a Session' subsection,
  updated sessions list output format, updated DB schema description.
- website/docs/reference/cli-commands.md: Added -c "name" and
  --resume by title to Core Commands, sessions rename to Sessions
  table, /title to slash commands.
- website/docs/user-guide/cli.md: Added -c "name" and --resume by
  title to resume options.
- AGENTS.md: Added -c, --resume, sessions list/rename to CLI commands
  table. Added hermes_state.py to project structure.
- CONTRIBUTING.md: Updated hermes_state.py and session persistence
  descriptions to mention titles.
- hermes_cli/main.py: Fixed sessions help string to include 'rename'.
2026-03-08 16:09:31 -07:00
teknium1
34b4fe495e fix: add title validation — sanitize, length limit, control char stripping
- Add SessionDB.sanitize_title() static method:
  - Strips ASCII control chars (null, bell, ESC, etc.) except whitespace
  - Strips problematic Unicode controls (zero-width, RTL override, BOM)
  - Collapses whitespace runs, strips edges
  - Normalizes empty/whitespace-only to None
  - Enforces 100 char max length (raises ValueError)
- set_session_title() now calls sanitize_title() internally,
  so all call sites (CLI, gateway, auto-lineage) are protected
- CLI /title handler sanitizes early to show correct feedback
- Gateway /title handler sanitizes early to show correct feedback
- 24 new tests: sanitize_title (17 cases covering control chars,
  zero-width, RTL, BOM, emoji, CJK, length, integration),
  gateway validation (too long, control chars, only-control-chars)
2026-03-08 15:54:51 -07:00
teknium1
4fdd6c0dac fix: harden session title system + add /title to gateway
- Empty string titles normalized to None (prevents uncaught IntegrityError
  when two sessions both get empty-string titles via the unique index)
- Escape SQL LIKE wildcards (%, _) in resolve_session_by_title and
  get_next_title_in_lineage to prevent false matches on titles like
  'test_project' matching 'testXproject #2'
- Optimize list_sessions_rich from N+2 queries to a single query with
  correlated subqueries (preview + last_active computed in SQL)
- Add /title slash command to gateway (Telegram, Discord, Slack, WhatsApp)
  with set and show modes, uniqueness conflict handling
- Add /title to gateway /help text and _known_commands
- 12 new tests: empty string normalization, multi-empty-title safety,
  SQL wildcard edge cases, gateway /title set/show/conflict/cross-platform
2026-03-08 15:48:09 -07:00
teknium1
60b6abefd9 feat: session naming with unique titles, auto-lineage, rich listing, resume by name
- Schema v4: unique title index, migration from v2/v3
- set/get/resolve session titles with uniqueness enforcement
- Auto-lineage: context compression auto-numbers titles (Task -> Task #2 -> Task #3)
- resolve_session_by_title: auto-latest finds most recent continuation
- list_sessions_rich: preview (first 60 chars) + last_active timestamp
- CLI: -c accepts optional name arg (hermes -c 'my project')
- CLI: /title command with deferred mode (set before session exists)
- CLI: sessions list shows Title, Preview, Last Active, ID
- 27 new tests (1844 total passing)
2026-03-08 15:20:29 -07:00
teknium1
4d53b7ccaa Add OpenRouter app attribution headers to skills_guard and trajectory_compressor
These two files were creating bare OpenAI clients pointing at OpenRouter
without the HTTP-Referer / X-OpenRouter-Title / X-OpenRouter-Categories
headers that the rest of the codebase sends for app attribution.

- skills_guard.py: LLM audit client (always OpenRouter)
- trajectory_compressor.py: sync + async summarization clients
  (guarded with 'openrouter' in base_url check since the endpoint
  is user-configurable)
2026-03-08 14:23:18 -07:00
teknium1
cd77c7100c Merge PR #648: test: add regression coverage for compressor tool-call boundaries
Authored by intertwine. Related to #647.
2026-03-08 06:46:50 -07:00
teknium1
cf810c2950 fix: pre-process CLI clipboard images through vision tool instead of raw embedding
Images pasted in the CLI were embedded as raw base64 image_url content
parts in the conversation history, which only works with vision-capable
models. If the main model (e.g. Nous API) doesn't support vision, this
breaks the request and poisons all subsequent messages.

Now the CLI uses the same approach as the messaging gateway: images are
pre-processed through the auxiliary vision model (Gemini Flash via
OpenRouter or Nous Portal) and converted to text descriptions. The
local file path is included so the agent can re-examine via
vision_analyze if needed. Works with any model.

Fixes #638.
2026-03-08 06:22:00 -07:00
teknium1
a23bcb81ce fix: improve /model user feedback + update docs
User messaging improvements:
- Rejection: '(>_<) Error: not a valid model' instead of '(^_^) Warning: Error:'
- Rejection: shows 'Model unchanged' + tip about /model and /provider
- Session-only: explains 'this session only' with reason and 'will revert on restart'
- Saved: clear '(saved to config)' confirmation

Docs updated:
- cli-commands.md, cli.md, messaging/index.md: /model now shows
  provider:model syntax, /provider command added to tables

Test fixes: deduplicated test names, assertions match new messages.
2026-03-08 06:13:12 -07:00
stablegenius49
d07d867718 Fix empty tool selection persistence 2026-03-08 06:11:18 -07:00
teknium1
666f2dd486 feat: /provider command + fix gateway bugs + harden parse_model_input
/provider command (CLI + gateway):
  Shows all providers with auth status (✓/✗), aliases, and active marker.
  Users can now discover what provider names work with provider:model syntax.

Gateway bugs fixed:
  - Config was saved even when validation.persist=False (told user 'session
    only' but actually persisted the unvalidated model)
  - HERMES_INFERENCE_PROVIDER env var not set on provider switch, causing
    the switch to be silently overridden if that env var was already set

parse_model_input hardened:
  - Colon only treated as provider delimiter if left side is a recognized
    provider name or alias. 'anthropic/claude-3.5-sonnet:beta' now passes
    through as a model name instead of trying provider='anthropic/claude-3.5-sonnet'.
  - HTTP URLs, random colons no longer misinterpreted.

56 tests passing across model validation, CLI commands, and integration.
2026-03-08 06:09:36 -07:00
teknium1
34792dd907 fix: resolve 'auto' provider properly via credential detection
'auto' doesn't always mean openrouter — it could be nous, zai,
kimi-coding, etc. depending on configured credentials. Reverted the
hardcoded mapping and now both CLI and gateway call
resolve_provider() to detect the actual active provider when 'auto'
is set. Falls back to openrouter only if resolution fails.
2026-03-08 05:58:45 -07:00
teknium1
7ad6fc8a40 fix: gateway /model also needs normalize_provider for 'auto' resolution 2026-03-08 05:56:43 -07:00
teknium1
f824c10429 feat: enhance config migration with new environment variable tracking
Added a system to track environment variables introduced in each config version, allowing migration prompts to only mention new variables since the user's last version. Updated the interactive configuration process to offer users the option to set these new optional keys during migration.
2026-03-08 05:55:32 -07:00
teknium1
132e5ec179 fix: resolve 'auto' provider in /model display + update gateway handler
- normalize_provider('auto') now returns 'openrouter' (the default)
  so /model shows the curated model list instead of nothing
- CLI /model display uses normalize_provider before looking up labels
- Gateway /model handler now uses the same validation logic as CLI:
  live API probe, provider:model syntax, curated model list display
2026-03-08 05:54:52 -07:00
teknium1
66d3e6a0c2 feat: provider switching via /model + enhanced model display
Add provider:model syntax to /model command for runtime provider switching:
  /model zai:glm-5           → switch to Z.AI provider with glm-5
  /model nous:hermes-3       → switch to Nous Portal with hermes-3
  /model openrouter:anthropic/claude-sonnet-4.5  → explicit OpenRouter

When switching providers, credentials are resolved via resolve_runtime_provider
and validated before committing. Both model and provider are saved to config.
Provider aliases work (glm: → zai, kimi: → kimi-coding, etc.).

Enhanced /model (no args) display now shows:
  - Current model and provider
  - Curated model list for the current provider with ← marker
  - Usage examples including provider:model syntax

39 tests covering parse_model_input, curated_models_for_provider,
provider switching (success + credential failure), and display output.
2026-03-08 05:45:59 -07:00
teknium1
4a09ae2985 chore: remove dead module stubs from test_cli_init.py
The 200 lines of prompt_toolkit/rich/fire stubs added in PR #650 were
guarded by 'if module in sys.modules: return' and never activated since
those dependencies are always installed. Removed to keep the test file
lean. Also removed unused MagicMock and pytest imports.
2026-03-08 05:35:02 -07:00
teknium1
8c734f2f27 fix: remove OpenRouter '/' format enforcement — let API probe be the authority
Not all providers require 'provider/model' format. Removing the rigid
format check lets the live API probe handle all validation uniformly.
If someone types 'gpt-5.4' on OpenRouter, the probe won't find it and
will suggest 'openai/gpt-5.4' — better UX than a format rejection.
2026-03-08 05:31:41 -07:00
teknium1
245d174359 feat: validate /model against live API instead of hardcoded lists
Replace the static catalog-based model validation with a live API probe.
The /model command now hits the provider's /models endpoint to check if
the requested model actually exists:

- Model found in API → accepted + saved to config
- Model NOT found in API → rejected with 'Error: not a valid model'
  and fuzzy-match suggestions from the live model list
- API unreachable → graceful fallback to hardcoded catalog (session-only
  for unrecognized models)
- Format errors (empty, spaces, missing '/') still caught instantly
  without a network call

The API probe takes ~0.2s for OpenRouter (346 models) and works with any
OpenAI-compatible endpoint (Ollama, vLLM, custom, etc.).

32 tests covering all paths: format checks, API found, API not found,
API unreachable fallback, CLI integration.
2026-03-08 05:22:20 -07:00
stablegenius49
77f47768dd fix: improve /history message display 2026-03-08 05:08:57 -07:00
teknium1
90fa9e54ca fix: guard validate_requested_model + expand test coverage (PR #649 follow-up)
- Wrap validate_requested_model in try/except so /model doesn't crash
  if validation itself fails (falls back to old accept+save behavior)
- Remove unnecessary sys.path.insert from both test files
- Expand test_model_validation.py: 4 → 23 tests covering normalize_provider,
  provider_model_ids, empty/whitespace/spaces rejection, OpenRouter format
  validation, custom endpoints, nous provider, provider aliases, unknown
  providers, fuzzy suggestions
- Expand test_cli_model_command.py: 2 → 5 tests adding known-model save,
  validation crash fallback, and /model with no argument
2026-03-08 04:47:35 -07:00
stablegenius49
9d3a44e0e8 fix: validate /model values before saving 2026-03-08 04:47:35 -07:00
teknium1
932d596466 feat: enhance systemd unit and install script for browser dependencies
Updated the systemd unit generation to include the virtual environment and node modules in the PATH, improving the execution context for the hermes CLI. Additionally, added support for installing Playwright and its dependencies on Arch/Manjaro systems in the install script, ensuring a smoother setup process for browser tools.
2026-03-08 04:36:23 -07:00
teknium1
d518f40e8b fix: improve browser command environment setup
Enhanced the environment setup for browser commands by ensuring the PATH variable includes standard directories, addressing potential issues with minimal PATH in systemd services. Additionally, updated the logging of stderr to use a warning level on failure for better visibility of errors. This change improves the robustness of subprocess execution in the browser tool.
2026-03-08 04:08:44 -07:00
Teknium
f016cfca46 Merge pull request #685 from NousResearch/revert-659-feat/skill-prerequisites
Revert "feat: skill prerequisites — hide skills with unmet runtime dependencies"
2026-03-08 03:58:41 -07:00
Teknium
b8120df860 Revert "feat: skill prerequisites — hide skills with unmet runtime dependencies" 2026-03-08 03:58:13 -07:00
teknium1
0df7df52f3 test: expand slash command autocomplete coverage (PR #645 follow-up)
- Fix failing test: use display_text/display_meta_text instead of str()
  on prompt_toolkit FormattedText objects
- Add regression guard: EXPECTED_COMMANDS set ensures no command
  silently disappears from the shared dict
- Add edge case tests: non-slash input, empty input, partial vs exact
  match trailing space, builtin display_meta content
- Add skill provider tests: None provider, exception swallowing,
  description truncation at 50 chars, missing description fallback,
  exact-match trailing space on skill commands
- Total: 15 tests (up from 4)
2026-03-08 03:53:22 -07:00
stablegenius49
bfa27d0a68 fix(cli): unify slash command autocomplete registry 2026-03-08 03:53:22 -07:00
teknium1
5a20c486e3 Merge PR #659: feat: skill prerequisites — hide skills with unmet runtime dependencies
Authored by kshitijk4poor. Fixes #630.
2026-03-08 03:12:35 -07:00
teknium1
78e19ebc95 chore: update .gitignore to include .worktrees directory
Added .worktrees to the .gitignore file to prevent tracking of worktree-specific files, ensuring a cleaner repository.
2026-03-08 03:01:46 -07:00
teknium1
b383cafc44 refactor: rename and enhance shell detection in local environment
Renamed _find_shell to _find_bash to clarify its purpose of specifically locating bash. Improved the shell detection logic to prioritize bash over the user's $SHELL, ensuring compatibility with the fence wrapper's syntax requirements. Added a backward compatibility alias for _find_shell to maintain existing imports in process_registry.py.
2026-03-08 03:00:05 -07:00
teknium1
b10ff83566 fix: enhance PATH handling in local environment
Updated the LocalEnvironment class to ensure the PATH variable includes standard directories. This change addresses issues with systemd services and terminal multiplexers that inherit a minimal PATH, improving the execution environment for subprocesses.
2026-03-08 01:50:38 -08:00
teknium1
daa1f542f9 fix: enhance shell detection in local environment configuration
Updated the _find_shell function to improve shell detection on non-Windows systems. The function now checks for the existence of /usr/bin/bash and /bin/bash before falling back to /bin/sh, ensuring a more robust shell resolution process.
2026-03-08 01:43:00 -08:00
teknium1
d507f593d0 fix: respect config.yaml cwd in gateway, add sandbox_dir config option
Two fixes:

1. Gateway CWD override: TERMINAL_CWD from config.yaml was being
   unconditionally overwritten by the messaging_cwd fallback (line 114).
   Now explicit paths in config.yaml are respected — only '.' / 'auto' /
   'cwd' (or unset) fall back to MESSAGING_CWD or home directory.

2. sandbox_dir config: Added terminal.sandbox_dir to config.yaml bridge
   in gateway/run.py, cli.py, and hermes_cli/config.py. Maps to
   TERMINAL_SANDBOX_DIR env var, which get_sandbox_dir() reads to
   determine where Docker/Singularity sandbox data is stored (default:
   ~/.hermes/sandboxes/). Users can now set:
     hermes config set terminal.sandbox_dir /data/hermes-sandboxes
2026-03-08 01:33:46 -08:00
kshitij
f210510276 feat: add prerequisites field to skill spec — hide skills with unmet dependencies
Skills can now declare runtime prerequisites (env vars, CLI binaries) via
YAML frontmatter. Skills with unmet prerequisites are excluded from the
system prompt so the agent never claims capabilities it can't deliver, and
skill_view() warns the agent about what's missing.

Three layers of defense:
- build_skills_system_prompt() filters out unavailable skills
- _find_all_skills() flags unmet prerequisites in metadata
- skill_view() returns prerequisites_warning with actionable details

Tagged 12 bundled skills that have hard runtime dependencies:
gif-search (TENOR_API_KEY), notion (NOTION_API_KEY), himalaya, imessage,
apple-notes, apple-reminders, openhue, duckduckgo-search, codebase-inspection,
blogwatcher, songsee, mcporter.

Closes #658
Fixes #630
2026-03-08 13:19:32 +05:30
teknium1
19b6f81ee7 fix: allow Anthropic API URLs as custom OpenAI-compatible endpoints
Removed the hard block on base_url containing 'api.anthropic.com'.
Anthropic now offers an OpenAI-compatible /chat/completions endpoint,
so blocking their URL prevents legitimate use. If the endpoint isn't
compatible, the API call will fail with a proper error anyway.

Removed from: run_agent.py, mini_swe_runner.py
Updated test to verify Anthropic URLs are accepted.
2026-03-07 23:36:35 -08:00
Teknium
76545ab365 Merge pull request #657 from NousResearch/feat/browser-screenshot-sharing
feat: browser screenshot sharing via MEDIA: on all messaging platforms
2026-03-07 22:57:42 -08:00
teknium1
b8c3bc7841 feat: browser screenshot sharing via MEDIA: on all messaging platforms
browser_vision now saves screenshots persistently to ~/.hermes/browser_screenshots/
and returns the screenshot_path in its JSON response. The model can include
MEDIA:<path> in its response to share screenshots as native photos.

Changes:
- browser_tool.py: Save screenshots persistently, return screenshot_path,
  auto-cleanup files older than 24 hours, mkdir moved inside try/except
- telegram.py: Add send_image_file() — sends local images via bot.send_photo()
- discord.py: Add send_image_file() — sends local images via discord.File
- slack.py: Add send_image_file() — sends local images via files_upload_v2()
  (WhatsApp already had send_image_file — no changes needed)
- prompt_builder.py: Updated Telegram hint to list image extensions,
  added Discord and Slack MEDIA: platform hints
- browser.md: Document screenshot sharing and 24h cleanup
- send_file_integration_map.md: Updated to reflect send_image_file is now
  implemented on Telegram/Discord/Slack
- test_send_image_file.py: 19 tests covering MEDIA: .png extraction,
  send_image_file on all platforms, and screenshot cleanup

Partially addresses #466 (Phase 0: platform adapter gaps for send_image_file).
2026-03-07 22:57:05 -08:00
teknium1
a680367568 fix tmux menus 2026-03-07 22:14:21 -08:00
teknium1
dfd37a4b31 Merge PR #635: fix: add Kimi Code API support (api.kimi.com/coding/v1)
Authored by christomitov. Auto-detects sk-kimi- key prefix and routes
to api.kimi.com/coding/v1. Adds User-Agent header for Kimi Code API
compatibility. Legacy Moonshot keys continue to work unchanged.
2026-03-07 21:45:27 -08:00
teknium1
5ee9b67d9b Merge PR #654: feat: git worktree isolation for parallel CLI sessions (--worktree / -w)
Adds --worktree (-w) flag to hermes CLI for isolated git worktree sessions.
Multiple agents can work on the same repo concurrently without collisions.

Closes #652
2026-03-07 21:38:42 -08:00
teknium1
542faf225f Fix Telegram image delivery for large (>5MB) images
Telegram's send_photo via URL has a ~5MB limit. Upscaled images from
fal.ai's Clarity Upscaler often exceed this, causing 'Wrong type of
web page content' or 'Failed to get http url content' errors.

Fix: Add download-and-upload fallback in Telegram's send_image().
When URL-based send_photo fails, download the image via httpx and
re-upload as bytes (supports up to 10MB file uploads).

Also: convert print() to logger.warning/error in image sending path
for proper log visibility (print goes to socket, invisible in logs).
2026-03-07 21:29:45 -08:00
teknium1
5684c68121 Add logger.info/error for image extraction and delivery debugging 2026-03-07 21:24:47 -08:00
teknium1
4be783446a fix: wire worktree flag into hermes CLI entry point + docs + tests
Critical fixes:
- Add --worktree/-w to hermes_cli/main.py argparse (both chat
  subcommand and top-level parser) so 'hermes -w' works via the
  actual CLI entry point, not just 'python cli.py -w'
- Pass worktree flag through cmd_chat() kwargs to cli_main()
- Handle worktree attr in bare 'hermes' and --resume/--continue paths

Bug fixes in cli.py:
- Skip worktree creation for --list-tools/--list-toolsets (wasteful)
- Wrap git worktree subprocess.run in try/except (crash on timeout)
- Add stale worktree pruning on startup (_prune_stale_worktrees):
  removes clean worktrees older than 24h left by crashed/killed sessions

Documentation updates:
- AGENTS.md: add --worktree to CLI commands table
- cli-config.yaml.example: add worktree config section
- website/docs/reference/cli-commands.md: add to core commands
- website/docs/user-guide/cli.md: add usage examples
- website/docs/user-guide/configuration.md: add config docs

Test improvements (17 → 31 tests):
- Stale worktree pruning (prune old clean, keep recent, keep dirty)
- Directory symlink via .worktreeinclude
- Edge cases (no commits, not a repo, pre-existing .worktrees/)
- CLI flag/config OR logic
- TERMINAL_CWD integration
- System prompt injection format
2026-03-07 21:05:40 -08:00
teknium1
8d719b180a feat: git worktree isolation for parallel CLI sessions (--worktree / -w)
Add a --worktree (-w) flag to the hermes CLI that creates an isolated
git worktree for the session. This allows running multiple hermes-agent
instances concurrently on the same repo without file collisions.

How it works:
- On startup with -w: detects git repo, creates .worktrees/<session>/
  with its own branch (hermes/<session-id>), sets TERMINAL_CWD to it
- Each agent works in complete isolation — independent HEAD, index,
  and working tree, shared git object store
- On exit: auto-removes worktree and branch if clean, warns and
  keeps if there are uncommitted changes
- .worktreeinclude file support: list gitignored files (.env, .venv/)
  to auto-copy/symlink into new worktrees
- .worktrees/ is auto-added to .gitignore
- Agent gets a system prompt note about the worktree context
- Config support: set worktree: true in config.yaml to always enable

Usage:
  hermes -w                      # Interactive mode in worktree
  hermes -w -q "Fix issue #123"  # Single query in worktree
  # Or in config.yaml:
  worktree: true

Includes 17 tests covering: repo detection, worktree creation,
independence verification, cleanup (clean/dirty), .worktreeinclude,
.gitignore management, and 10 concurrent worktrees.

Closes #652
2026-03-07 20:51:08 -08:00
teknium1
bf048c8aec feat: add qmd optional skill — local knowledge base search
Add official optional skill for qmd (tobi/qmd), a local on-device
search engine for personal knowledge bases, notes, docs, and meeting
transcripts.

Covers:
- Installation and setup for macOS and Linux
- Collection management and context annotations
- All search modes: BM25, vector, hybrid with reranking
- MCP integration (stdio and HTTP daemon modes)
- Structured query patterns and best practices
- systemd/launchd service configs for daemon persistence

Placed in optional-skills/ due to heavyweight requirements
(Node >= 22, ~2GB local models).
2026-03-07 20:39:05 -08:00
teknium1
c5a9d1ef9d Merge branch 'main' into pr-635 2026-03-07 20:36:42 -08:00
teknium1
c7b6f423c7 feat: auto-compress pathologically large gateway sessions (#628)
Long-lived gateway sessions can accumulate enough history that every new
message rehydrates an oversized transcript, causing repeated truncation
failures (finish_reason=length).

Add a session hygiene check in _handle_message that runs right after
loading the transcript and before invoking the agent:

1. Estimate message count and rough token count of the transcript
2. If above configurable thresholds (default: 200 msgs or 100K tokens),
   auto-compress the transcript proactively
3. Notify the user about the compression with before/after stats
4. If still above warn threshold (default: 200K tokens) after
   compression, suggest /reset
5. If compression fails on a dangerously large session, warn the user
   to use /compress or /reset manually

Thresholds are configurable via config.yaml:

  session_hygiene:
    auto_compress_tokens: 100000
    auto_compress_messages: 200
    warn_tokens: 200000

This complements the agent's existing preflight compression (which
runs inside run_conversation) by catching pathological sessions at
the gateway layer before the agent is even created.

Includes 12 tests for threshold detection and token estimation.
2026-03-07 20:09:48 -08:00
teknium1
6d34207167 Merge PR #620: fix: restore missing MIT license file
Authored by stablegenius49. Fixes #619.
2026-03-07 20:00:33 -08:00
Bryan Young
fcde9be10d fix: keep tool-call output runs intact during compression 2026-03-08 03:13:14 +00:00
teknium1
3830bbda41 fix: include url in web_extract trimmed results & fix docs
The web_extract_tool was stripping the 'url' key during its output
trimming step, but documentation in 3 places claimed it was present.
This caused KeyError when accessing result['url'] in execute_code
scripts, especially when extracting from multiple URLs.

Changes:
- web_tools.py: Add 'url' back to trimmed_results output
- code_execution_tool.py: Add 'title' to _TOOL_STUBS docstring and
  _TOOL_DOC_LINES so docs match actual {url, title, content, error}
  response format
2026-03-07 18:07:36 -08:00
Christo Mitov
4447e7d71a fix: add Kimi Code API support (api.kimi.com/coding/v1)
Kimi Code (platform.kimi.ai) issues API keys prefixed sk-kimi- that require:
1. A different base URL: api.kimi.com/coding/v1 (not api.moonshot.ai/v1)
2. A User-Agent header identifying a recognized coding agent

Without this fix, sk-kimi- keys fail with 401 (wrong endpoint) or 403
('only available for Coding Agents') errors.

Changes:
- Auto-detect sk-kimi- key prefix and route to api.kimi.com/coding/v1
- Send User-Agent: KimiCLI/1.0 header for Kimi Code endpoints
- Legacy Moonshot keys (api.moonshot.ai) continue to work unchanged
- KIMI_BASE_URL env var override still takes priority over auto-detection
- Updated .env.example with correct docs and all endpoint options
- Fixed doctor.py health check for Kimi Code keys

Reference: https://github.com/MoonshotAI/kimi-cli (platforms.py)
2026-03-07 21:00:12 -05:00
teknium1
7bccd904c7 Merge PR #629: feat: add Polymarket prediction market skill (read-only)
Adds market-data/polymarket skill — read-only access to Polymarket's public
prediction market APIs. Zero dependencies, zero auth required.
Addresses #589.
2026-03-07 17:28:03 -08:00
teknium1
313d522b61 feat: add Polymarket prediction market skill (read-only)
Adds a new market-data/polymarket skill for querying Polymarket's public
prediction market APIs. Pure read-only, zero authentication required,
zero external dependencies (stdlib only).

Includes:
- SKILL.md: Agent instructions with key concepts and workflow
- references/api-endpoints.md: Full API reference (Gamma, CLOB, Data APIs)
- scripts/polymarket.py: CLI helper for search, trending, prices, orderbooks,
  price history, and recent trades

Addresses #589.
2026-03-07 17:27:29 -08:00
teknium1
9ee4fe41fe Fix image_generate 'Event loop is closed' in gateway
Root cause: fal_client.AsyncClient uses @cached_property for its
httpx.AsyncClient, creating it once and caching forever. In the gateway,
the agent runs in a thread pool where _run_async() calls asyncio.run()
which creates a temporary event loop. The first call works, but
asyncio.run() closes that loop. On the next call, a new loop is created
but the cached httpx.AsyncClient still references the old closed loop,
causing 'Event loop is closed'.

Fix: Switch from async fal_client API (submit_async/handler.get with
await) to sync API (submit/handler.get). The sync API uses httpx.Client
which has no event loop dependency. Since the tool already runs in a
thread pool via the gateway, async adds no benefit here.

Changes:
- image_generate_tool: async def -> def
- _upscale_image: async def -> def
- fal_client.submit_async -> fal_client.submit
- await handler.get() -> handler.get()
- is_async=True -> is_async=False in registry
- Remove unused asyncio import
2026-03-07 16:56:49 -08:00
teknium1
39ee3512cb Merge PR #614: fix: resolve systemd restart loop with --replace flag
Authored by voidborne-d. Fixes #576.

Adds --replace flag to 'hermes gateway run' that terminates any existing
gateway instance (SIGTERM with SIGKILL fallback) before starting.
Updated systemd unit template with --replace, ExecStop, KillMode, and
TimeoutStopSec for robust service management.
2026-03-07 16:33:27 -08:00
teknium1
42673556af Merge PR #575: fix(setup): prevent OpenRouter model list fallback for Nous provider
Authored by PercyDikec. Fixes #574.

# Conflicts:
#	hermes_cli/setup.py
2026-03-07 16:22:13 -08:00
teknium1
faab73ad58 Merge PR #573: fix(doctor): detect OpenAI custom endpoint env settings
Authored by stablegenius49. Fixes #572.
2026-03-07 16:16:08 -08:00
teknium1
7e36468511 fix: /clear command broken inside TUI (patch_stdout interference)
The /clear command was using Rich's console.clear() and console.print()
which write directly to stdout. Inside the TUI, prompt_toolkit's
patch_stdout intercepts stdout via StdoutProxy, which doesn't interpret
screen-clearing escape sequences and mangles Rich's ANSI output,
resulting in raw escape codes dumped to the terminal.

Fix:
- Use prompt_toolkit's output.erase_screen() + cursor_goto() to clear
  the terminal directly (bypasses patch_stdout's StdoutProxy)
- Render the banner through ChatConsole (which routes Rich output
  through prompt_toolkit's native print_formatted_text/ANSI renderer)
- Use _cprint for the status message (prompt_toolkit-compatible)
- Fall back to the old behavior when not inside the TUI (e.g. startup)
2026-03-07 16:09:23 -08:00
stablegenius49
9ba5d399e5 fix: restore missing MIT license file 2026-03-07 13:43:08 -08:00
teknium1
306d92a9d7 refactor(context_compressor): improve summary generation logic and error handling
Updated the _generate_summary method to attempt summary generation using the auxiliary model first, with a fallback to the main model. If both attempts fail, the method now returns None instead of a placeholder, allowing the caller to handle missing summaries appropriately. This change enhances the robustness of context compression and improves logging for failure scenarios.
2026-03-07 11:54:51 -08:00
teknium1
5baae0df88 feat(scheduler): enhance job configuration with reasoning effort, prefill messages, and provider routing
Added support for loading reasoning configuration, prefill messages, and provider routing from environment variables or config.yaml in the run_job function. This improves flexibility and customization for job execution, allowing for better control over agent behavior and message handling.
2026-03-07 11:37:16 -08:00
teknium1
24f6a193e7 fix: remove stale 'model' assertion from delegate_task schema test
The 'model' property was removed from DELEGATE_TASK_SCHEMA but the
test still asserted its presence, causing CI to fail.
2026-03-07 11:29:55 -08:00
teknium1
8c0f8baf32 feat(delegate_tool): add additional parameters for child agent configuration
Enhanced the _run_single_child function by introducing max_tokens, reasoning_config, and prefill_messages parameters from the parent agent. This allows for more flexible configuration of child agents, improving their operational capabilities.
2026-03-07 11:29:17 -08:00
teknium1
d80c30cc92 feat(gateway): proactive async memory flush on session expiry
Previously, when a session expired (idle/daily reset), the memory flush
ran synchronously inside get_or_create_session — blocking the user's
message for 10-60s while an LLM call saved memories.

Now a background watcher task (_session_expiry_watcher) runs every 5 min,
detects expired sessions, and flushes memories proactively in a thread
pool.  By the time the user sends their next message, memories are
already saved and the response is immediate.

Changes:
- Add _is_session_expired(entry) to SessionStore — works from entry
  alone without needing a SessionSource
- Add _pre_flushed_sessions set to track already-flushed sessions
- Remove sync _on_auto_reset callback from get_or_create_session
- Refactor flush into _flush_memories_for_session (sync worker) +
  _async_flush_memories (thread pool wrapper)
- Add _session_expiry_watcher background task, started in start()
- Simplify /reset command to use shared fire-and-forget flush
- Add 10 tests for expiry detection, callback removal, tracking
2026-03-07 11:27:50 -08:00
teknium1
e64d646bad Critical: fix bug in new subagent tool call budget to not be session-level but tool call loop level 2026-03-07 10:32:51 -08:00
teknium1
b84f9e410c feat: default reasoning effort from xhigh to medium
Reduces token usage and latency for most tasks by defaulting to
medium reasoning effort instead of xhigh. Users can still override
via config or CLI flag. Updates code, tests, example config, and docs.
2026-03-07 10:14:19 -08:00
d 🔹
ee5daba061 fix: resolve systemd restart loop with --replace flag (#576)
When running under systemd, the gateway could enter restart loops in two
scenarios:

1. The previous gateway process hasn't fully exited when systemd starts
   a new one, causing 'Gateway already running (PID ...)' → exit 1 →
   restart → same error → infinite loop.

2. The interactive CLI exits immediately in non-TTY mode, and systemd
   keeps restarting it.

Changes:

- Add --replace flag to 'hermes gateway run' that gracefully kills any
  existing gateway instance (SIGTERM → wait 10s → SIGKILL) before
  starting, preventing the PID-lock deadlock.

- Update the generated systemd unit template to use --replace by default,
  add ExecStop for clean shutdown, set KillMode=mixed and
  TimeoutStopSec=15 for proper process management.

- Existing behavior (without --replace) is unchanged: still prints the
  error message and exits, now also mentioning the --replace option.

Fixes #576
2026-03-07 18:08:12 +00:00
teknium1
23e84de830 refactor: remove model parameter from AIAgent initialization
Eliminated the model parameter from the AIAgent class initialization, streamlining the constructor and ensuring consistent behavior across agent instances. This change aligns with recent updates to the task delegation logic.
2026-03-07 09:48:19 -08:00
teknium1
48e0dc8791 feat: implement Z.AI endpoint detection for API key validation
Added functionality to detect the appropriate Z.AI endpoint based on the provided API key, accommodating different billing plans and regions. The setup process now probes available endpoints and updates the configuration accordingly, enhancing user experience and reducing potential billing errors. Updated the setup model provider function to integrate this new detection logic.
2026-03-07 09:43:37 -08:00
teknium1
fb0f579b16 refactor: remove model parameter from delegate_task function
Eliminated the model parameter from the delegate_task function and its associated schema, defaulting to None for subagent calls. This change simplifies the function signature and enforces consistent behavior across task delegation.
2026-03-07 09:20:27 -08:00
teknium1
5a711f32b1 fix: enhance payload and context compression handling
Added logic to manage multiple compression attempts for large payloads and context length errors. Introduced limits on compression attempts to prevent infinite retries, with appropriate logging and error handling. This ensures better resilience and user feedback when facing compression issues during API calls.
2026-03-07 09:19:07 -08:00
teknium1
4d34427cc7 fix: update model version in agent configurations
Updated the default model version from "anthropic/claude-sonnet-4-20250514" to "anthropic/claude-sonnet-4.6" across multiple files including AGENTS.md, batch_runner.py, mini_swe_runner.py, and run_agent.py for consistency and to reflect the latest model improvements.
2026-03-07 09:06:37 -08:00
teknium1
41877183bc Merge PR #604: fix(tests): isolate max_turns tests from CI env and update default to 90
Authored by 0xbyt4. Fixes test assertions broken by 0a82396 (60→90 default).
2026-03-07 08:57:36 -08:00
0xbyt4
451a007fb1 fix(tests): isolate max_turns tests from CI env and update default to 90
_make_cli() did not clear HERMES_MAX_ITERATIONS env var, so tests
failed in CI where the var was set externally. Also, default max_turns
changed from 60 to 90 in 0a82396 but tests were not updated.

- Clear HERMES_MAX_ITERATIONS in _make_cli() for proper isolation
- Add env_overrides parameter for tests that need specific env values
- Update hardcoded 60 assertions to 90 to match new default
- Simplify test_env_var_max_turns using env_overrides
2026-03-07 19:43:20 +03:00
teknium1
0a82396718 feat: shared iteration budget across parent + subagents
Subagent tool calls now count toward the same session-wide iteration
limit as the parent agent. Previously, each subagent had its own
independent counter, so a parent with max_iterations=60 could spawn
3 subagents each doing 50 calls = 150 total tool calls unmetered.

Changes:
- IterationBudget: thread-safe shared counter (run_agent.py)
  - consume(): try to use one iteration, returns False if exhausted
  - refund(): give back one iteration (for execute_code turns)
  - Thread-safe via Lock (subagents run in ThreadPoolExecutor)
- Parent creates the budget, children inherit it via delegate_tool.py
- execute_code turns are refunded (don't count against budget)
- Default raised from 60 → 90 to account for shared consumption
- Per-child cap (50) still applies as a safety valve

The per-child max_iterations (default 50) remains as a per-child
ceiling, but the shared budget is the hard session-wide limit.
A child stops at whichever comes first.
2026-03-07 08:16:37 -08:00
teknium1
5da55ea1e3 fix: sanitize orphaned tool-call/result pairs in message compression
Enhance message compression by adding a method to clean up orphaned tool-call and tool-result pairs. This ensures that the API receives well-formed messages, preventing errors related to mismatched IDs. The new functionality includes removing orphaned results and adding stub results for missing calls, improving overall message integrity during compression.
2026-03-07 08:08:00 -08:00
teknium1
064c009deb feat: show update-available notice in CLI banner
Check how many commits behind origin/main the local repo is and
display a warning in the welcome banner:

  ⚠ 12 commits behind — run hermes update to update

- git fetch cached for 6 hours (avoids repeated network calls)
- Falls back gracefully if offline or not a git repo
- Never breaks the banner — all errors silently caught
2026-03-07 07:35:36 -08:00
teknium1
caab1cf453 fix: update setup/config UI for local browser mode
- tools_config.py: Add 'Local Browser' as first provider option
  (no API keys needed, same npm install for agent-browser)
- setup.py: Show 'Browser Automation (local)' when agent-browser
  CLI is found but no Browserbase key is set
- config.py: Mark BROWSERBASE_* descriptions as optional
- status.py: Note that local browser works without Browserbase
2026-03-07 01:23:27 -08:00
teknium1
55c70f3508 fix: strip MarkdownV2 escapes from Telegram plaintext fallback
When Telegram's MarkdownV2 parser rejects a message, the send() fallback
was sending the already-escaped text as plain text. This caused users to
see raw backslashes before every special character (periods, dashes,
parentheses, etc.) — e.g. 'sentence\.' or '\-\-auto\-approve'.

Changes:
- Add _strip_mdv2() to reverse MarkdownV2 escaping for clean plaintext
- Use stripped text in the send() fallback path instead of raw escaped chunk
- Add logging when the MDV2 fallback is triggered for diagnostics
- Add logger to telegram.py (was missing)

The edit_message() fallback already correctly used the original content;
this brings send() in line with that behavior.
2026-03-07 01:23:18 -08:00
teknium1
d29249b8fa feat: local browser backend — zero-cost headless Chromium via agent-browser
Add local browser mode as an automatic fallback when Browserbase
credentials are not configured. Uses the same agent-browser CLI with
--session (local Chromium) instead of --cdp (cloud Browserbase).

The agent-facing API is completely unchanged — all 10 browser_* tools
produce identical output in both modes. Auto-detection:
  - BROWSERBASE_API_KEY set → cloud mode (existing behavior)
  - No key → local mode (new, free, headless Chromium)

Changes:
- _is_local_mode(): auto-detect based on env vars
- _create_local_session(): lightweight session (no API call)
- _get_session_info(): branches on local vs cloud
- _run_browser_command(): --session in local, --cdp in cloud
- check_browser_requirements(): only needs agent-browser CLI in local mode
- _emergency_cleanup: CLI close in local, API release in cloud
- cleanup_browser/browser_close: skip BB API calls in local mode
- Registry: removed requires_env — check_fn handles both modes

Setup for local mode:
  npm install -g agent-browser
  agent-browser install              # downloads Chromium
  agent-browser install --with-deps  # also installs system libs (Docker/Debian)

Closes #374 (Phase 1)
2026-03-07 01:14:57 -08:00
teknium1
f668e9fc75 feat: platform-conditional skill loading + Apple/macOS skills
Add a 'platforms' field to SKILL.md frontmatter that restricts skills
to specific operating systems. Skills with platforms: [macos] only
appear in the system prompt, skills_list(), and slash commands on macOS.
Skills without the field load everywhere (backward compatible).

Implementation:
- skill_matches_platform() in tools/skills_tool.py — core filter
- Wired into all 3 discovery paths: prompt_builder.py, skills_tool.py,
  skill_commands.py
- 28 new tests across 3 test files

New bundled Apple/macOS skills (all platforms: [macos]):
- imessage — Send/receive iMessages via imsg CLI
- apple-reminders — Manage Reminders via remindctl CLI
- apple-notes — Manage Notes via memo CLI
- findmy — Track devices/AirTags via AppleScript + screen capture

Docs updated: CONTRIBUTING.md, AGENTS.md, creating-skills.md,
skills.md (user guide)
2026-03-07 00:47:54 -08:00
teknium1
74fe1e2254 chore: remove TODO.md — all items tracked as issues
All remaining TODO items have covering issues:
- Local Browser via CDP: #374, #493
- Signal Integration: #405
- Plugin/Extension System: #359
- MCP Client Improvements: #581 (new)
- Filesystem Checkpointing: #452

Completed items (MCP core support) already shipped in PR #301.
2026-03-07 00:07:14 -08:00
teknium1
348936752a fix: simplify timezone migration to use os.getenv directly
The previous 'get_env_value' in dir() check always evaluated to False
(dir() returns local scope, not module scope), making the left branch
dead code. Simplified to just os.getenv() which was the fallback anyway.
2026-03-07 00:05:05 -08:00
teknium1
69a36a3361 Merge PR #309: fix(timezone): timezone-aware now() for prompt, cron, and execute_code
Authored by areu01or00. Adds timezone support via hermes_time.now() helper
with IANA timezone resolution (HERMES_TIMEZONE env → config.yaml → server-local).
Updates system prompt timestamp, cron scheduling, and execute_code sandbox TZ
injection. Includes config migration (v4→v5) and comprehensive test coverage.
2026-03-07 00:04:41 -08:00
Teknium
8712dd6d1c Merge pull request #308 from batuhankocyigit/patch-2
fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
2026-03-06 23:43:09 -08:00
teknium1
55a21fe37b docs: add Environments, Benchmarks & Data Generation guide
Comprehensive developer guide covering:
- Architecture (BaseEnv → HermesAgentBaseEnv → concrete envs)
- All three benchmarks (TerminalBench2, TBLite, YC-Bench)
- Training environments (TerminalTestEnv, HermesSweEnv)
- Core components (AgentLoop, ToolContext, Tool Call Parsers)
- Two-phase operation (Phase 1 OpenAI, Phase 2 VLLM)
- Running environments (evaluate, process, serve modes)
- Creating new environments (training + eval-only)
- Configuration reference and prerequisites

Also updates environments/README.md directory tree to include
TBLite and YC-Bench benchmarks.
2026-03-06 23:31:45 -08:00
teknium1
f55f625277 chore: reorder terminal backends in setup wizard
Local, Docker, Modal, SSH, Daytona, Singularity (Linux-only, last).
2026-03-06 22:21:57 -08:00
teknium1
9dac85b069 fix: uv pip install fails outside venv in setup wizard
uv pip install requires a virtual environment by default. When hermes
is installed system-wide or via pipx, the setup wizard's SDK installs
(daytona, swe-rex[modal], tinker-atropos) fail with 'No virtual
environment found'. Fix by passing --python sys.executable to uv,
which targets the correct Python regardless of venv state.

Also show the actual error message on install failure so users can
debug.
2026-03-06 21:55:33 -08:00
teknium1
99bd69baa8 Merge feat/modular-setup-wizard: modular setup wizard with section subcommands and tool-first UX
- 5 standalone sections: hermes setup [model|terminal|gateway|tools|agent]
- Returning user menu with section shortcuts
- Tool-first UX: category -> provider -> API key flow
- Unified hermes tools / hermes setup tools
- Fixed dict-format model config display bug

Closes #567
2026-03-06 21:12:30 -08:00
teknium1
a62a137a4f fix: handle dict-format model config in setup wizard display
config['model'] can be a dict (old format: {default, base_url, provider})
or a string (new format). The setup wizard was showing the raw dict in
'Keep current' and 'Model set to' messages. Now extracts the model name
from either format.
2026-03-06 21:11:40 -08:00
teknium1
82b18e8ac2 feat: unify hermes tools and hermes setup tools into single flow
Both 'hermes tools' and 'hermes setup tools' now use the same unified
flow in tools_config.py:

1. Select platform (CLI, Telegram, Discord, etc.)
2. Toggle all 18 toolsets on/off in checklist
3. Newly enabled tools that need API keys → provider-aware config
   (e.g., TTS shows Edge/OpenAI/ElevenLabs picker)
4. Already-configured tools that stay enabled → silent, no prompts
5. Menu option: 'Reconfigure an existing tool' for updating
   providers or API keys on tools that are already set up

Key changes:
- Move TOOL_CATEGORIES, provider config, and post-setup hooks from
  setup.py to tools_config.py
- Replace flat _check_and_prompt_requirements() with provider-aware
  _configure_toolset() that uses TOOL_CATEGORIES
- Add _reconfigure_tool() flow for updating existing configs
- setup.py's setup_tools() now delegates to tools_command()
- tools_command() menu adds 'Reconfigure' option alongside platforms
- Only prompt for API keys on tools that are NEWLY toggled on AND
  don't already have keys configured

No breaking changes. All 2013 tests pass.
2026-03-06 21:02:00 -08:00
teknium1
0111c9848d fix: remove ANSI codes and em dashes from menu labels
simple_term_menu miscalculates string widths when labels contain
ANSI escape codes (from color()) or em dashes, causing duplicated
and garbled lines on arrow key navigation.

Replace color() status indicators with plain text [configured]/[active]
and em dashes with regular dashes in all prompt_choice/prompt_checklist
labels.
2026-03-06 21:02:00 -08:00
teknium1
ab9cadfeee feat: modular setup wizard with section subcommands and tool-first UX
Restructure the monolithic hermes setup wizard into independently-runnable
sections with a category-first tool configuration experience.

Changes:
- Break setup into 5 sections: model, terminal, gateway, tools, agent
- Each section is a standalone function, runnable individually via
  'hermes setup model', 'hermes setup terminal', etc.
- Returning users get a menu: Quick Setup / Full Setup / individual sections
- First-time users get a guided walkthrough of all sections

Tool Configuration UX overhaul:
- Replace flat API key checklist with category-first approach
- Show tool types (TTS, Web Search, Image Gen, etc.) as top-level items
- Within each category, let users pick a provider:
  - TTS: Microsoft Edge (Free), OpenAI, ElevenLabs
  - Web: Firecrawl Cloud, Firecrawl Self-Hosted
  - Image Gen: FAL.ai
  - Browser: Browserbase
  - Smart Home: Home Assistant
  - RL Training: Tinker/Atropos
  - GitHub: Personal Access Token
- Shows configured status on each tool and provider
- Only prompts for API keys after provider selection

Also:
- Add section argument to setup argparse parser in main.py
- Update summary to show new section commands
- Add self-hosted Firecrawl and Home Assistant to tool setup
- All 2013 tests pass
2026-03-06 21:02:00 -08:00
PercyDikec
8bf28e1441 fix(setup): prevent OpenRouter model list fallback for Nous provider
When `fetch_nous_models()` fails silently during setup, the model
selection falls through to the OpenRouter static list. Users then pick
models in OpenRouter format (e.g. `anthropic/claude-opus-4.6`) which
the Nous inference API rejects with a 400 "missing model" error.

Add an explicit `elif selected_provider == "nous"` branch that prompts
for manual model entry instead of falling through to the generic
OpenRouter fallback.
2026-03-07 07:16:22 +03:00
teknium1
ce28f847ce fix: update OpenRouter model names for yc-bench config
Use anthropic/claude-sonnet-4.6 (OpenRouter format) instead of
anthropic/claude-sonnet-4-20250514 (direct API format).
2026-03-06 19:58:56 -08:00
stablegenius49
5609117882 fix(doctor): recognize OPENAI_API_KEY custom endpoint config 2026-03-06 19:47:09 -08:00
teknium1
b4fbb6fe10 feat: add YC-Bench long-horizon agent benchmark environment
Adds eval-only benchmark for YC-Bench (collinear-ai/yc-bench), a
deterministic long-horizon benchmark where the agent acts as CEO of an
AI startup over a simulated 1-3 year run.

Key design decisions verified against the official yc-bench repo:
- Uses 'sim init' (NOT 'yc-bench run') to avoid starting a competing
  built-in agent loop
- Correct DB table names: 'companies' and 'sim_events'
- Correct 4 domains: research, inference, data_environment, training
- Penalty values are preset-dependent (not hardcoded in system prompt)
- Sequential evaluation (each run is 100-500 turns)
- Follows TerminalBench2 patterns: KeyboardInterrupt handling,
  cleanup_all_environments(), tqdm logging handler, streaming JSONL

yc-bench added as optional dependency: pip install hermes-agent[yc-bench]

Closes #340
2026-03-06 19:25:56 -08:00
teknium1
82d7e9429e chore: add GLM/Kimi/MiniMax models to insights pricing (zero cost)
These direct providers don't return cost in API responses and their
per-token pricing isn't readily available externally. Treat as local
models with zero cost so they appear in /insights without fake estimates.
2026-03-06 19:12:14 -08:00
teknium1
e2821effb5 feat: add direct API-key providers as auxiliary client fallbacks
When the user only has a z.ai/Kimi/MiniMax API key (no OpenRouter key),
auxiliary tasks (context compression, web summarization, session search)
now fall back to the configured direct provider instead of returning None.

Resolution chain: OpenRouter -> Nous -> Custom endpoint -> Codex OAuth
-> direct API-key providers -> None.

Uses cheap/fast models for auxiliary tasks:
- zai: glm-4.5-flash
- kimi-coding: kimi-k2-turbo-preview
- minimax/minimax-cn: MiniMax-M2.5-highspeed

Vision auxiliary intentionally NOT modified — vision needs multimodal
models (Gemini) that these providers don't serve.
2026-03-06 19:08:54 -08:00
teknium1
9742f11fda chore: add context lengths for Kimi and MiniMax models
Adds DEFAULT_CONTEXT_LENGTHS entries for kimi-k2.5 (262144), kimi-k2-thinking
(262144), kimi-k2-turbo-preview (262144), kimi-k2-0905-preview (131072),
MiniMax-M2.5/M2.5-highspeed/M2.1 (204800), and glm-4.5/4.5-flash (131072).

Avoids unnecessary 2M-token probe on first use with direct providers.
2026-03-06 19:01:38 -08:00
teknium1
388dd4789c feat: add z.ai/GLM, Kimi/Moonshot, MiniMax as first-class providers
Adds 4 new direct API-key providers (zai, kimi-coding, minimax, minimax-cn)
to the inference provider system. All use standard OpenAI-compatible
chat/completions endpoints with Bearer token auth.

Core changes:
- auth.py: Extended ProviderConfig with api_key_env_vars and base_url_env_var
  fields. Added providers to PROVIDER_REGISTRY. Added provider aliases
  (glm, z-ai, zhipu, kimi, moonshot). Added auto-detection of API-key
  providers in resolve_provider(). Added resolve_api_key_provider_credentials()
  and get_api_key_provider_status() helpers.
- runtime_provider.py: Added generic API-key provider branch in
  resolve_runtime_provider() — any provider with auth_type='api_key'
  is automatically handled.
- main.py: Added providers to hermes model menu with generic
  _model_flow_api_key_provider() flow. Updated _has_any_provider_configured()
  to check all provider env vars. Updated argparse --provider choices.
- setup.py: Added providers to setup wizard with API key prompts and
  curated model lists.
- config.py: Added env vars (GLM_API_KEY, KIMI_API_KEY, MINIMAX_API_KEY,
  etc.) to OPTIONAL_ENV_VARS.
- status.py: Added API key display and provider status section.
- doctor.py: Added connectivity checks for each provider endpoint.
- cli.py: Updated provider docstrings.

Docs: Updated README.md, .env.example, cli-config.yaml.example,
cli-commands.md, environment-variables.md, configuration.md.

Tests: 50 new tests covering registry, aliases, resolution, auto-detection,
credential resolution, and runtime provider dispatch.

Inspired by PR #33 (numman-ali) which proposed a provider registry approach.
Credit to tars90percent (PR #473) and manuelschipper (PR #420) for related
provider improvements merged earlier in this changeset.
2026-03-06 18:55:18 -08:00
Teknium
fdebca4573 Merge pull request #571 from NousResearch/rewbs/nous-key-remint-attempt-on-401
fix: implement Nous credential refresh on 401 error for retry logic
2026-03-06 18:52:01 -08:00
teknium1
479dfc096a Merge PR #473: Update model id in OpenRouter from minimax-m2.1 to minimax-m2.5
Authored by tars90percent. Updates remaining minimax-m2.1 references to
minimax-m2.5 in rl_training_tool.py and docs.
2026-03-06 18:43:18 -08:00
teknium1
3c6c11b7c9 Merge PR #420: fix: respect OPENAI_BASE_URL when resolving API key priority
Authored by manuelschipper. Adds GLM-4.7 and GLM-5 context lengths (202752)
to model_metadata.py. The key priority fix (prefer OPENAI_API_KEY for
non-OpenRouter endpoints) was already applied in PR #295; merged the Z.ai
mention into the comment.
2026-03-06 18:43:13 -08:00
Robin Fernandes
bc091eb7ef fix: implement Nous credential refresh on 401 error for retry logic 2026-03-07 13:34:23 +11:00
teknium1
f75b1d21b4 fix: execute_code and delegate_task now respect disabled toolsets
When a user disables the web toolset via 'hermes tools', the execute_code
schema description still hardcoded web_search/web_extract as available,
causing the model to keep trying to use them. Similarly, delegate_task
always defaulted to ['terminal', 'file', 'web'] for subagents regardless
of the parent's config.

Changes:
- execute_code schema is now built dynamically via build_execute_code_schema()
  based on which sandbox tools are actually enabled
- model_tools.py rebuilds the execute_code schema at definition time using
  the intersection of sandbox-allowed and session-enabled tools
- delegate_task now inherits the parent agent's enabled_toolsets instead of
  hardcoding DEFAULT_TOOLSETS when no explicit toolsets are specified
- delegate_task description updated to say 'inherits your enabled toolsets'

Reported by kotyKD on Discord.
2026-03-06 17:36:14 -08:00
teknium1
94053d75a6 fix: custom endpoint no longer leaks OPENROUTER_API_KEY (#560)
API key selection is now base_url-aware: when the resolved base_url
targets OpenRouter, OPENROUTER_API_KEY takes priority (preserving the
#289 fix). When hitting any other endpoint (Z.ai, vLLM, custom, etc.),
OPENAI_API_KEY takes priority so the OpenRouter key doesn't leak.

Applied in both the runtime provider resolver (the real code path) and
the CLI initial default (for consistency).

Fixes #560.
2026-03-06 17:16:14 -08:00
teknium1
2a68099675 fix(tests): isolate tests from user ~/.hermes/ config and SOUL.md
_make_cli() now patches CLI_CONFIG with clean defaults so
test_cli_init tests don't depend on the developer's local config.yaml.
test_empty_dir_returns_empty now mocks Path.home() so it doesn't pick
up a global SOUL.md.

Credit to teyrebaz33 for identifying and fixing these in PR #557.
Fixes #555.
2026-03-06 17:10:35 -08:00
teknium1
6cd3bc6640 Merge PR #563: fix: prevent data loss in skills sync on copy/update failure
Authored by 0xbyt4. Two bugs fixed:
1. Failed copytree no longer poisons the manifest (skill gets retried)
2. Failed update no longer destroys user's copy (backup + restore)
2026-03-06 17:01:30 -08:00
0xbyt4
211b55815e fix: prevent data loss in skills sync on copy/update failure
Two bugs in sync_skills():

1. Failed copytree poisons manifest: when shutil.copytree fails (disk
   full, permission error), the skill is still recorded in the manifest.
   On the next sync, the skill appears as "in manifest but not on disk"
   which is interpreted as "user deliberately deleted it" — the skill
   is never retried.  Fix: only write to manifest on successful copy.

2. Failed update destroys user copy: rmtree deletes the existing skill
   directory before copytree runs. If copytree then fails, the user's
   skill is gone with no way to recover.  Fix: move to .bak before
   copying, restore from backup if copytree fails.

Both bugs are proven by new regression tests that fail on the old code
and pass on the fix.
2026-03-07 03:58:32 +03:00
teknium1
8ae4a6f824 fix: improve handling of empty responses after tool calls
- Added fallback mechanism to utilize previous content when the model generates an empty response after tool calls, reducing unnecessary API retries.
- Enhanced logging to indicate when prior content is used as a final response.
- Updated logic to ensure that genuine empty responses are retried appropriately, maintaining user experience.
2026-03-06 16:54:31 -08:00
teknium1
b98301677a docs: add /insights to all help menus and documentation
- website/docs/reference/cli-commands.md: Added 'hermes insights' terminal
  command section with --days and --source flags, plus /insights slash command
  in the Conversation section
- website/docs/user-guide/cli.md: Added /insights to slash commands table
- website/docs/user-guide/messaging/index.md: Added /insights to gateway
  chat commands table
- website/docs/user-guide/sessions.md: Added cross-reference to hermes
  insights from the sessions stats section
2026-03-06 16:48:58 -08:00
teknium1
f2fdde5ba4 fix: show user-modified skills count in hermes update output 2026-03-06 16:14:43 -08:00
teknium1
4f56e31dc7 fix: track origin hashes in skills manifest to preserve user modifications
Upgrade skills_sync manifest to v2 format (name:origin_hash). The origin
hash records the MD5 of the bundled skill at the time it was last synced.

On update, the user's copy is compared against the origin hash:
- User copy == origin hash → unmodified → safe to update from bundled
- User copy != origin hash → user customized → skip (preserve changes)

v1 manifests (plain names) are auto-migrated: the user's current hash
becomes the baseline, so future syncs can detect modifications.

Output now shows user-modified skills:
  ~ whisper (user-modified, skipping)

27 tests covering all scenarios including v1→v2 migration, user
modification detection, update after migration, and origin hash tracking.
2009 tests pass.
2026-03-06 16:13:58 -08:00
Teknium
6d3804770c Merge pull request #552 from NousResearch/feat/insights
feat: /insights command — usage analytics, cost estimation & activity patterns
2026-03-06 16:00:28 -08:00
teknium1
ab0f4126cf fix: restore all removed bundled skills + fix skills sync system
- Restored 21 skills removed in commits 757d012 and 740dd92:
  accelerate, audiocraft, code-review, faiss, flash-attention, gguf,
  grpo-rl-training, guidance, llava, nemo-curator, obliteratus, peft,
  pytorch-fsdp, pytorch-lightning, simpo, slime, stable-diffusion,
  tensorrt-llm, torchtitan, trl-fine-tuning, whisper

- Rewrote sync_skills() with proper update semantics:
  * New skills (not in manifest): copied to user dir
  * Existing skills (in manifest + on disk): updated via hash comparison
  * User-deleted skills (in manifest, not on disk): respected, not re-added
  * Stale manifest entries (removed from bundled): cleaned from manifest

- Added sync_skills() to CLI startup (cmd_chat) and gateway startup
  (start_gateway) — previously only ran during 'hermes update'

- Updated cmd_update output to show new/updated/cleaned counts

- Rewrote tests: 20 tests covering manifest CRUD, dir hashing, fresh
  install, user deletion respect, update detection, stale cleanup, and
  name collision handling

75 bundled skills total. 2002 tests pass.
2026-03-06 15:57:30 -08:00
teknium1
585f8528b2 fix: deep review — prefix matching, tool_calls extraction, query perf, serialization
Issues found and fixed during deep code path review:

1. CRITICAL: Prefix matching returned wrong prices for dated model names
   - 'gpt-4o-mini-2024-07-18' matched gpt-4o ($2.50) instead of gpt-4o-mini ($0.15)
   - Same for o3-mini→o3 (9x), gpt-4.1-mini→gpt-4.1 (5x), gpt-4.1-nano→gpt-4.1 (20x)
   - Fix: use longest-match-wins strategy instead of first-match
   - Removed dangerous key.startswith(bare) reverse matching

2. CRITICAL: Top Tools section was empty for CLI sessions
   - run_agent.py doesn't set tool_name on tool response messages (pre-existing)
   - Insights now also extracts tool names from tool_calls JSON on assistant
     messages, which IS populated for all sessions
   - Uses max() merge strategy to avoid double-counting between sources

3. SELECT * replaced with explicit column list
   - Skips system_prompt and model_config blobs (can be thousands of chars)
   - Reduces memory and I/O for large session counts

4. Sets in overview dict converted to sorted lists
   - models_with_pricing / models_without_pricing were Python sets
   - Sets aren't JSON-serializable — would crash json.dumps()

5. Negative duration guard
   - end > start check prevents negative durations from clock drift

6. Model breakdown sort fallback
   - When all tokens are 0, now sorts by session count instead of arbitrary order

7. Removed unused timedelta import

Added 6 new tests: dated model pricing (4), tool_calls JSON extraction,
JSON serialization safety. Total: 69 tests.
2026-03-06 14:50:57 -08:00
teknium1
75f523f5c0 fix: unknown/custom models get zero cost instead of fake estimates
Custom OAI endpoints, self-hosted models, and local inference should NOT
show fabricated cost estimates. Changed default pricing from $3/$12 per
million tokens to $0/$0 for unrecognized models.

- Added _has_known_pricing() to distinguish commercial vs custom models
- Models with known pricing show $ amounts; unknown models show 'N/A'
- Overview shows asterisk + note when some models lack pricing data
- Gateway format adds '(excludes custom/self-hosted models)' note
- Added 7 new tests for custom model cost handling
2026-03-06 14:18:19 -08:00
teknium1
68fbae5692 docs: add Custom & Self-Hosted LLM Providers guide
Comprehensive guide for using Hermes Agent with alternative LLM backends:
- Ollama (local models, zero config)
- vLLM (high-performance GPU inference)
- SGLang (RadixAttention, prefix caching)
- llama.cpp / llama-server (CPU & Metal inference)
- LiteLLM Proxy (multi-provider gateway)
- ClawRouter (cost-optimized routing with complexity scoring)
- 10+ other compatible providers table (Together, Groq, DeepSeek, etc.)
- Choosing the Right Setup decision table
- General custom endpoint setup instructions

All of these work via the existing OPENAI_BASE_URL + OPENAI_API_KEY
custom endpoint support — no code changes needed.
2026-03-06 14:16:06 -08:00
teknium1
80f1dd8d37 docs: add Custom & Self-Hosted LLM Providers guide
Comprehensive guide for using Hermes Agent with alternative LLM backends:
- Ollama (local models, zero config)
- vLLM (high-performance GPU inference)
- SGLang (RadixAttention, prefix caching)
- llama.cpp / llama-server (CPU & Metal inference)
- LiteLLM Proxy (multi-provider gateway)
- ClawRouter (cost-optimized routing with complexity scoring)
- 10+ other compatible providers table (Together, Groq, DeepSeek, etc.)
- Choosing the Right Setup decision table
- General custom endpoint setup instructions

All of these work via the existing OPENAI_BASE_URL + OPENAI_API_KEY
custom endpoint support — no code changes needed.
2026-03-06 14:15:57 -08:00
teknium1
b52b37ae64 feat: add /insights command with usage analytics and cost estimation
Inspired by Claude Code's /insights, adapted for Hermes Agent's multi-platform
architecture. Analyzes session history from state.db to produce comprehensive
usage insights.

Features:
- Overview stats: sessions, messages, tokens, estimated cost, active time
- Model breakdown: per-model sessions, tokens, and cost estimation
- Platform breakdown: CLI vs Telegram vs Discord etc. (unique to Hermes)
- Tool usage ranking: most-used tools with percentages
- Activity patterns: day-of-week chart, peak hours, streaks
- Notable sessions: longest, most messages, most tokens, most tool calls
- Cost estimation: real pricing data for 25+ models (OpenAI, Anthropic,
  DeepSeek, Google, Meta) with fuzzy model name matching
- Configurable time window: --days flag (default 30)
- Source filtering: --source flag to filter by platform

Three entry points:
- /insights slash command in CLI (supports --days and --source flags)
- /insights slash command in gateway (compact markdown format)
- hermes insights CLI subcommand (standalone)

Includes 56 tests covering pricing helpers, format helpers, empty DB,
populated DB with multi-platform data, filtering, formatting, and edge cases.
2026-03-06 14:04:59 -08:00
teknium1
d63b363cde refactor: extract atomic_json_write helper, add 24 checkpoint tests
Extract the duplicated temp-file + fsync + os.replace pattern from
batch_runner.py (1 instance) and process_registry.py (2 instances) into
a shared utils.atomic_json_write() function.

Add 12 tests for atomic_json_write covering: valid JSON, parent dir
creation, overwrite, crash safety (original preserved on error), no temp
file leaks, string paths, unicode, custom indent, concurrent writes.

Add 12 tests for batch_runner checkpoint behavior covering:
_save_checkpoint (valid JSON, last_updated, overwrite, lock/no-lock,
parent dirs, no temp leaks), _load_checkpoint (missing file, existing
data, corrupt JSON), and resume logic (preserves prior progress,
different run_name starts fresh).
2026-03-06 05:50:12 -08:00
teknium1
c05c60665e Merge PR #298: Make process_registry checkpoint writes atomic
Authored by aydnOktay. Companion to PR #297 (batch_runner). Applies the
same atomic write pattern (temp file + fsync + os.replace) to both
_write_checkpoint() and recover_from_checkpoint() in process_registry.py.
Prevents checkpoint corruption on gateway crashes. Also improves error
handling: bare 'pass' replaced with logger.debug(..., exc_info=True)
for better debugging.
2026-03-06 05:32:35 -08:00
teknium1
b4873a5de7 fix(setup): Escape skips instead of exiting, add control hints to all prompts
Previously pressing Escape in any setup wizard menu called sys.exit(1),
killing the entire wizard with no way to recover. Now:

- prompt_choice: Escape keeps the current default and moves on (prints
  'Skipped (keeping current)'). Shows '↑/↓ Navigate  Enter Select
  Esc Skip  Ctrl+C Exit' hint.
- prompt_checklist: Escape returns pre-selected items instead of empty
  list. Shows 'SPACE Toggle  ENTER Confirm  ESC Skip  Ctrl+C Exit'.
- prompt_yes_no: now catches KeyboardInterrupt/EOFError properly.
- Fallback number prompts also show control hints.

Ctrl+C still exits the wizard cleanly.
2026-03-06 05:27:11 -08:00
teknium1
913f8ce0a5 Merge PR #297: Make batch_runner checkpoint incremental and atomic
Authored by aydnOktay. Three improvements to batch_runner fault tolerance:
1) Atomic checkpoint writes (temp file + fsync + os.replace) to prevent
   corruption on crashes — same pattern as auth.py's _save_auth_store().
2) Incremental checkpoints after each batch result instead of only at end,
   so interrupted runs can resume with minimal progress loss.
3) Resume loads existing checkpoint state instead of initializing empty,
   preventing clobber of prior progress.

Conflict resolved: kept both the incremental checkpoint logic (PR) and
the batch worker error handling (HEAD) in the imap_unordered loop.
2026-03-06 05:16:31 -08:00
teknium1
4a63737227 Merge PR #433: fix(whatsapp): replace Linux-only fuser with cross-platform port cleanup
Authored by Farukest. Fixes #432. Extracts _kill_port_process() helper
that uses netstat+taskkill on Windows and fuser on Linux. Previously,
fuser calls were inline with bare except-pass, so on Windows orphaned
bridge processes were never cleaned up — causing 'address already in use'
errors on reconnect. Includes 5 tests covering both platforms, port
matching edge cases, and exception suppression.
2026-03-06 04:52:25 -08:00
teknium1
3e93db16bd Merge PR #436: fix: use _max_tokens_param in max-iterations retry path
Authored by Farukest. Fixes #435. The retry summary in
_handle_max_iterations() hardcoded max_tokens instead of using
_max_tokens_param(), which returns max_completion_tokens for direct
OpenAI API (required by gpt-4o, o-series). The first attempt already
used _max_tokens_param correctly — only the retry path was wrong.
Includes 4 tests for _max_tokens_param provider detection.
2026-03-06 04:46:24 -08:00
teknium1
f863a42351 Merge PR #441: fix(gateway): return response from /retry handler instead of discarding it
Authored by PercyDikec. Fixes #440. _handle_retry_command called
_handle_message(retry_event) but discarded the return value, returning
None instead. Since only _process_message_background sends the response
via adapter.send(), this meant the agent would run (tool progress was
visible) but the final answer was silently dropped on all platforms.
2026-03-06 04:42:54 -08:00
teknium1
dc55f493be fix: add missing re.DOTALL to DeepSeek V3.1 parser (same bug as V3)
The V3.1 parser had the same issue — .*? without re.DOTALL fails to
match multi-line JSON arguments. Found during review of PR #444.
2026-03-06 04:41:47 -08:00
teknium1
936fda3f9e Merge PR #444: fix: add missing re.DOTALL flag to DeepSeek V3 tool call parser
Authored by PercyDikec. Fixes #443. Without re.DOTALL, the regex .*
doesn't match newlines, so multi-line JSON arguments (the normal case)
silently fail to parse. Every other parser in the codebase that matches
across lines already uses re.DOTALL.
2026-03-06 04:39:53 -08:00
teknium1
ecb8148a9f Merge PR #448: fix(cli): use correct dict key for codex auth file path in status output
Authored by PercyDikec. Fixes #447. The status display used
codex_status.get('auth_file') but get_codex_auth_status() in auth.py
returns the path under 'auth_store' (line 1220). This one-char key
mismatch silently dropped the auth file path from 'hermes status'.
2026-03-06 04:34:46 -08:00
teknium1
2dbbedc05a docs: rebrand messaging — 'the self-improving AI agent'
- Lead with the learning loop: autonomous skill creation, skill
  self-improvement, memory nudges, FTS5 session search, Honcho
  dialectic user modeling
- 'Runs anywhere' angle: 6 backends, serverless persistence with
  Daytona/Modal, not tied to your laptop
- 'Built by model trainers' replaces 'model-agnostic'
- Updated README tagline, feature table, subtitle
- Updated docs landing page hero, description, key features
- Updated docusaurus tagline and pyproject.toml description
2026-03-06 04:34:06 -08:00
teknium1
c30967806c test: add 26 tests for set_config_value secret routing
Verifies explicit allowlist keys, catch-all _API_KEY/_TOKEN patterns,
case insensitivity, TERMINAL_SSH prefix, and config.yaml routing for
non-secret keys. Covers the fix from PR #469.
2026-03-06 04:26:18 -08:00
teknium1
145f719d30 Merge PR #469: fix(config): route API keys and tokens to .env instead of config.yaml
Authored by ygd58. Fixes #465. Adds missing keys to allowlist and
catch-all patterns (_API_KEY, _TOKEN suffixes) for future-proofing.
2026-03-06 04:23:49 -08:00
teknium1
b89eb29174 fix: correct mock tool name 'search' → 'search_files' in test_code_execution
The mock handler checked for function_name == 'search' but the RPC
sends 'search_files'. Any test exercising search_files through the
mock would get 'Unknown tool' instead of the canned response.
2026-03-06 03:53:43 -08:00
teknium1
3670089a42 docs: add Daytona to batch_runner, process_registry, agent_loop, tool_context
Add daytona_image to batch_runner per-prompt container image overrides
so batch processing works with the Daytona backend. Update inline
comments in RL environment files (agent_loop, tool_context) and
process_registry docstrings to include Daytona in backend lists.
2026-03-06 03:49:59 -08:00
teknium1
3982fcf095 fix: sync execute_code sandbox stubs with real tool schemas
The _TOOL_STUBS dict in code_execution_tool.py was out of sync with the
actual tool schemas, causing TypeErrors when the LLM used parameters it
sees in its system prompt but the sandbox stubs didn't accept:

search_files:
  - Added missing params: context, offset, output_mode
  - Fixed target default: 'grep' → 'content' (old value was obsolete)

patch:
  - Added missing params: mode, patch (V4A multi-file patch support)

Also added 4 drift-detection tests (TestStubSchemaDrift) that will
catch future divergence between stubs and real schemas:
  - test_stubs_cover_all_schema_params: every schema param in stub
  - test_stubs_pass_all_params_to_rpc: every stub param sent over RPC
  - test_search_files_target_uses_current_values: no obsolete values
  - test_generated_module_accepts_all_params: generated code compiles

All 28 tests pass.
2026-03-06 03:40:06 -08:00
teknium1
8481fdcf08 docs: complete Daytona backend documentation coverage
Update all remaining files that enumerate terminal backends to include
Daytona. Covers security docs (bypass info, backend comparison table),
environment variables reference (DAYTONA_API_KEY, TERMINAL_DAYTONA_IMAGE,
container resources header), AGENTS.md (architecture tree, config keys),
environments/README.md, hermes_base_env.py field description, and various
module docstrings.

Follow-up to PR #451 merge.
2026-03-06 03:37:05 -08:00
teknium1
39299e2de4 Merge PR #451: feat: Add Daytona environment backend
Authored by rovle. Adds Daytona as the sixth terminal execution backend
with cloud sandboxes, persistent workspaces, and full CLI/gateway integration.
Includes 24 unit tests and 8 integration tests.
2026-03-06 03:32:40 -08:00
teknium1
efec4fcaab feat(execute_code): add json_parse, shell_quote, retry helpers to sandbox
The execute_code sandbox generates a hermes_tools.py stub module for LLM
scripts. Three common failure modes keep tripping up scripts:

1. json.loads(strict=True) rejects control chars in terminal() output
   (e.g., GitHub issue bodies with literal tabs/newlines)
2. Shell backtick/quote interpretation when interpolating dynamic content
   into terminal() commands (markdown with backticks gets eaten by bash)
3. No retry logic for transient network failures (API timeouts, rate limits)

Adds three convenience helpers to the generated hermes_tools module:

- json_parse(text) — json.loads with strict=False for tolerant parsing
- shell_quote(s) — shlex.quote() for safe shell interpolation
- retry(fn, max_attempts=3, delay=2) — exponential backoff wrapper

Also updates the EXECUTE_CODE_SCHEMA description to document these helpers
so LLMs know they're available without importing anything extra.

Includes 7 new tests (unit + integration) covering all three helpers.
2026-03-06 01:52:46 -08:00
teknium1
5ce2c47d60 docs: update all docs for optional-skills and browse command
Update 7 documentation files to reflect:
- optional-skills/ directory in all project structure trees
- 'hermes skills browse' in all CLI command listings
- '/skills browse' in all slash command references
- Three-tier skill placement (bundled → optional → hub)
- 'official' trust level in trust level tables
- Updated /skills description from 'Search, install...' to 'Browse, search...'

Files updated:
- CONTRIBUTING.md (skill classification, project tree, section title)
- AGENTS.md (project tree, Skills Hub description, source adapters list)
- website/docs/reference/cli-commands.md (CLI table, slash command table)
- website/docs/developer-guide/creating-skills.md (structure, classification, trust)
- website/docs/user-guide/features/skills.md (hub commands, trust table, slash commands)
- website/docs/user-guide/cli.md (slash command description)
- website/docs/developer-guide/architecture.md (project tree)
2026-03-06 01:46:34 -08:00
teknium1
f6f3d1de9b fix: review fixes — path traversal guard, trust_style consistency, edge cases
Address code review findings:

Security (Medium):
- Path traversal guard in OptionalSkillSource.fetch() — resolve() and
  validate that the path stays within optional-skills/ before reading

Bug fixes (Medium):
- Add 'builtin' to trust_style dicts in do_inspect() and
  _resolve_short_name() — official skills now show bright_cyan 'official'
  label consistently across all display functions (5/5 dicts fixed)

Edge cases (Low):
- Clamp page_size to [1, 100] in do_browse() to prevent ZeroDivisionError
- Update SkillMeta.source docstring to include 'official'
- Add browse command to optional-skills/DESCRIPTION.md
2026-03-06 01:40:01 -08:00
teknium1
ec0fe3242a feat: 'hermes skills browse' — paginated browsing of all hub skills
Add a browse command that shows all available skills across all registries,
paginated and sorted with official skills first.

Usage:
  hermes skills browse                    # all sources, page 1
  hermes skills browse --source official  # only official optional skills
  hermes skills browse --page 2           # page 2
  hermes skills browse --size 30          # 30 per page
  /skills browse                          # slash command in chat

Features:
- Official optional skills always appear first (★ marker, cyan styling)
- Per-source limits prevent overloading (100 official/github, 50 others)
- Deduplication by name preferring higher trust
- Sorted: official > trusted > community, then alphabetical
- Page navigation hints at bottom
- Source counts summary
- Works in both CLI and /skills chat interface
- Added 'official' as source filter option for search command too
2026-03-06 01:29:45 -08:00
teknium1
f2e24faaca feat: optional skills — official skills shipped but not activated by default
Add 'optional-skills/' directory for official skills that ship with the repo
but are not copied to ~/.hermes/skills/ during setup. They are:
- NOT shown to the model in the system prompt
- NOT copied during hermes setup/update
- Discoverable via 'hermes skills search' labeled as 'official'
- Installable via 'hermes skills install' with builtin trust (no third-party warning)
- Auto-categorized on install based on directory structure

Implementation:
- OptionalSkillSource adapter in tools/skills_hub.py (search/fetch/inspect)
- Added to create_source_router() as first source (highest priority)
- Trust level 'builtin' for official skills in skills_guard.py
- Friendly install message for official skills (no third-party warning)
- 'official' label in cyan in search results and skill list

First optional skill: Blackbox CLI (autonomous-ai-agents/blackbox)
- Multi-model coding agent with built-in judge/Chairman pattern
- Delegates to Claude, Codex, Gemini, and Blackbox models
- Open-source CLI (GPL-3.0, TypeScript, forked from Gemini CLI)
- Requires paid Blackbox AI API key

Refs: #475
2026-03-06 01:24:11 -08:00
teknium1
8c80b96318 chore: update OpenRouter model list
- Remove opus-4.5 and gpt-5.2
- Reorder GPT: 5.4-pro, 5.4, 5.3-codex
- Add qwen/qwen3.5-plus-02-15 and qwen/qwen3.5-35b-a3b
- Update z-ai/glm-4.7 → glm-5
- Update minimax/minimax-m2.1 → minimax-m2.5
2026-03-06 00:52:45 -08:00
teknium1
2387465dcc chore: add openai/gpt-5.4-pro and stepfun/step-3.5-flash to OpenRouter models 2026-03-06 00:49:25 -08:00
tars90percent
32636ecf8a Update MiniMax model ID from m2.1 to m2.5 2026-03-06 16:47:48 +08:00
ygd58
6055adbe1b fix(config): route API keys and tokens to .env instead of config.yaml 2026-03-06 08:55:36 +01:00
teknium1
ffd2f8dc50 docs: add Vision & Image Paste guide with platform compatibility
New docs page covering clipboard image paste across all platforms:
- Platform compatibility table (macOS, Linux X11/Wayland, WSL2, VSCode, SSH)
- Setup instructions per platform (xclip, wl-paste, powershell.exe)
- Explanation of terminal paste limitations and why /paste exists
- SSH workarounds (file upload, URLs, X11 forwarding, messaging)
- Keybinding reference (Alt+V, Ctrl+V, /paste) with when each works

Also updates CLI commands reference with /paste command and
Alt+V keybinding documentation.
2026-03-05 23:51:46 -08:00
teknium1
e93b4d1dcd feat: Alt+V keybinding for clipboard image paste
Alt key combos pass through all terminal emulators (sent as ESC + key),
unlike Ctrl+V which terminals intercept for text paste. This is the
reliable way to attach clipboard images on WSL2, Windows Terminal,
VSCode, and SSH sessions where Ctrl+V never reaches the application
for image-only clipboard content.

Also adds 'Paste image: Alt+V (or /paste)' hint to /help output.
2026-03-05 22:48:39 -08:00
teknium1
014a5b712d fix: prevent duplicate gateway instances from running simultaneously
start_gateway() now checks for an existing running instance via PID file
before starting. If another gateway is already running under the same
HERMES_HOME, it refuses to start with a clear error message directing the
user to 'hermes gateway restart' or 'hermes gateway stop'.

Also fixes gateway/status.py to respect the HERMES_HOME env var instead of
hardcoding ~/.hermes. This scopes the PID file per HERMES_HOME directory,
which lays the groundwork for future multi-profile support where distinct
HERMES_HOME directories can run concurrent gateway instances independently.
2026-03-05 20:35:33 -08:00
teknium1
2317d115cd fix: clipboard image paste on WSL2, Wayland, and VSCode terminal
The original implementation only supported xclip (X11), which silently
fails on WSL2 (can't access Windows clipboard for images), Wayland
desktops (xclip is X11-only), and VSCode terminal on WSL2.

Clipboard backend changes (hermes_cli/clipboard.py):
- WSL2: detect via /proc/version, use powershell.exe with .NET
  System.Windows.Forms.Clipboard to extract images as base64 PNG
- Wayland: use wl-paste with MIME type detection, auto-convert BMP
  to PNG for WSLg environments (via Pillow or ImageMagick)
- Dispatch order: WSL → Wayland → X11 (xclip), with fallthrough
- New has_clipboard_image() for lightweight clipboard checks
- Cache WSL detection result per-process

CLI changes (cli.py):
- /paste command: explicit clipboard image check for terminals where
  BracketedPaste doesn't fire (image-only clipboard in VSCode/WinTerm)
- Ctrl+V keybinding: fallback for Linux terminals where Ctrl+V sends
  raw byte instead of triggering bracketed paste

Tests: 80 tests (up from 37) covering WSL, Wayland, X11 dispatch,
BMP conversion, has_clipboard_image, and /paste command.
2026-03-05 20:22:44 -08:00
teknium1
8253b54be9 test: strengthen assertions in skill_manager + memory_tool (batch 3)
test_skill_manager_tool.py (20 weak → 0):
  - Validation error messages verified against exact strings
  - Name validation: checks specific invalid name echoed in error
  - Frontmatter validation: exact error text for missing fields,
    unclosed markers, empty content, invalid YAML
  - File path validation: traversal, disallowed dirs, root-level

test_memory_tool.py (13 weak → 0):
  - Security scan tests verify both 'Blocked' prefix AND specific
    threat pattern ID (prompt_injection, exfil_curl, etc.)
  - Invisible unicode tests verify exact codepoint strings
  - Snapshot test verifies type, header, content, and isolation
2026-03-05 18:51:43 -08:00
teknium1
5c867fd79f test: strengthen assertions across 3 more test files (batch 2)
test_run_agent.py (2 weak → 0, +13 assertions):
  - Session ID validated against actual YYYYMMDD_HHMMSS_hex format
  - API failure verifies error message propagation
  - Invalid JSON args verifies empty dict fallback + message structure
  - Context compression verifies final_response + completed flag
  - Invalid tool name retry verifies api_calls count
  - Invalid response verifies completed/failed/error structure

test_model_tools.py (3 weak → 0):
  - Unknown tool error includes tool name in message
  - Exception returns dict with 'error' key + non-empty message
  - get_all_tool_names verifies both web_search AND terminal present

test_approval.py (1 weak → 0, assert ratio 1.1 → 2.2):
  - Dangerous commands verify description content (delete, shell, drop, etc.)
  - Safe commands explicitly assert key AND desc are None
  - Pre/post condition checks for state management
2026-03-05 18:46:30 -08:00
teknium1
a44e041acf test: strengthen assertions across 7 test files (batch 1)
Replaced weak 'is not None' / '> 0' / 'len >= 1' assertions with
concrete value checks across the most flagged test files:

gateway/test_pairing.py (11 weak → 0):
  - Code assertions verify isinstance + len == CODE_LENGTH
  - Approval results verify dict structure + specific user_id/user_name
  - Added code2 != code1 check in rate_limit_expires

test_hermes_state.py (6 weak → 0):
  - ended_at verified as float timestamp
  - Search result counts exact (== 2, not >= 1)
  - Context verified as non-empty list
  - Export verified as dict, session ID verified

test_cli_init.py (4 weak → 0):
  - max_turns asserts exact value (60)
  - model asserts string with provider/name format

gateway/test_hooks.py (2 zero-assert tests → fixed):
  - test_no_handlers_for_event: verifies no handler registered
  - test_handler_error_does_not_propagate: verifies handler count + return

gateway/test_platform_base.py (9 weak image tests → fixed):
  - extract_images tests now verify actual URL and alt_text
  - truncate_message verifies content preservation after splitting

cron/test_scheduler.py (1 weak → 0):
  - resolve_origin verifies dict equality, not just existence

cron/test_jobs.py (2 weak → 0 + 4 new tests):
  - Schedule parsing verifies ISO timestamp type
  - Cron expression verifies result is valid datetime string
  - NEW: 4 tests for update_job() (was completely untested)
2026-03-05 18:39:37 -08:00
teknium1
e9f05b3524 test: comprehensive tests for model metadata + firecrawl config
model_metadata tests (61 tests, was 39):
  - Token estimation: concrete value assertions, unicode, tool_call messages,
    vision multimodal content, additive verification
  - Context length resolution: cache-over-API priority, no-base_url skips cache,
    missing context_length key in API response
  - API metadata fetch: canonical_slug aliasing, TTL expiry with time mock,
    stale cache fallback on API failure, malformed JSON resilience
  - Probe tiers: above-max returns 2M, zero returns None
  - Error parsing: Anthropic format ('X > Y maximum'), LM Studio, empty string,
    unreasonably large numbers — also fixed parser to handle Anthropic format
  - Cache: corruption resilience (garbage YAML, wrong structure), value updates,
    special chars in model names

Firecrawl config tests (8 tests, was 4):
  - Singleton caching (core purpose — verified constructor called once)
  - Constructor failure recovery (retry after exception)
  - Return value actually asserted (not just constructor args)
  - Empty string env vars treated as absent
  - Proper setup/teardown for env var isolation
2026-03-05 18:22:39 -08:00
teknium1
e2a834578d refactor: extract clipboard methods + comprehensive tests (37 tests)
Refactored image paste internals for testability:
- Extracted _try_attach_clipboard_image() method (clipboard → state)
- Extracted _build_multimodal_content() method (images → OpenAI format)
- chat() now delegates to these instead of inline logic

Tests organized in 4 levels:
  Level 1 (19 tests): Clipboard module — every platform path with
    realistic subprocess simulation (tools writing files, timeouts,
    empty files, cleanup on failure)
  Level 2 (8 tests): _build_multimodal_content — base64 encoding,
    MIME types (png/jpg/webp/unknown), missing files, multiple images,
    default question for empty text
  Level 3 (5 tests): _try_attach_clipboard_image — state management,
    counter increment/rollback, naming convention, mixed success/failure
  Level 4 (5 tests): Queue routing — tuple unpacking, command detection,
    images-only payloads, text-only payloads
2026-03-05 18:07:53 -08:00
teknium1
ffc752a79e test: improve clipboard tests with realistic scenarios and multimodal coverage
Rewrote clipboard tests from 11 shallow mocks to 21 realistic tests:
- Success paths now simulate tools actually writing files (not pre-created)
- osascript: success with PNG, success with TIFF, extraction-fail cases
- pngpaste: empty file rejection edge case
- Linux: extraction failure cleanup verification
- New TestMultimodalConversion class: base64 encoding, MIME types,
  multiple images, missing file handling, default question fallback
2026-03-05 17:58:06 -08:00
teknium1
399562a7d1 feat: clipboard image paste in CLI (Cmd+V / Ctrl+V)
Copy an image to clipboard (screenshot, browser, etc.) and paste into
the Hermes CLI. The image is saved to ~/.hermes/images/, shown as a
badge above the input ([📎 Image #1]), and sent to the model as a
base64-encoded OpenAI vision multimodal content block.

Implementation:
- hermes_cli/clipboard.py: clean module with platform-specific extraction
  - macOS: pngpaste (if installed) → osascript fallback (always available)
  - Linux: xclip (apt install xclip)
- cli.py: BracketedPaste key handler checks clipboard on every paste,
  image bar widget shows attached images, chat() converts to multimodal
  content format, Ctrl+C clears attachments

Inspired by @m0at's fork (https://github.com/m0at/hermes-agent) which
implemented image paste support for local vision models. Reimplemented
cleanly as a separate module with tests.
2026-03-05 17:55:41 -08:00
teknium1
fec8a0da72 Merge PR #296: fix(cron): close lock_fd on failed flock to prevent fd leak
Authored by alireza78a. When flock() raises on a concurrent tick, the
file descriptor was leaked because the except clause returned without
closing it. Adds lock_fd=None init and close in the except path.
2026-03-05 17:05:06 -08:00
teknium1
9f4542b3db fix: require Python 3.11+ in pyproject.toml
Was incorrectly set to >=3.10. Hermes uses tomllib and other 3.11+
features. CONTRIBUTING.md and README already say 3.11+.
2026-03-05 17:04:08 -08:00
teknium1
363633e2ba fix: allow self-hosted Firecrawl without API key + add self-hosting docs
On top of PR #460: self-hosted Firecrawl instances don't require an API
key (USE_DB_AUTHENTICATION=false), so don't force users to set a dummy
FIRECRAWL_API_KEY when FIRECRAWL_API_URL is set. Also adds a proper
self-hosting section to the configuration docs explaining what you get,
what you lose, and how to set it up (Docker stack, tradeoffs vs cloud).

Added 2 more tests (URL-only without key, neither-set raises).
2026-03-05 16:44:21 -08:00
teknium1
a41ba57a7a Merge PR #460: feat(tools): add support for self-hosted firecrawl
Authored by caentzminger. Adds optional FIRECRAWL_API_URL env var to point
the Firecrawl client at a self-hosted instance instead of the cloud API.
2026-03-05 16:41:30 -08:00
teknium1
884c8ea70a chore: add openai/gpt-5.4 to OpenRouter preferred models list 2026-03-05 16:13:45 -08:00
teknium1
c886333d32 feat: smart context length probing with persistent caching + banner display
Replaces the unsafe 128K fallback for unknown models with a descending
probe strategy (2M → 1M → 512K → 200K → 128K → 64K → 32K). When a
context-length error occurs, the agent steps down tiers and retries.
The discovered limit is cached per model+provider combo in
~/.hermes/context_length_cache.yaml so subsequent sessions skip probing.

Also parses API error messages to extract the actual context limit
(e.g. 'maximum context length is 32768 tokens') for instant resolution.

The CLI banner now displays the context window size next to the model
name (e.g. 'claude-opus-4 · 200K context · Nous Research').

Changes:
- agent/model_metadata.py: CONTEXT_PROBE_TIERS, persistent cache
  (save/load/get), parse_context_limit_from_error(), get_next_probe_tier()
- agent/context_compressor.py: accepts base_url, passes to metadata
- run_agent.py: step-down logic in context error handler, caches on success
- cli.py + hermes_cli/banner.py: context length in welcome banner
- tests: 22 new tests for probing, parsing, and caching

Addresses #132. PR #319's approach (8K default) rejected — too conservative.
2026-03-05 16:09:57 -08:00
teknium1
55b173dd03 refactor: move shutil import to module level
Cleanup on top of PR #305 — replace two inline 'import shutil as _shutil'
with a single module-level import.
2026-03-05 15:57:05 -08:00
dmahan93
9079a27814 fix: prompt box and response box span full terminal width on wide screens
- Replace hardcoded '─' * 200 horizontal rules with Window(char='─')
  so prompt_toolkit fills the entire terminal width automatically
- Use shutil.get_terminal_size().columns instead of Rich Console.width
  for response box, separator line, and input height calculation
  (more reliable inside patch_stdout context)
2026-03-05 15:57:05 -08:00
caentzminger
d7d10b14cd feat(tools): add support for self-hosted firecrawl
Adds optional FIRECRAWL_API_URL environment variable to support
self-hosted Firecrawl deployments alongside the cloud service.

- Add FIRECRAWL_API_URL to optional env vars in hermes_cli/config.py
- Update _get_firecrawl_client() in tools/web_tools.py to accept custom API URL
- Add tests for client initialization with/without URL
- Document new env var in installation and config guides
2026-03-05 16:16:18 -06:00
rovle
a6499b6107 fix(daytona): use shell timeout wrapper instead of broken SDK exec timeout
The Daytona SDK's process.exec(timeout=N) parameter is not enforced —
the server-side timeout never fires and the SDK has no client-side
fallback, causing commands to hang indefinitely.

Fix: wrap commands with timeout N sh -c '...' (coreutils) which
reliably kills the process and returns exit code 124. Added
shlex.quote for proper shell escaping and a secondary deadline (timeout + 10s) that force-stops the sandbox if the shell timeout somehow fails.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 13:12:41 -08:00
rovle
74a36b0729 docs: add Daytona to backend lists in docs
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:55:41 -08:00
rovle
efc7a7b957 fix(daytona): don't guess /root on cwd probe failure, keep constructor default; update tests to reflect this
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:49:35 -08:00
rovle
4f1464b3af fix(daytona): default disk to 10GB to match platform limit
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:37:30 -08:00
rovle
3a41079fac fix(daytona): add optional dependency group to pyproject.toml
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:13:12 -08:00
rovle
5279540bb4 fix(daytona): add missing config mappings in gateway, CLI defaults, and config display
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:12:50 -08:00
rovle
577da79a47 fix(daytona): make disk cap visible and use SDK enum for sandbox
state

- Replace logger.warning with warnings.warn for the disk cap so users
  actually see it (logger was suppressed by CLI's log level config)
- Use SandboxState enum instead of string literals in
_ensure_sandbox_ready

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:03:39 -08:00
rovle
1faa9648d3 chore(daytona): cap the disk size to current maximum on daytona sandboxes
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:43:41 -08:00
PercyDikec
ad57bf1e4b fix(cli): use correct dict key for codex auth file path in status output 2026-03-05 21:27:12 +03:00
rovle
d5efb82c7c test(daytona): add unit and integration tests for Daytona backend
Unit tests cover cwd resolution, sandbox persistence/resume, cleanup,
command execution, resource conversion, interrupt handling, retry
exhaustion, and sandbox readiness checks. Integration tests verify
basic commands, filesystem ops, session persistence, and task
isolation against a live Daytona API.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:26:22 -08:00
rovle
ea2f7ef2f6 docs(config): add Daytona disk limit hint and fix default cwd in example
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:02:22 -08:00
rovle
435530018b fix(daytona): resolve cwd by detecting home directory inside the sandbox 2026-03-05 10:02:22 -08:00
rovle
df61054a84 feat(cli): add Daytona to setup wizard, doctor, and status display
Add Daytona as a backend choice in the interactive setup wizard with
SDK installation and API key prompts. Show Daytona image in status
output and validate API key + SDK in doctor checks. Add OPTION 6
example in cli-config.yaml.example.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:02:22 -08:00
rovle
690b8bb563 feat(cli): add Daytona config mapping and env var sync
Wire TERMINAL_DAYTONA_IMAGE through cli.py env_mappings and
hermes_cli/config.py so `hermes config set` propagates correctly.
2026-03-05 10:02:21 -08:00
rovle
c43451a50b feat(terminal): integrate Daytona backend into tool pipeline
Add Daytona to image selection, container_config guards, environment
factory, requirements check, and diagnostics in terminal_tool.py and
file_tools.py. Also add to sandboxed-backend approval bypass.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:02:21 -08:00
rovle
1e312c6582 feat(environments): add Daytona cloud sandbox backend
New execution backend using the Daytona Python SDK. Supports persistent
sandboxes via stop/start lifecycle, interrupt handling, and automatic
retry on transient errors.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:02:21 -08:00
PercyDikec
e36c8cd49a fix: add missing re.DOTALL flag to DeepSeek V3 tool call parser 2026-03-05 20:32:38 +03:00
PercyDikec
16cb6d1a6e fix(gateway): return response from /retry handler instead of discarding it 2026-03-05 19:59:54 +03:00
Farukest
e25ad79d5d fix: use _max_tokens_param in max-iterations retry path
The retry summary in _handle_max_iterations hardcodes max_tokens instead
of calling _max_tokens_param(). For direct OpenAI API users (gpt-4o,
o-series), the correct parameter name is max_completion_tokens. The first
attempt at line 2697 already uses _max_tokens_param correctly but the
retry path at line 2743 was missed.
2026-03-05 17:49:37 +03:00
Farukest
82cb1752d9 fix(whatsapp): replace Linux-only fuser with cross-platform port cleanup
fuser command does not exist on Windows, causing orphaned bridge processes
to never be cleaned up. On crash recovery, the port stays occupied and the
next connect() fails with address-already-in-use.

Add _kill_port_process() helper that uses netstat+taskkill on Windows and
fuser on Linux/macOS. Replace both call sites in connect() and disconnect().
2026-03-05 17:13:14 +03:00
Dev User
3221818b6e fix: respect OPENAI_BASE_URL when resolving API key priority
When base_url points to a non-OpenRouter endpoint (e.g. Z.ai),
OPENROUTER_API_KEY incorrectly takes priority over OPENAI_API_KEY,
sending the wrong credentials. This causes 401 errors on the main
inference path and forces users to comment out OPENROUTER_API_KEY,
which then breaks auxiliary clients (compression, vision).

Fix: check whether base_url contains "openrouter" and swap the key
priority accordingly. Also adds GLM-4.7 and GLM-5 context lengths
to DEFAULT_CONTEXT_LENGTHS.
2026-03-05 08:25:16 +00:00
areu01or00
a1c25046a9 fix(timezone): add timezone-aware clock across agent, cron, and execute_code 2026-03-03 18:23:40 +05:30
BathreeNode
d10108f8ca fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
The ECMA schema directory was misspelled as 'fouth-edition'
instead of 'fourth-edition'. Renamed all 4 files within to
correct the path:

- opc-contentTypes.xsd
- opc-coreProperties.xsd
- opc-digSig.xsd
- opc-relationships.xsd
2026-03-03 09:21:28 +03:00
BathreeNode
8b520f9848 fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
The ECMA schema directory was misspelled as 'fouth-edition'
instead of 'fourth-edition'. Renamed all 4 files within to
correct the path:

- opc-contentTypes.xsd
- opc-coreProperties.xsd
- opc-digSig.xsd
- opc-relationships.xsd
2026-03-03 09:20:47 +03:00
BathreeNode
a718aed1be fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
The ECMA schema directory was misspelled as 'fouth-edition'
instead of 'fourth-edition'. Renamed all 4 files within to
correct the path:

- opc-contentTypes.xsd
- opc-coreProperties.xsd
- opc-digSig.xsd
- opc-relationships.xsd
2026-03-03 09:20:07 +03:00
BathreeNode
5f29e7b63c fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
The ECMA schema directory was misspelled as 'fouth-edition'
instead of 'fourth-edition'. Renamed all 4 files within to
correct the path:

- opc-contentTypes.xsd
- opc-coreProperties.xsd
- opc-digSig.xsd
- opc-relationships.xsd
2026-03-03 09:17:13 +03:00
aydnOktay
5fa3e24b76 Make process_registry checkpoint writes atomic 2026-03-03 02:44:01 +03:00
aydnOktay
ac6d747fa6 Make batch_runner checkpoint incremental and atomic 2026-03-03 01:43:07 +03:00
alireza78a
ee541c84f1 fix(cron): close lock_fd on failed flock to prevent fd leak 2026-03-03 02:09:56 +03:30
240 changed files with 50213 additions and 2217 deletions

View File

@@ -13,6 +13,38 @@ OPENROUTER_API_KEY=
# Examples: anthropic/claude-opus-4.6, openai/gpt-4o, google/gemini-3-flash-preview, zhipuai/glm-4-plus
LLM_MODEL=anthropic/claude-opus-4.6
# =============================================================================
# LLM PROVIDER (z.ai / GLM)
# =============================================================================
# z.ai provides access to ZhipuAI GLM models (GLM-4-Plus, etc.)
# Get your key at: https://z.ai or https://open.bigmodel.cn
GLM_API_KEY=
# GLM_BASE_URL=https://api.z.ai/api/paas/v4 # Override default base URL
# =============================================================================
# LLM PROVIDER (Kimi / Moonshot)
# =============================================================================
# Kimi Code provides access to Moonshot AI coding models (kimi-k2.5, etc.)
# Get your key at: https://platform.kimi.ai (Kimi Code console)
# Keys prefixed sk-kimi- use the Kimi Code API (api.kimi.com) by default.
# Legacy keys from platform.moonshot.ai need KIMI_BASE_URL override below.
KIMI_API_KEY=
# KIMI_BASE_URL=https://api.kimi.com/coding/v1 # Default for sk-kimi- keys
# KIMI_BASE_URL=https://api.moonshot.ai/v1 # For legacy Moonshot keys
# KIMI_BASE_URL=https://api.moonshot.cn/v1 # For Moonshot China keys
# =============================================================================
# LLM PROVIDER (MiniMax)
# =============================================================================
# MiniMax provides access to MiniMax models (global endpoint)
# Get your key at: https://www.minimax.io
MINIMAX_API_KEY=
# MINIMAX_BASE_URL=https://api.minimax.io/v1 # Override default base URL
# MiniMax China endpoint (for users in mainland China)
MINIMAX_CN_API_KEY=
# MINIMAX_CN_BASE_URL=https://api.minimaxi.com/v1 # Override default base URL
# =============================================================================
# TOOL API KEYS
# =============================================================================

4
.gitignore vendored
View File

@@ -47,4 +47,6 @@ cli-config.yaml
# Skills Hub state (lives in ~/.hermes/skills/.hub/ at runtime, but just in case)
skills/.hub/
ignored/
ignored/
.worktrees/
environments/benchmarks/evals/

View File

@@ -44,7 +44,8 @@ hermes-agent/
│ │ ├── docker.py # Docker container execution
│ │ ├── ssh.py # SSH remote execution
│ │ ├── singularity.py # Singularity/Apptainer + SIF management
│ │ ── modal.py # Modal cloud execution
│ │ ── modal.py # Modal cloud execution
│ │ └── daytona.py # Daytona cloud sandboxes
│ ├── terminal_tool.py # Terminal orchestration (sudo, lifecycle, factory)
│ ├── todo_tool.py # Planning & task management
│ ├── process_registry.py # Background process management
@@ -55,7 +56,9 @@ hermes-agent/
├── cron/ # Scheduler implementation
├── environments/ # RL training environments (Atropos integration)
├── skills/ # Bundled skill sources
├── optional-skills/ # Official optional skills (not activated by default)
├── cli.py # Interactive CLI orchestrator (HermesCLI class)
├── hermes_state.py # SessionDB — SQLite session store (schema, titles, FTS5 search)
├── run_agent.py # AIAgent class (core conversation loop)
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
├── toolsets.py # Tool groupings
@@ -96,7 +99,7 @@ The main agent is implemented in `run_agent.py`:
class AIAgent:
def __init__(
self,
model: str = "anthropic/claude-sonnet-4",
model: str = "anthropic/claude-sonnet-4.6",
api_key: str = None,
base_url: str = "https://openrouter.ai/api/v1",
max_iterations: int = 60, # Max tool-calling loops
@@ -202,7 +205,7 @@ Every installed skill in `~/.hermes/skills/` is automatically registered as a sl
The skill name (from frontmatter or folder name) becomes the command: `axolotl``/axolotl`.
Implementation (`agent/skill_commands.py`, shared between CLI and gateway):
1. `scan_skill_commands()` scans all SKILL.md files at startup
1. `scan_skill_commands()` scans all SKILL.md files at startup, filtering out skills incompatible with the current OS platform (via the `platforms` frontmatter field)
2. `build_skill_invocation_message()` loads the SKILL.md content and builds a user-turn message
3. The message includes the full skill content, a list of supporting files (not loaded), and the user's instruction
4. Supporting files can be loaded on demand via the `skill_view` tool
@@ -224,6 +227,10 @@ The unified `hermes` command provides all functionality:
|---------|-------------|
| `hermes` | Interactive chat (default) |
| `hermes chat -q "..."` | Single query mode |
| `hermes -c` / `hermes --continue` | Resume the most recent session |
| `hermes -c "my project"` | Resume a session by name (latest in lineage) |
| `hermes --resume <session_id>` | Resume a specific session by ID or title |
| `hermes -w` / `hermes --worktree` | Start in isolated git worktree (for parallel agents) |
| `hermes setup` | Configure API keys and settings |
| `hermes config` | View current configuration |
| `hermes config edit` | Open config in editor |
@@ -237,6 +244,8 @@ The unified `hermes` command provides all functionality:
| `hermes gateway` | Start gateway (messaging + cron scheduler) |
| `hermes gateway setup` | Configure messaging platforms interactively |
| `hermes gateway install` | Install gateway as system service |
| `hermes sessions list` | List past sessions (title, preview, last active) |
| `hermes sessions rename <id> <title>` | Rename/title a session |
| `hermes cron list` | View scheduled jobs |
| `hermes cron status` | Check if cron scheduler is running |
| `hermes version` | Show version info |
@@ -421,16 +430,19 @@ The system uses `_config_version` to detect outdated configs:
API keys are loaded from `~/.hermes/.env`:
- `OPENROUTER_API_KEY` - Main LLM API access (primary provider)
- `FIRECRAWL_API_KEY` - Web search/extract tools
- `FIRECRAWL_API_URL` - Self-hosted Firecrawl endpoint (optional)
- `BROWSERBASE_API_KEY` / `BROWSERBASE_PROJECT_ID` - Browser automation
- `FAL_KEY` - Image generation (FLUX model)
- `NOUS_API_KEY` - Vision and Mixture-of-Agents tools
Terminal tool configuration (in `~/.hermes/config.yaml`):
- `terminal.backend` - Backend: local, docker, singularity, modal, or ssh
- `terminal.backend` - Backend: local, docker, singularity, modal, daytona, or ssh
- `terminal.cwd` - Working directory ("." = host CWD for local only; for remote backends set an absolute path inside the target, or omit to use the backend's default)
- `terminal.docker_image` - Image for Docker backend
- `terminal.singularity_image` - Image for Singularity backend
- `terminal.modal_image` - Image for Modal backend
- `terminal.daytona_image` - Image for Daytona backend
- `DAYTONA_API_KEY` - API key for Daytona backend (in .env)
- SSH: `TERMINAL_SSH_HOST`, `TERMINAL_SSH_USER`, `TERMINAL_SSH_KEY` in .env
Agent behavior (in `~/.hermes/.env`):
@@ -494,7 +506,7 @@ terminal(command="pytest -v tests/", background=true)
- `process(action="submit", session_id="proc_abc123", data="yes")` -- send + Enter
**Key behaviors:**
- Background processes execute through the configured terminal backend (local/Docker/Modal/SSH/Singularity) -- never directly on the host unless `TERMINAL_ENV=local`
- Background processes execute through the configured terminal backend (local/Docker/Modal/Daytona/SSH/Singularity) -- never directly on the host unless `TERMINAL_ENV=local`
- The `wait` action blocks the tool call until the process finishes, times out, or is interrupted by a new user message
- PTY mode (`pty=true` on terminal) enables interactive CLI tools (Codex, Claude Code)
- In RL training, background processes are auto-killed when the episode ends (`tool_context.cleanup()`)
@@ -652,6 +664,7 @@ SKILL.md files use YAML frontmatter (agentskills.io format):
name: skill-name
description: Brief description for listing
version: 1.0.0
platforms: [macos] # Optional — restrict to specific OS (macos/linux/windows)
metadata:
hermes:
tags: [tag1, tag2]
@@ -660,16 +673,40 @@ metadata:
# Skill Content...
```
**Skills Hub** — user-driven skill search/install from online registries (GitHub, ClawHub, Claude marketplaces, LobeHub). Not exposed as an agent tool — the model cannot search for or install skills. Users manage skills via `hermes skills ...` CLI commands or the `/skills` slash command in chat.
**Platform filtering** — Skills with a `platforms` field are automatically excluded from the system prompt index, `skills_list()`, and slash commands on incompatible platforms. Skills without the field load everywhere (backward compatible). See `skills/apple/` for macOS-only examples (iMessage, Reminders, Notes, FindMy).
**Skills Hub** — user-driven skill search/install from online registries and official optional skills. Sources: official optional skills (shipped with repo, labeled "official"), GitHub (openai/skills, anthropics/skills, custom taps), ClawHub, Claude marketplace, LobeHub. Not exposed as an agent tool — the model cannot search for or install skills. Users manage skills via `hermes skills browse/search/install` CLI commands or the `/skills` slash command in chat.
Key files:
- `tools/skills_tool.py` — Agent-facing skill list/view (progressive disclosure)
- `tools/skills_guard.py` — Security scanner (regex + LLM audit, trust-aware install policy)
- `tools/skills_hub.py` — Source adapters (GitHub, ClawHub, Claude marketplace, LobeHub), lock file, auth
- `tools/skills_hub.py` — Source adapters (OptionalSkillSource, GitHub, ClawHub, Claude marketplace, LobeHub), lock file, auth
- `hermes_cli/skills_hub.py` — CLI subcommands + `/skills` slash command handler
---
## Known Pitfalls
### DO NOT use `simple_term_menu` for interactive menus
`simple_term_menu` has rendering bugs in tmux, iTerm2, and other non-standard terminals. When the user scrolls with arrow keys, previously highlighted items "ghost" — duplicating upward and corrupting the display. This happens because the library uses ANSI cursor-up codes to redraw in place, and tmux/iTerm miscalculate positions when the menu is near the bottom of the viewport.
**Rule:** All interactive menus in `hermes_cli/` must use `curses` (Python stdlib) instead. See `tools_config.py` for the pattern — both `_prompt_choice()` (single-select) and `_prompt_toolset_checklist()` (multi-select with space toggle) use `curses.wrapper()`. The numbered-input fallback handles Windows where curses isn't available.
### DO NOT use `\033[K` (ANSI erase-to-EOL) in spinner/display code
The ANSI escape `\033[K` leaks as literal `?[K` text when `prompt_toolkit`'s `patch_stdout` is active. Use space-padding instead to clear lines: `f"\r{line}{' ' * pad}"`. See `agent/display.py` `KawaiiSpinner`.
### `_last_resolved_tool_names` is a process-global in `model_tools.py`
The `execute_code` sandbox uses `_last_resolved_tool_names` (set by `get_tool_definitions()`) to decide which tool stubs to generate. When subagents run with restricted toolsets, they overwrite this global. After delegation returns to the parent, `execute_code` may see the child's restricted list instead of the parent's full list. This is a known bug — `execute_code` calls after delegation may fail with `ImportError: cannot import name 'patch' from 'hermes_tools'`.
### Tests must not write to `~/.hermes/`
The `autouse` fixture `_isolate_hermes_home` in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Every test runs in isolation. If you add a test that creates `AIAgent` instances or writes session logs, the fixture handles cleanup automatically. Never hardcode `~/.hermes/` paths in tests.
---
## Testing Changes
After making changes:

View File

@@ -43,7 +43,9 @@ Bundled skills (in `skills/`) ship with every Hermes install. They should be **b
- Document handling, web research, common dev workflows, system administration
- Used regularly by a wide range of people
If your skill is specialized (a niche engineering tool, a specific SaaS integration, a game), it's better suited for a **Skills Hub**upload it to a skills registry and share it in the [Nous Research Discord](https://discord.gg/NousResearch). Users can install it with `hermes skills install`.
If your skill is official and useful but not universally needed (e.g., a paid service integration, a heavyweight dependency), put it in **`optional-skills/`** — it ships with the repo but isn't activated by default. Users can discover it via `hermes skills browse` (labeled "official") and install it with `hermes skills install` (no third-party warning, builtin trust).
If your skill is specialized, community-contributed, or niche, it's better suited for a **Skills Hub** — upload it to a skills registry and share it in the [Nous Research Discord](https://discord.gg/NousResearch). Users can install it with `hermes skills install`.
---
@@ -116,7 +118,7 @@ hermes-agent/
├── cli.py # HermesCLI class — interactive TUI, prompt_toolkit integration
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
├── toolsets.py # Tool groupings and presets (hermes-cli, hermes-telegram, etc.)
├── hermes_state.py # SQLite session database with FTS5 full-text search
├── hermes_state.py # SQLite session database with FTS5 full-text search, session titles
├── batch_runner.py # Parallel batch processing for trajectory generation
├── agent/ # Agent internals (extracted modules)
@@ -153,7 +155,7 @@ hermes-agent/
│ ├── skill_tools.py # Skill search, load, manage
│ └── environments/ # Terminal execution backends
│ ├── base.py # BaseEnvironment ABC
│ ├── local.py, docker.py, ssh.py, singularity.py, modal.py
│ ├── local.py, docker.py, ssh.py, singularity.py, modal.py, daytona.py
├── gateway/ # Messaging gateway
│ ├── run.py # GatewayRunner — platform lifecycle, message routing, cron
@@ -168,6 +170,7 @@ hermes-agent/
│ └── whatsapp-bridge/ # Node.js WhatsApp bridge (Baileys)
├── skills/ # Bundled skills (copied to ~/.hermes/skills/ on install)
├── optional-skills/ # Official optional skills (discoverable via hub, not activated by default)
├── environments/ # RL training environments (Atropos integration)
├── tests/ # Test suite
├── website/ # Documentation site (hermes-agent.nousresearch.com)
@@ -215,7 +218,7 @@ User message → AIAgent._run_agent_loop()
- **Self-registering tools**: Each tool file calls `registry.register()` at import time. `model_tools.py` triggers discovery by importing all tool modules.
- **Toolset grouping**: Tools are grouped into toolsets (`web`, `terminal`, `file`, `browser`, etc.) that can be enabled/disabled per platform.
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search. JSON logs go to `~/.hermes/sessions/`.
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search and unique session titles. JSON logs go to `~/.hermes/sessions/`.
- **Ephemeral injection**: System prompts and prefill messages are injected at API call time, never persisted to the database or logs.
- **Provider abstraction**: The agent works with any OpenAI-compatible API. Provider resolution happens at init time (Nous Portal OAuth, OpenRouter API key, or custom endpoint).
- **Provider routing**: When using OpenRouter, `provider_routing` in config.yaml controls provider selection (sort by throughput/latency/price, allow/ignore specific providers, data retention policies). These are injected as `extra_body.provider` in API requests.
@@ -294,9 +297,9 @@ If it's a new toolset, add it to `toolsets.py` and to the relevant platform pres
---
## Adding a Bundled Skill
## Adding a Skill
Bundled skills live in `skills/` organized by category:
Bundled skills live in `skills/` organized by category. Official optional skills use the same structure in `optional-skills/`:
```
skills/
@@ -322,6 +325,9 @@ description: Brief description (shown in skill search results)
version: 1.0.0
author: Your Name
license: MIT
platforms: [macos, linux] # Optional — restrict to specific OS platforms
# Valid: macos, linux, windows
# Omit to load on all platforms (default)
metadata:
hermes:
tags: [Category, Subcategory, Keywords]
@@ -348,6 +354,18 @@ Known failure modes and how to handle them.
How the agent confirms it worked.
```
### Platform-specific skills
Skills can declare which OS platforms they support via the `platforms` frontmatter field. Skills with this field are automatically hidden from the system prompt, `skills_list()`, and slash commands on incompatible platforms.
```yaml
platforms: [macos] # macOS only (e.g., iMessage, Apple Reminders)
platforms: [macos, linux] # macOS and Linux
platforms: [windows] # Windows only
```
If the field is omitted or empty, the skill loads on all platforms (backward compatible). See `skills/apple/` for examples of macOS-only skills.
### Skill guidelines
- **No external dependencies unless absolutely necessary.** Prefer stdlib Python, curl, and existing Hermes tools (`web_extract`, `terminal`, `read_file`).

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Nous Research
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -11,17 +11,17 @@
<a href="https://nousresearch.com"><img src="https://img.shields.io/badge/Built%20by-Nous%20Research-blueviolet?style=for-the-badge" alt="Built by Nous Research"></a>
</p>
**The fully open-source AI agent that grows with you.** Install it on a machine, give it your messaging accounts, and it becomes a persistent personal agent — learning your projects, building its own skills, running tasks on a schedule, and reaching you wherever you are.
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai), OpenAI Codex, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
<table>
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
<tr><td><b>Lives where you do</b></td><td>Telegram, Discord, Slack, WhatsApp, and CLI — all from a single gateway process. Voice memo transcription, cross-platform conversation continuity.</td></tr>
<tr><td><b>Grows the longer it runs</b></td><td>Persistent memory across sessions. When it solves a hard problem, it writes a skill document for next time. Skills are searchable, shareable, and compatible with the <a href="https://agentskills.io">agentskills.io</a> open standard.</td></tr>
<tr><td><b>A closed learning loop</b></td><td>Agent-curated memory with periodic nudges. Autonomous skill creation after complex tasks. Skills self-improve during use. FTS5 session search with LLM summarization for cross-session recall. <a href="https://github.com/plastic-labs/honcho">Honcho</a> dialectic user modeling. Compatible with the <a href="https://agentskills.io">agentskills.io</a> open standard.</td></tr>
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
<tr><td><b>Real sandboxing</b></td><td>Five terminal backends — local, Docker, SSH, Singularity, and Modal — with persistent workspaces and container security hardening.</td></tr>
<tr><td><b>Runs anywhere, not just your laptop</b></td><td>Six terminal backends — local, Docker, SSH, Daytona, Singularity, and Modal. Daytona and Modal offer serverless persistence — your agent's environment hibernates when idle and wakes on demand, costing nearly nothing between sessions. Run it on a $5 VPS or a GPU cluster.</td></tr>
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, Atropos RL environments, trajectory compression for training the next generation of tool-calling models.</td></tr>
</table>

129
TODO.md
View File

@@ -1,129 +0,0 @@
# Hermes Agent - Future Improvements
---
## 3. Local Browser Control via CDP 🌐
**Status:** Not started (currently Browserbase cloud only)
**Priority:** Medium
Support local Chrome/Chromium via Chrome DevTools Protocol alongside existing Browserbase cloud backend.
**What other agents do:**
- **OpenClaw**: Full CDP-based Chrome control with snapshots, actions, uploads, profiles, file chooser, PDF save, console messages, tab management. Uses local Chrome for persistent login sessions.
- **Cline**: Headless browser with Computer Use (click, type, scroll, screenshot, console logs)
**Our approach:**
- Add a `local` backend option to `browser_tool.py` using Playwright or raw CDP
- Config toggle: `browser.backend: local | browserbase | auto`
- `auto` mode: try local first, fall back to Browserbase
- Local advantages: free, persistent login sessions, no API key needed
- Local disadvantages: no CAPTCHA solving, no stealth mode, requires Chrome installed
- Reuse the same 10-tool interface -- just swap the backend
- Later: Chrome profile management for persistent sessions across restarts
---
## 4. Signal Integration 📡
**Status:** Not started
**Priority:** Low
New platform adapter using signal-cli daemon (JSON-RPC HTTP + SSE). Requires Java runtime and phone number registration.
**Reference:** OpenClaw has Signal support via signal-cli.
---
## 5. Plugin/Extension System 🔌
**Status:** Partially implemented (event hooks exist in `gateway/hooks.py`)
**Priority:** Medium
Full Python plugin interface that goes beyond the current hook system.
**What other agents do:**
- **OpenClaw**: Plugin SDK with tool-send capabilities, lifecycle phase hooks (before-agent-start, after-tool-call, model-override), plugin registry with install/uninstall.
- **Pi**: Extensions are TypeScript modules that can register tools, commands, keyboard shortcuts, custom UI widgets, overlays, status lines, dialogs, compaction hooks, raw terminal input listeners. Extremely comprehensive.
- **OpenCode**: MCP client support (stdio, SSE, StreamableHTTP), OAuth auth for MCP servers. Also has Copilot/Codex plugins.
- **Codex**: Full MCP integration with skill dependencies.
- **Cline**: MCP integration + lifecycle hooks with cancellation support.
**Our approach (phased):**
### Phase 1: Enhanced hooks
- Expand the existing `gateway/hooks.py` to support more events: `before-tool-call`, `after-tool-call`, `before-response`, `context-compress`, `session-end`
- Allow hooks to modify tool results (e.g., filter sensitive output)
### Phase 2: Plugin interface
- `~/.hermes/plugins/<name>/plugin.yaml` + `handler.py`
- Plugins can: register new tools, add CLI commands, subscribe to events, inject system prompt sections
- `hermes plugin list|install|uninstall|create` CLI commands
- Plugin discovery and validation on startup
### Phase 3: MCP support (industry standard) ✅ DONE
- ✅ MCP client that connects to external MCP servers (stdio + HTTP/StreamableHTTP)
- ✅ Config: `mcp_servers` in config.yaml with connection details
- ✅ Each MCP server's tools auto-registered as a dynamic toolset
- Future: Resources, Prompts, Progress notifications, `hermes mcp` CLI command
---
## 6. MCP (Model Context Protocol) Support 🔗 ✅ DONE
**Status:** Implemented (PR #301)
**Priority:** Complete
Native MCP client support with stdio and HTTP/StreamableHTTP transports, auto-discovery, reconnection with exponential backoff, env var filtering, and credential stripping. See `docs/mcp.md` for full documentation.
**Still TODO:**
- `hermes mcp` CLI subcommand (list/test/status)
- `hermes tools` UI integration for MCP toolsets
- MCP Resources and Prompts support
- OAuth authentication for remote servers
- Progress notifications for long-running tools
---
## 8. Filesystem Checkpointing / Rollback 🔄
**Status:** Not started
**Priority:** Low-Medium
Automatic filesystem snapshots after each agent loop iteration so the user can roll back destructive changes to their project.
**What other agents do:**
- **Cline**: Workspace checkpoints at each step with Compare/Restore UI
- **OpenCode**: Git-backed workspace snapshots per step, with weekly gc
- **Codex**: Sandboxed execution with commit-per-step, rollback on failure
**Our approach:**
- After each tool call (or batch of tool calls in a single turn) that modifies files, create a lightweight checkpoint of the affected files
- Git-based when the project is a repo: auto-commit to a detached/temporary branch (`hermes/checkpoints/<session>`) after each agent turn, squash or discard on session end
- Non-git fallback: tar snapshots of changed files in `~/.hermes/checkpoints/<session_id>/`
- `hermes rollback` CLI command to restore to a previous checkpoint
- Agent-accessible via a `checkpoint` tool: `list` (show available restore points), `restore` (roll back to a named point), `diff` (show what changed since a checkpoint)
- Configurable: off by default (opt-in via `config.yaml`), since auto-committing can be surprising
- Cleanup: checkpoints expire after session ends (or configurable retention period)
- Integration with the terminal backend: works with local, SSH, and Docker backends (snapshots happen on the execution host)
---
## Implementation Priority Order
### Tier 1: Next Up
1. ~~MCP Support -- #6~~ ✅ Done (PR #301)
### Tier 2: Quality of Life
3. Local Browser Control via CDP -- #3
4. Plugin/Extension System -- #5
### Tier 3: Nice to Have
5. Session Branching / Checkpoints -- #7
6. Filesystem Checkpointing / Rollback -- #8
7. Signal Integration -- #4

View File

@@ -10,7 +10,9 @@ Resolution order for text tasks:
3. Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY)
4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex,
wrapped to look like a chat.completions client)
5. None
5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN)
— checked via PROVIDER_REGISTRY entries with auth_type='api_key'
6. None
Resolution order for vision/multimodal tasks:
1. OpenRouter
@@ -31,6 +33,14 @@ from hermes_constants import OPENROUTER_BASE_URL
logger = logging.getLogger(__name__)
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
"zai": "glm-4.5-flash",
"kimi-coding": "kimi-k2-turbo-preview",
"minimax": "MiniMax-M2.5-highspeed",
"minimax-cn": "MiniMax-M2.5-highspeed",
}
# OpenRouter app attribution headers
_OR_HEADERS = {
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
@@ -282,12 +292,58 @@ def _read_codex_access_token() -> Optional[str]:
return None
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Try each API-key provider in PROVIDER_REGISTRY order.
Returns (client, model) for the first provider whose env var is set,
or (None, None) if none are configured.
"""
try:
from hermes_cli.auth import PROVIDER_REGISTRY
except ImportError:
logger.debug("Could not import PROVIDER_REGISTRY for API-key fallback")
return None, None
for provider_id, pconfig in PROVIDER_REGISTRY.items():
if pconfig.auth_type != "api_key":
continue
# Check if any of the provider's env vars are set
api_key = ""
for env_var in pconfig.api_key_env_vars:
val = os.getenv(env_var, "").strip()
if val:
api_key = val
break
if not api_key:
continue
# Resolve base URL (with optional env-var override)
# Kimi Code keys (sk-kimi-) need api.kimi.com/coding/v1
env_url = ""
if pconfig.base_url_env_var:
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
if env_url:
base_url = env_url.rstrip("/")
elif provider_id == "kimi-coding" and api_key.startswith("sk-kimi-"):
base_url = "https://api.kimi.com/coding/v1"
else:
base_url = pconfig.inference_base_url
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
extra = {}
if "api.kimi.com" in base_url.lower():
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
return None, None
# ── Public API ──────────────────────────────────────────────────────────────
def get_text_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Return (client, model_slug) for text-only auxiliary tasks.
Falls through OpenRouter -> Nous Portal -> custom endpoint -> Codex OAuth -> (None, None).
Falls through OpenRouter -> Nous Portal -> custom endpoint -> Codex OAuth
-> direct API-key providers -> (None, None).
"""
# 1. OpenRouter
or_key = os.getenv("OPENROUTER_API_KEY")
@@ -323,7 +379,12 @@ def get_text_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
# 5. Nothing available
# 5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, etc.)
api_client, api_model = _resolve_api_key_provider()
if api_client is not None:
return api_client, api_model
# 6. Nothing available
logger.debug("Auxiliary text client: none available")
return None, None
@@ -350,6 +411,8 @@ def get_async_text_auxiliary_client():
}
if "openrouter" in str(sync_client.base_url).lower():
async_kwargs["default_headers"] = dict(_OR_HEADERS)
elif "api.kimi.com" in str(sync_client.base_url).lower():
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
return AsyncOpenAI(**async_kwargs), model

View File

@@ -7,7 +7,7 @@ protecting head and tail context.
import logging
import os
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from agent.auxiliary_client import get_text_auxiliary_client
from agent.model_metadata import (
@@ -34,17 +34,20 @@ class ContextCompressor:
summary_target_tokens: int = 2500,
quiet_mode: bool = False,
summary_model_override: str = None,
base_url: str = "",
):
self.model = model
self.base_url = base_url
self.threshold_percent = threshold_percent
self.protect_first_n = protect_first_n
self.protect_last_n = protect_last_n
self.summary_target_tokens = summary_target_tokens
self.quiet_mode = quiet_mode
self.context_length = get_model_context_length(model)
self.context_length = get_model_context_length(model, base_url=base_url)
self.threshold_tokens = int(self.context_length * threshold_percent)
self.compression_count = 0
self._context_probed = False # True after a step-down from context error
self.last_prompt_tokens = 0
self.last_completion_tokens = 0
@@ -79,11 +82,14 @@ class ContextCompressor:
"compression_count": self.compression_count,
}
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> str:
"""Generate a concise summary of conversation turns using a fast model."""
if not self.client:
return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed to save space. The assistant performed various actions and received responses."
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
"""Generate a concise summary of conversation turns.
Tries the auxiliary model first, then falls back to the user's main
model. Returns None if all attempts fail — the caller should drop
the middle turns without a summary rather than inject a useless
placeholder.
"""
parts = []
for msg in turns_to_summarize:
role = msg.get("role", "unknown")
@@ -114,28 +120,28 @@ TURNS TO SUMMARIZE:
Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
try:
return self._call_summary_model(self.client, self.summary_model, prompt)
except Exception as e:
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
# 1. Try the auxiliary model (cheap/fast)
if self.client:
try:
return self._call_summary_model(self.client, self.summary_model, prompt)
except Exception as e:
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
# Fallback: try the main model's endpoint. This handles the common
# case where the user switched providers (e.g. OpenRouter → local LLM)
# but a stale API key causes the auxiliary client to pick the old
# provider which then fails (402, auth error, etc.).
fallback_client, fallback_model = self._get_fallback_client()
if fallback_client is not None:
try:
logger.info("Retrying context summary with fallback client (%s)", fallback_model)
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
# Success — swap in the working client for future compressions
self.client = fallback_client
self.summary_model = fallback_model
return summary
except Exception as fallback_err:
logging.warning(f"Fallback summary model also failed: {fallback_err}")
# 2. Fallback: try the user's main model endpoint
fallback_client, fallback_model = self._get_fallback_client()
if fallback_client is not None:
try:
logger.info("Retrying context summary with main model (%s)", fallback_model)
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
self.client = fallback_client
self.summary_model = fallback_model
return summary
except Exception as fallback_err:
logging.warning(f"Main model summary also failed: {fallback_err}")
return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed. The assistant performed tool calls and received responses."
# 3. All models failed — return None so the caller drops turns without a summary
logging.warning("Context compression: no model available for summary. Middle turns will be dropped without summary.")
return None
def _call_summary_model(self, client, model: str, prompt: str) -> str:
"""Make the actual LLM call to generate a summary. Raises on failure."""
@@ -193,10 +199,111 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
logger.debug("Could not build fallback auxiliary client: %s", exc)
return None, None
# ------------------------------------------------------------------
# Tool-call / tool-result pair integrity helpers
# ------------------------------------------------------------------
@staticmethod
def _get_tool_call_id(tc) -> str:
"""Extract the call ID from a tool_call entry (dict or SimpleNamespace)."""
if isinstance(tc, dict):
return tc.get("id", "")
return getattr(tc, "id", "") or ""
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fix orphaned tool_call / tool_result pairs after compression.
Two failure modes:
1. A tool *result* references a call_id whose assistant tool_call was
removed (summarized/truncated). The API rejects this with
"No tool call found for function call output with call_id ...".
2. An assistant message has tool_calls whose results were dropped.
The API rejects this because every tool_call must be followed by
a tool result with the matching call_id.
This method removes orphaned results and inserts stub results for
orphaned calls so the message list is always well-formed.
"""
surviving_call_ids: set = set()
for msg in messages:
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = self._get_tool_call_id(tc)
if cid:
surviving_call_ids.add(cid)
result_call_ids: set = set()
for msg in messages:
if msg.get("role") == "tool":
cid = msg.get("tool_call_id")
if cid:
result_call_ids.add(cid)
# 1. Remove tool results whose call_id has no matching assistant tool_call
orphaned_results = result_call_ids - surviving_call_ids
if orphaned_results:
messages = [
m for m in messages
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
]
if not self.quiet_mode:
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
# 2. Add stub results for assistant tool_calls whose results were dropped
missing_results = surviving_call_ids - result_call_ids
if missing_results:
patched: List[Dict[str, Any]] = []
for msg in messages:
patched.append(msg)
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = self._get_tool_call_id(tc)
if cid in missing_results:
patched.append({
"role": "tool",
"content": "[Result from earlier conversation — see context summary above]",
"tool_call_id": cid,
})
messages = patched
if not self.quiet_mode:
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
return messages
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
"""Push a compress-start boundary forward past any orphan tool results.
If ``messages[idx]`` is a tool result, slide forward until we hit a
non-tool message so we don't start the summarised region mid-group.
"""
while idx < len(messages) and messages[idx].get("role") == "tool":
idx += 1
return idx
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
"""Pull a compress-end boundary backward to avoid splitting a
tool_call / result group.
If the message just before ``idx`` is an assistant message with
tool_calls, those tool results will start at ``idx`` and would be
separated from their parent. Move backwards to include the whole
group in the summarised region.
"""
if idx <= 0 or idx >= len(messages):
return idx
prev = messages[idx - 1]
if prev.get("role") == "assistant" and prev.get("tool_calls"):
# The results for this assistant turn sit at idx..idx+k.
# Include the assistant message in the summarised region too.
idx -= 1
return idx
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
"""Compress conversation messages by summarizing middle turns.
Keeps first N + last N turns, summarizes everything in between.
After compression, orphaned tool_call / tool_result pairs are cleaned
up so the API never receives mismatched IDs.
"""
n_messages = len(messages)
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
@@ -209,6 +316,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
if compress_start >= compress_end:
return messages
# Adjust boundaries to avoid splitting tool_call/result groups.
compress_start = self._align_boundary_forward(messages, compress_start)
compress_end = self._align_boundary_backward(messages, compress_end)
if compress_start >= compress_end:
return messages
turns_to_summarize = messages[compress_start:compress_end]
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
@@ -216,24 +329,6 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
# Truncation fallback when no auxiliary model is available
if self.client is None:
print("⚠️ Context compression: no auxiliary model available. Falling back to message truncation.")
# Keep system message(s) at the front and the protected tail;
# simply drop the oldest non-system messages until under threshold.
kept = []
for msg in messages:
if msg.get("role") == "system":
kept.append(msg.copy())
else:
break
tail = messages[-self.protect_last_n:]
kept.extend(m.copy() for m in tail)
self.compression_count += 1
if not self.quiet_mode:
print(f" ✂️ Truncated: {len(messages)}{len(kept)} messages (dropped middle turns)")
return kept
if not self.quiet_mode:
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
@@ -246,13 +341,19 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
msg["content"] = (msg.get("content") or "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
compressed.append(msg)
compressed.append({"role": "user", "content": summary})
if summary:
compressed.append({"role": "user", "content": summary})
else:
if not self.quiet_mode:
print(" ⚠️ No summary model available — middle turns dropped without summary")
for i in range(compress_end, n_messages):
compressed.append(messages[i].copy())
self.compression_count += 1
compressed = self._sanitize_tool_pairs(compressed)
if not self.quiet_mode:
new_estimate = estimate_messages_tokens_rough(compressed)
saved_estimate = display_tokens - new_estimate

818
agent/insights.py Normal file
View File

@@ -0,0 +1,818 @@
"""
Session Insights Engine for Hermes Agent.
Analyzes historical session data from the SQLite state database to produce
comprehensive usage insights — token consumption, cost estimates, tool usage
patterns, activity trends, model/platform breakdowns, and session metrics.
Inspired by Claude Code's /insights command, adapted for Hermes Agent's
multi-platform architecture with additional cost estimation and platform
breakdown capabilities.
Usage:
from agent.insights import InsightsEngine
engine = InsightsEngine(db)
report = engine.generate(days=30)
print(engine.format_terminal(report))
"""
import json
import time
from collections import Counter, defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional
# =========================================================================
# Model pricing (USD per million tokens) — approximate as of early 2026
# =========================================================================
MODEL_PRICING = {
# OpenAI
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"gpt-4.1": {"input": 2.00, "output": 8.00},
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
"gpt-5": {"input": 10.00, "output": 30.00},
"gpt-5.4": {"input": 10.00, "output": 30.00},
"o3": {"input": 10.00, "output": 40.00},
"o3-mini": {"input": 1.10, "output": 4.40},
"o4-mini": {"input": 1.10, "output": 4.40},
# Anthropic
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
# DeepSeek
"deepseek-chat": {"input": 0.14, "output": 0.28},
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
# Google
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
# Meta (via providers)
"llama-4-maverick": {"input": 0.50, "output": 0.70},
"llama-4-scout": {"input": 0.20, "output": 0.30},
# Z.AI / GLM (direct provider — pricing not published externally, treat as local)
"glm-5": {"input": 0.0, "output": 0.0},
"glm-4.7": {"input": 0.0, "output": 0.0},
"glm-4.5": {"input": 0.0, "output": 0.0},
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
# Kimi / Moonshot (direct provider — pricing not published externally, treat as local)
"kimi-k2.5": {"input": 0.0, "output": 0.0},
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
# MiniMax (direct provider — pricing not published externally, treat as local)
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
}
# Fallback: unknown/custom models get zero cost (we can't assume pricing
# for self-hosted models, custom OAI endpoints, local inference, etc.)
_DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
def _has_known_pricing(model_name: str) -> bool:
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
return _get_pricing(model_name) is not _DEFAULT_PRICING
def _get_pricing(model_name: str) -> Dict[str, float]:
"""Look up pricing for a model. Uses fuzzy matching on model name.
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
we can't assume costs for self-hosted endpoints, local inference, etc.
"""
if not model_name:
return _DEFAULT_PRICING
# Strip provider prefix (e.g., "anthropic/claude-..." -> "claude-...")
bare = model_name.split("/")[-1].lower()
# Exact match first
if bare in MODEL_PRICING:
return MODEL_PRICING[bare]
# Fuzzy prefix match — prefer the LONGEST matching key to avoid
# e.g. "gpt-4o" matching before "gpt-4o-mini" for "gpt-4o-mini-2024-07-18"
best_match = None
best_len = 0
for key, price in MODEL_PRICING.items():
if bare.startswith(key) and len(key) > best_len:
best_match = price
best_len = len(key)
if best_match:
return best_match
# Keyword heuristics (checked in most-specific-first order)
if "opus" in bare:
return {"input": 15.00, "output": 75.00}
if "sonnet" in bare:
return {"input": 3.00, "output": 15.00}
if "haiku" in bare:
return {"input": 0.80, "output": 4.00}
if "gpt-4o-mini" in bare:
return {"input": 0.15, "output": 0.60}
if "gpt-4o" in bare:
return {"input": 2.50, "output": 10.00}
if "gpt-5" in bare:
return {"input": 10.00, "output": 30.00}
if "deepseek" in bare:
return {"input": 0.14, "output": 0.28}
if "gemini" in bare:
return {"input": 0.15, "output": 0.60}
return _DEFAULT_PRICING
def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
"""Estimate the USD cost for a given model and token counts."""
pricing = _get_pricing(model)
return (input_tokens * pricing["input"] + output_tokens * pricing["output"]) / 1_000_000
def _format_duration(seconds: float) -> str:
"""Format seconds into a human-readable duration string."""
if seconds < 60:
return f"{seconds:.0f}s"
minutes = seconds / 60
if minutes < 60:
return f"{minutes:.0f}m"
hours = minutes / 60
if hours < 24:
remaining_min = int(minutes % 60)
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
days = hours / 24
return f"{days:.1f}d"
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
"""Create simple horizontal bar chart strings from values."""
peak = max(values) if values else 1
if peak == 0:
return ["" for _ in values]
return ["" * max(1, int(v / peak * max_width)) if v > 0 else "" for v in values]
class InsightsEngine:
"""
Analyzes session history and produces usage insights.
Works directly with a SessionDB instance (or raw sqlite3 connection)
to query session and message data.
"""
def __init__(self, db):
"""
Initialize with a SessionDB instance.
Args:
db: A SessionDB instance (from hermes_state.py)
"""
self.db = db
self._conn = db._conn
def generate(self, days: int = 30, source: str = None) -> Dict[str, Any]:
"""
Generate a complete insights report.
Args:
days: Number of days to look back (default: 30)
source: Optional filter by source platform
Returns:
Dict with all computed insights
"""
cutoff = time.time() - (days * 86400)
# Gather raw data
sessions = self._get_sessions(cutoff, source)
tool_usage = self._get_tool_usage(cutoff, source)
message_stats = self._get_message_stats(cutoff, source)
if not sessions:
return {
"days": days,
"source_filter": source,
"empty": True,
"overview": {},
"models": [],
"platforms": [],
"tools": [],
"activity": {},
"top_sessions": [],
}
# Compute insights
overview = self._compute_overview(sessions, message_stats)
models = self._compute_model_breakdown(sessions)
platforms = self._compute_platform_breakdown(sessions)
tools = self._compute_tool_breakdown(tool_usage)
activity = self._compute_activity_patterns(sessions)
top_sessions = self._compute_top_sessions(sessions)
return {
"days": days,
"source_filter": source,
"empty": False,
"generated_at": time.time(),
"overview": overview,
"models": models,
"platforms": platforms,
"tools": tools,
"activity": activity,
"top_sessions": top_sessions,
}
# =========================================================================
# Data gathering (SQL queries)
# =========================================================================
# Columns we actually need (skip system_prompt, model_config blobs)
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
"message_count, tool_call_count, input_tokens, output_tokens")
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
"""Fetch sessions within the time window."""
if source:
cursor = self._conn.execute(
f"""SELECT {self._SESSION_COLS} FROM sessions
WHERE started_at >= ? AND source = ?
ORDER BY started_at DESC""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
f"""SELECT {self._SESSION_COLS} FROM sessions
WHERE started_at >= ?
ORDER BY started_at DESC""",
(cutoff,),
)
return [dict(row) for row in cursor.fetchall()]
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
"""Get tool call counts from messages.
Uses two sources:
1. tool_name column on 'tool' role messages (set by gateway)
2. tool_calls JSON on 'assistant' role messages (covers CLI where
tool_name is not populated on tool responses)
"""
tool_counts = Counter()
# Source 1: explicit tool_name on tool response messages
if source:
cursor = self._conn.execute(
"""SELECT m.tool_name, COUNT(*) as count
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?
AND m.role = 'tool' AND m.tool_name IS NOT NULL
GROUP BY m.tool_name
ORDER BY count DESC""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"""SELECT m.tool_name, COUNT(*) as count
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?
AND m.role = 'tool' AND m.tool_name IS NOT NULL
GROUP BY m.tool_name
ORDER BY count DESC""",
(cutoff,),
)
for row in cursor.fetchall():
tool_counts[row["tool_name"]] += row["count"]
# Source 2: extract from tool_calls JSON on assistant messages
# (covers CLI sessions where tool_name is NULL on tool responses)
if source:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff, source),
)
else:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff,),
)
tool_calls_counts = Counter()
for row in cursor2.fetchall():
try:
calls = row["tool_calls"]
if isinstance(calls, str):
calls = json.loads(calls)
if isinstance(calls, list):
for call in calls:
func = call.get("function", {}) if isinstance(call, dict) else {}
name = func.get("name")
if name:
tool_calls_counts[name] += 1
except (json.JSONDecodeError, TypeError, AttributeError):
continue
# Merge: prefer tool_name source, supplement with tool_calls source
# for tools not already counted
if not tool_counts and tool_calls_counts:
# No tool_name data at all — use tool_calls exclusively
tool_counts = tool_calls_counts
elif tool_counts and tool_calls_counts:
# Both sources have data — use whichever has the higher count per tool
# (they may overlap, so take the max to avoid double-counting)
all_tools = set(tool_counts) | set(tool_calls_counts)
merged = Counter()
for tool in all_tools:
merged[tool] = max(tool_counts.get(tool, 0), tool_calls_counts.get(tool, 0))
tool_counts = merged
# Convert to the expected format
return [
{"tool_name": name, "count": count}
for name, count in tool_counts.most_common()
]
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
"""Get aggregate message statistics."""
if source:
cursor = self._conn.execute(
"""SELECT
COUNT(*) as total_messages,
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"""SELECT
COUNT(*) as total_messages,
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?""",
(cutoff,),
)
row = cursor.fetchone()
return dict(row) if row else {
"total_messages": 0, "user_messages": 0,
"assistant_messages": 0, "tool_messages": 0,
}
# =========================================================================
# Computation
# =========================================================================
def _compute_overview(self, sessions: List[Dict], message_stats: Dict) -> Dict:
"""Compute high-level overview statistics."""
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
total_tokens = total_input + total_output
total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions)
total_messages = sum(s.get("message_count") or 0 for s in sessions)
# Cost estimation (weighted by model)
total_cost = 0.0
models_with_pricing = set()
models_without_pricing = set()
for s in sessions:
model = s.get("model") or ""
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
total_cost += _estimate_cost(model, inp, out)
display = model.split("/")[-1] if "/" in model else (model or "unknown")
if _has_known_pricing(model):
models_with_pricing.add(display)
else:
models_without_pricing.add(display)
# Session duration stats (guard against negative durations from clock drift)
durations = []
for s in sessions:
start = s.get("started_at")
end = s.get("ended_at")
if start and end and end > start:
durations.append(end - start)
total_hours = sum(durations) / 3600 if durations else 0
avg_duration = sum(durations) / len(durations) if durations else 0
# Earliest and latest session
started_timestamps = [s["started_at"] for s in sessions if s.get("started_at")]
date_range_start = min(started_timestamps) if started_timestamps else None
date_range_end = max(started_timestamps) if started_timestamps else None
return {
"total_sessions": len(sessions),
"total_messages": total_messages,
"total_tool_calls": total_tool_calls,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_tokens": total_tokens,
"estimated_cost": total_cost,
"total_hours": total_hours,
"avg_session_duration": avg_duration,
"avg_messages_per_session": total_messages / len(sessions) if sessions else 0,
"avg_tokens_per_session": total_tokens / len(sessions) if sessions else 0,
"user_messages": message_stats.get("user_messages") or 0,
"assistant_messages": message_stats.get("assistant_messages") or 0,
"tool_messages": message_stats.get("tool_messages") or 0,
"date_range_start": date_range_start,
"date_range_end": date_range_end,
"models_with_pricing": sorted(models_with_pricing),
"models_without_pricing": sorted(models_without_pricing),
}
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
"""Break down usage by model."""
model_data = defaultdict(lambda: {
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
})
for s in sessions:
model = s.get("model") or "unknown"
# Normalize: strip provider prefix for display
display_model = model.split("/")[-1] if "/" in model else model
d = model_data[display_model]
d["sessions"] += 1
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
d["input_tokens"] += inp
d["output_tokens"] += out
d["total_tokens"] += inp + out
d["tool_calls"] += s.get("tool_call_count") or 0
d["cost"] += _estimate_cost(model, inp, out)
d["has_pricing"] = _has_known_pricing(model)
result = [
{"model": model, **data}
for model, data in model_data.items()
]
# Sort by tokens first, fall back to session count when tokens are 0
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
return result
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]:
"""Break down usage by platform/source."""
platform_data = defaultdict(lambda: {
"sessions": 0, "messages": 0, "input_tokens": 0,
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
})
for s in sessions:
source = s.get("source") or "unknown"
d = platform_data[source]
d["sessions"] += 1
d["messages"] += s.get("message_count") or 0
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
d["input_tokens"] += inp
d["output_tokens"] += out
d["total_tokens"] += inp + out
d["tool_calls"] += s.get("tool_call_count") or 0
result = [
{"platform": platform, **data}
for platform, data in platform_data.items()
]
result.sort(key=lambda x: x["sessions"], reverse=True)
return result
def _compute_tool_breakdown(self, tool_usage: List[Dict]) -> List[Dict]:
"""Process tool usage data into a ranked list with percentages."""
total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0
result = []
for t in tool_usage:
pct = (t["count"] / total_calls * 100) if total_calls else 0
result.append({
"tool": t["tool_name"],
"count": t["count"],
"percentage": pct,
})
return result
def _compute_activity_patterns(self, sessions: List[Dict]) -> Dict:
"""Analyze activity patterns by day of week and hour."""
day_counts = Counter() # 0=Monday ... 6=Sunday
hour_counts = Counter()
daily_counts = Counter() # date string -> count
for s in sessions:
ts = s.get("started_at")
if not ts:
continue
dt = datetime.fromtimestamp(ts)
day_counts[dt.weekday()] += 1
hour_counts[dt.hour] += 1
daily_counts[dt.strftime("%Y-%m-%d")] += 1
day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
day_breakdown = [
{"day": day_names[i], "count": day_counts.get(i, 0)}
for i in range(7)
]
hour_breakdown = [
{"hour": i, "count": hour_counts.get(i, 0)}
for i in range(24)
]
# Busiest day and hour
busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None
busiest_hour = max(hour_breakdown, key=lambda x: x["count"]) if hour_breakdown else None
# Active days (days with at least one session)
active_days = len(daily_counts)
# Streak calculation
if daily_counts:
all_dates = sorted(daily_counts.keys())
current_streak = 1
max_streak = 1
for i in range(1, len(all_dates)):
d1 = datetime.strptime(all_dates[i - 1], "%Y-%m-%d")
d2 = datetime.strptime(all_dates[i], "%Y-%m-%d")
if (d2 - d1).days == 1:
current_streak += 1
max_streak = max(max_streak, current_streak)
else:
current_streak = 1
else:
max_streak = 0
return {
"by_day": day_breakdown,
"by_hour": hour_breakdown,
"busiest_day": busiest_day,
"busiest_hour": busiest_hour,
"active_days": active_days,
"max_streak": max_streak,
}
def _compute_top_sessions(self, sessions: List[Dict]) -> List[Dict]:
"""Find notable sessions (longest, most messages, most tokens)."""
top = []
# Longest by duration
sessions_with_duration = [
s for s in sessions
if s.get("started_at") and s.get("ended_at")
]
if sessions_with_duration:
longest = max(
sessions_with_duration,
key=lambda s: (s["ended_at"] - s["started_at"]),
)
dur = longest["ended_at"] - longest["started_at"]
top.append({
"label": "Longest session",
"session_id": longest["id"][:16],
"value": _format_duration(dur),
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
})
# Most messages
most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0)
if (most_msgs.get("message_count") or 0) > 0:
top.append({
"label": "Most messages",
"session_id": most_msgs["id"][:16],
"value": f"{most_msgs['message_count']} msgs",
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") if most_msgs.get("started_at") else "?",
})
# Most tokens
most_tokens = max(
sessions,
key=lambda s: (s.get("input_tokens") or 0) + (s.get("output_tokens") or 0),
)
token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0)
if token_total > 0:
top.append({
"label": "Most tokens",
"session_id": most_tokens["id"][:16],
"value": f"{token_total:,} tokens",
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") if most_tokens.get("started_at") else "?",
})
# Most tool calls
most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0)
if (most_tools.get("tool_call_count") or 0) > 0:
top.append({
"label": "Most tool calls",
"session_id": most_tools["id"][:16],
"value": f"{most_tools['tool_call_count']} calls",
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") if most_tools.get("started_at") else "?",
})
return top
# =========================================================================
# Formatting
# =========================================================================
def format_terminal(self, report: Dict) -> str:
"""Format the insights report for terminal display (CLI)."""
if report.get("empty"):
days = report.get("days", 30)
src = f" (source: {report['source_filter']})" if report.get("source_filter") else ""
return f" No sessions found in the last {days} days{src}."
lines = []
o = report["overview"]
days = report["days"]
src_filter = report.get("source_filter")
# Header
lines.append("")
lines.append(" ╔══════════════════════════════════════════════════════════╗")
lines.append(" ║ 📊 Hermes Insights ║")
period_label = f"Last {days} days"
if src_filter:
period_label += f" ({src_filter})"
padding = 58 - len(period_label) - 2
left_pad = padding // 2
right_pad = padding - left_pad
lines.append(f"{' ' * left_pad} {period_label} {' ' * right_pad}")
lines.append(" ╚══════════════════════════════════════════════════════════╝")
lines.append("")
# Date range
if o.get("date_range_start") and o.get("date_range_end"):
start_str = datetime.fromtimestamp(o["date_range_start"]).strftime("%b %d, %Y")
end_str = datetime.fromtimestamp(o["date_range_end"]).strftime("%b %d, %Y")
lines.append(f" Period: {start_str}{end_str}")
lines.append("")
# Overview
lines.append(" 📋 Overview")
lines.append(" " + "" * 56)
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}")
cost_str = f"${o['estimated_cost']:.2f}"
if o.get("models_without_pricing"):
cost_str += " *"
lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}")
if o["total_hours"] > 0:
lines.append(f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}")
lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}")
lines.append("")
# Model breakdown
if report["models"]:
lines.append(" 🤖 Models Used")
lines.append(" " + "" * 56)
lines.append(f" {'Model':<30} {'Sessions':>8} {'Tokens':>12} {'Cost':>8}")
for m in report["models"]:
model_name = m["model"][:28]
if m.get("has_pricing"):
cost_cell = f"${m['cost']:>6.2f}"
else:
cost_cell = " N/A"
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
if o.get("models_without_pricing"):
lines.append(f" * Cost N/A for custom/self-hosted models")
lines.append("")
# Platform breakdown
if len(report["platforms"]) > 1 or (report["platforms"] and report["platforms"][0]["platform"] != "cli"):
lines.append(" 📱 Platforms")
lines.append(" " + "" * 56)
lines.append(f" {'Platform':<14} {'Sessions':>8} {'Messages':>10} {'Tokens':>14}")
for p in report["platforms"]:
lines.append(f" {p['platform']:<14} {p['sessions']:>8} {p['messages']:>10,} {p['total_tokens']:>14,}")
lines.append("")
# Tool usage
if report["tools"]:
lines.append(" 🔧 Top Tools")
lines.append(" " + "" * 56)
lines.append(f" {'Tool':<28} {'Calls':>8} {'%':>8}")
for t in report["tools"][:15]: # Top 15
lines.append(f" {t['tool']:<28} {t['count']:>8,} {t['percentage']:>7.1f}%")
if len(report["tools"]) > 15:
lines.append(f" ... and {len(report['tools']) - 15} more tools")
lines.append("")
# Activity patterns
act = report.get("activity", {})
if act.get("by_day"):
lines.append(" 📅 Activity Patterns")
lines.append(" " + "" * 56)
# Day of week chart
day_values = [d["count"] for d in act["by_day"]]
bars = _bar_chart(day_values, max_width=15)
for i, d in enumerate(act["by_day"]):
bar = bars[i]
lines.append(f" {d['day']} {bar:<15} {d['count']}")
lines.append("")
# Peak hours (show top 5 busiest hours)
busy_hours = sorted(act["by_hour"], key=lambda x: x["count"], reverse=True)
busy_hours = [h for h in busy_hours if h["count"] > 0][:5]
if busy_hours:
hour_strs = []
for h in busy_hours:
hr = h["hour"]
ampm = "AM" if hr < 12 else "PM"
display_hr = hr % 12 or 12
hour_strs.append(f"{display_hr}{ampm} ({h['count']})")
lines.append(f" Peak hours: {', '.join(hour_strs)}")
if act.get("active_days"):
lines.append(f" Active days: {act['active_days']}")
if act.get("max_streak") and act["max_streak"] > 1:
lines.append(f" Best streak: {act['max_streak']} consecutive days")
lines.append("")
# Notable sessions
if report.get("top_sessions"):
lines.append(" 🏆 Notable Sessions")
lines.append(" " + "" * 56)
for ts in report["top_sessions"]:
lines.append(f" {ts['label']:<20} {ts['value']:<18} ({ts['date']}, {ts['session_id']})")
lines.append("")
return "\n".join(lines)
def format_gateway(self, report: Dict) -> str:
"""Format the insights report for gateway/messaging (shorter)."""
if report.get("empty"):
days = report.get("days", 30)
return f"No sessions found in the last {days} days."
lines = []
o = report["overview"]
days = report["days"]
lines.append(f"📊 **Hermes Insights** — Last {days} days\n")
# Overview
lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}")
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
cost_note = ""
if o.get("models_without_pricing"):
cost_note = " _(excludes custom/self-hosted models)_"
lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}")
if o["total_hours"] > 0:
lines.append(f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}")
lines.append("")
# Models (top 5)
if report["models"]:
lines.append("**🤖 Models:**")
for m in report["models"][:5]:
cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A"
lines.append(f" {m['model'][:25]}{m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}")
lines.append("")
# Platforms (if multi-platform)
if len(report["platforms"]) > 1:
lines.append("**📱 Platforms:**")
for p in report["platforms"]:
lines.append(f" {p['platform']}{p['sessions']} sessions, {p['messages']:,} msgs")
lines.append("")
# Tools (top 8)
if report["tools"]:
lines.append("**🔧 Top Tools:**")
for t in report["tools"][:8]:
lines.append(f" {t['tool']}{t['count']:,} calls ({t['percentage']:.1f}%)")
lines.append("")
# Activity summary
act = report.get("activity", {})
if act.get("busiest_day") and act.get("busiest_hour"):
hr = act["busiest_hour"]["hour"]
ampm = "AM" if hr < 12 else "PM"
display_hr = hr % 12 or 12
lines.append(f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)")
if act.get("active_days"):
lines.append(f"**Active days:** {act['active_days']}", )
if act.get("max_streak", 0) > 1:
lines.append(f"**Best streak:** {act['max_streak']} consecutive days")
return "\n".join(lines)

View File

@@ -5,10 +5,14 @@ and run_agent.py for pre-flight context checks.
"""
import logging
import os
import re
import time
from typing import Any, Dict, List
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
import yaml
from hermes_constants import OPENROUTER_MODELS_URL
@@ -18,6 +22,18 @@ _model_metadata_cache: Dict[str, Dict[str, Any]] = {}
_model_metadata_cache_time: float = 0
_MODEL_CACHE_TTL = 3600
# Descending tiers for context length probing when the model is unknown.
# We start high and step down on context-length errors until one works.
CONTEXT_PROBE_TIERS = [
2_000_000,
1_000_000,
512_000,
200_000,
128_000,
64_000,
32_000,
]
DEFAULT_CONTEXT_LENGTHS = {
"anthropic/claude-opus-4": 200000,
"anthropic/claude-opus-4.5": 200000,
@@ -33,6 +49,17 @@ DEFAULT_CONTEXT_LENGTHS = {
"meta-llama/llama-3.3-70b-instruct": 131072,
"deepseek/deepseek-chat-v3": 65536,
"qwen/qwen-2.5-72b-instruct": 32768,
"glm-4.7": 202752,
"glm-5": 202752,
"glm-4.5": 131072,
"glm-4.5-flash": 131072,
"kimi-k2.5": 262144,
"kimi-k2-thinking": 262144,
"kimi-k2-turbo-preview": 262144,
"kimi-k2-0905-preview": 131072,
"MiniMax-M2.5": 204800,
"MiniMax-M2.5-highspeed": 204800,
"MiniMax-M2.1": 204800,
}
@@ -71,17 +98,117 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
return _model_metadata_cache or {}
def get_model_context_length(model: str) -> int:
"""Get the context length for a model (API first, then fallback defaults)."""
def _get_context_cache_path() -> Path:
"""Return path to the persistent context length cache file."""
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
return hermes_home / "context_length_cache.yaml"
def _load_context_cache() -> Dict[str, int]:
"""Load the model+provider → context_length cache from disk."""
path = _get_context_cache_path()
if not path.exists():
return {}
try:
with open(path) as f:
data = yaml.safe_load(f) or {}
return data.get("context_lengths", {})
except Exception as e:
logger.debug("Failed to load context length cache: %s", e)
return {}
def save_context_length(model: str, base_url: str, length: int) -> None:
"""Persist a discovered context length for a model+provider combo.
Cache key is ``model@base_url`` so the same model name served from
different providers can have different limits.
"""
key = f"{model}@{base_url}"
cache = _load_context_cache()
if cache.get(key) == length:
return # already stored
cache[key] = length
path = _get_context_cache_path()
try:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
logger.info("Cached context length %s%s tokens", key, f"{length:,}")
except Exception as e:
logger.debug("Failed to save context length cache: %s", e)
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
"""Look up a previously discovered context length for model+provider."""
key = f"{model}@{base_url}"
cache = _load_context_cache()
return cache.get(key)
def get_next_probe_tier(current_length: int) -> Optional[int]:
"""Return the next lower probe tier, or None if already at minimum."""
for tier in CONTEXT_PROBE_TIERS:
if tier < current_length:
return tier
return None
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
"""Try to extract the actual context limit from an API error message.
Many providers include the limit in their error text, e.g.:
- "maximum context length is 32768 tokens"
- "context_length_exceeded: 131072"
- "Maximum context size 32768 exceeded"
- "model's max context length is 65536"
"""
error_lower = error_msg.lower()
# Pattern: look for numbers near context-related keywords
patterns = [
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
]
for pattern in patterns:
match = re.search(pattern, error_lower)
if match:
limit = int(match.group(1))
# Sanity check: must be a reasonable context length
if 1024 <= limit <= 10_000_000:
return limit
return None
def get_model_context_length(model: str, base_url: str = "") -> int:
"""Get the context length for a model.
Resolution order:
1. Persistent cache (previously discovered via probing)
2. OpenRouter API metadata
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match)
4. First probe tier (2M) — will be narrowed on first context error
"""
# 1. Check persistent cache (model+provider)
if base_url:
cached = get_cached_context_length(model, base_url)
if cached is not None:
return cached
# 2. OpenRouter API metadata
metadata = fetch_model_metadata()
if model in metadata:
return metadata[model].get("context_length", 128000)
# 3. Hardcoded defaults (fuzzy match)
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
if default_model in model or model in default_model:
return length
return 128000
# 4. Unknown model — start at highest probe tier
return CONTEXT_PROBE_TIERS[0]
def estimate_tokens_rough(text: str) -> int:

View File

@@ -66,7 +66,8 @@ DEFAULT_AGENT_IDENTITY = (
"range of tasks including answering questions, writing and editing code, "
"analyzing information, creative work, and executing actions via your tools. "
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
"being genuinely useful over being verbose unless otherwise directed below."
"being genuinely useful over being verbose unless otherwise directed below. "
"Be targeted and efficient in your exploration and investigations."
)
MEMORY_GUIDANCE = (
@@ -102,12 +103,24 @@ PLATFORM_HINTS = {
"You are on a text messaging communication platform, Telegram. "
"Please do not use markdown as it does not render. "
"You can send media files natively: to deliver a file to the user, "
"include MEDIA:/absolute/path/to/file in your response. Audio "
"(.ogg) sends as voice bubbles. You can also include image URLs "
"in markdown format ![alt](url) and they will be sent as native photos."
"include MEDIA:/absolute/path/to/file in your response. Images "
"(.png, .jpg, .webp) appear as photos, audio (.ogg) sends as voice "
"bubbles, and videos (.mp4) play inline. You can also include image "
"URLs in markdown format ![alt](url) and they will be sent as native photos."
),
"discord": (
"You are in a Discord server or group chat communicating with your user."
"You are in a Discord server or group chat communicating with your user. "
"You can send media files natively: include MEDIA:/absolute/path/to/file "
"in your response. Images (.png, .jpg, .webp) are sent as photo "
"attachments, audio as file attachments. You can also include image URLs "
"in markdown format ![alt](url) and they will be sent as attachments."
),
"slack": (
"You are in a Slack workspace communicating with your user. "
"You can send media files natively: include MEDIA:/absolute/path/to/file "
"in your response. Images (.png, .jpg, .webp) are uploaded as photo "
"attachments, audio as file attachments. You can also include image URLs "
"in markdown format ![alt](url) and they will be uploaded as attachments."
),
"cli": (
"You are a CLI AI Agent. Try not to use markdown but simple text "
@@ -142,12 +155,28 @@ def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
return ""
def _skill_is_platform_compatible(skill_file: Path) -> bool:
"""Quick check if a SKILL.md is compatible with the current OS platform.
Reads just enough to parse the ``platforms`` frontmatter field.
Skills without the field (the vast majority) are always compatible.
"""
try:
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
raw = skill_file.read_text(encoding="utf-8")[:2000]
frontmatter, _ = _parse_frontmatter(raw)
return skill_matches_platform(frontmatter)
except Exception:
return True # Err on the side of showing the skill
def build_skills_system_prompt() -> str:
"""Build a compact skill index for the system prompt.
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
Includes per-skill descriptions from frontmatter so the model can
match skills by meaning, not just name.
Filters out skills incompatible with the current OS platform.
"""
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
skills_dir = hermes_home / "skills"
@@ -159,6 +188,9 @@ def build_skills_system_prompt() -> str:
# Each entry: (skill_name, description)
skills_by_category: dict[str, list[tuple[str, str]]] = {}
for skill_file in skills_dir.rglob("SKILL.md"):
# Skip skills incompatible with the current OS platform
if not _skill_is_platform_compatible(skill_file):
continue
rel_path = skill_file.relative_to(skills_dir)
parts = rel_path.parts
if len(parts) >= 2:

View File

@@ -22,7 +22,7 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
global _skill_commands
_skill_commands = {}
try:
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
if not SKILLS_DIR.exists():
return _skill_commands
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
@@ -31,6 +31,9 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
try:
content = skill_md.read_text(encoding='utf-8')
frontmatter, body = _parse_frontmatter(content)
# Skip skills incompatible with the current OS platform
if not skill_matches_platform(frontmatter):
continue
name = frontmatter.get('name', skill_md.parent.name)
description = frontmatter.get('description', '')
if not description:

View File

@@ -29,7 +29,6 @@ from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from multiprocessing import Pool, Lock
import traceback
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import fire
@@ -250,7 +249,7 @@ def _process_single_prompt(
task_id = f"task_{prompt_index}"
# Per-prompt container image override: if the dataset row has an 'image' field,
# register it for this task's sandbox. Works with Docker, Modal, and Singularity.
# register it for this task's sandbox. Works with Docker, Modal, Singularity, and Daytona.
container_image = prompt_data.get("image") or prompt_data.get("docker_image")
if container_image:
# Verify the image is accessible before spending tokens on the agent loop.
@@ -292,6 +291,7 @@ def _process_single_prompt(
"docker_image": container_image,
"modal_image": container_image,
"singularity_image": f"docker://{container_image}",
"daytona_image": container_image,
}
if prompt_data.get("cwd"):
overrides["cwd"] = prompt_data["cwd"]
@@ -700,14 +700,13 @@ class BatchRunner:
lock (Lock): Optional lock for thread-safe access
"""
checkpoint_data["last_updated"] = datetime.now().isoformat()
from utils import atomic_json_write
if lock:
with lock:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
atomic_json_write(self.checkpoint_file, checkpoint_data)
else:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
atomic_json_write(self.checkpoint_file, checkpoint_data)
def _scan_completed_prompts_by_content(self) -> set:
"""
@@ -832,13 +831,15 @@ class BatchRunner:
print(f" New batches created: {len(batches_to_process)}")
print("=" * 70 + "\n")
# Initialize checkpoint data (needed for saving at the end)
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
# Load existing checkpoint (so resume doesn't clobber prior progress)
checkpoint_data = self._load_checkpoint()
if checkpoint_data.get("run_name") != self.run_name:
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
# Prepare configuration for workers
config = {
@@ -860,7 +861,7 @@ class BatchRunner:
}
# For backward compatibility, still track by index (but this is secondary to content matching)
completed_prompts_set = set()
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
# Aggregate statistics across all batches
total_tool_stats = {}
@@ -869,6 +870,9 @@ class BatchRunner:
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
# Checkpoint writes happen in the parent process; keep a lock for safety.
checkpoint_lock = Lock()
# Process batches in parallel
with Pool(processes=self.num_workers) as pool:
# Create tasks for each batch
@@ -914,6 +918,25 @@ class BatchRunner:
for result in pool.imap_unordered(_process_batch_worker, tasks):
results.append(result)
progress.update(task, advance=1)
# Incremental checkpoint update (so resume works after crash)
try:
batch_num = result.get('batch_num')
completed = result.get('completed_prompts', []) or []
completed_prompts_set.update(completed)
if isinstance(batch_num, int):
checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = {
'processed': result.get('processed', 0),
'skipped': result.get('skipped', 0),
'discarded_no_reasoning': result.get('discarded_no_reasoning', 0),
}
checkpoint_data['completed_prompts'] = sorted(completed_prompts_set)
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
except Exception as ckpt_err:
# Don't fail the run if checkpoint write fails
print(f"⚠️ Warning: Failed to save incremental checkpoint: {ckpt_err}")
except Exception as e:
logger.error("Batch worker failed: %s", e, exc_info=True)
raise
@@ -945,9 +968,12 @@ class BatchRunner:
for key in total_reasoning_stats:
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
# Save final checkpoint
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
# Save final checkpoint (best-effort; incremental writes already happened)
try:
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
except Exception as ckpt_err:
print(f"⚠️ Warning: Failed to save final checkpoint: {ckpt_err}")
# Calculate success rates
for tool_name in total_tool_stats:
@@ -1086,7 +1112,7 @@ def main(
batch_size: int = None,
run_name: str = None,
distribution: str = "default",
model: str = "anthropic/claude-sonnet-4-20250514",
model: str = "anthropic/claude-sonnet-4.6",
api_key: str = None,
base_url: str = "https://openrouter.ai/api/v1",
max_turns: int = 10,
@@ -1129,7 +1155,7 @@ def main(
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh")
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "medium")
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
@@ -1190,7 +1216,7 @@ def main(
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
# Build reasoning_config from CLI flags
# --reasoning_disabled takes priority, then --reasoning_effort, then default (xhigh)
# --reasoning_disabled takes priority, then --reasoning_effort, then default (medium)
reasoning_config = None
if reasoning_disabled:
# Completely disable reasoning/thinking tokens

View File

@@ -13,6 +13,10 @@ model:
# "auto" - Use Nous Portal if logged in, otherwise OpenRouter/env vars (default)
# "openrouter" - Always use OpenRouter API key from OPENROUTER_API_KEY
# "nous" - Always use Nous Portal (requires: hermes login)
# "zai" - Use z.ai / ZhipuAI GLM models (requires: GLM_API_KEY)
# "kimi-coding"- Use Kimi / Moonshot AI models (requires: KIMI_API_KEY)
# "minimax" - Use MiniMax global endpoint (requires: MINIMAX_API_KEY)
# "minimax-cn" - Use MiniMax China endpoint (requires: MINIMAX_CN_API_KEY)
# Can also be overridden with --provider flag or HERMES_INFERENCE_PROVIDER env var.
provider: "auto"
@@ -46,6 +50,16 @@ model:
# # Data policy: "allow" (default) or "deny" to exclude providers that may store data
# # data_collection: "deny"
# =============================================================================
# Git Worktree Isolation
# =============================================================================
# When enabled, each CLI session creates an isolated git worktree so multiple
# agents can work on the same repo concurrently without file collisions.
# Equivalent to always passing --worktree / -w on the command line.
#
# worktree: true # Always create a worktree when in a git repo
# worktree: false # Default — only create when -w flag is passed
# =============================================================================
# Terminal Tool Configuration
# =============================================================================
@@ -116,8 +130,23 @@ terminal:
# timeout: 180
# lifetime_seconds: 300
# modal_image: "nikolaik/python-nodejs:python3.11-nodejs20"
# -----------------------------------------------------------------------------
# OPTION 6: Daytona cloud execution
# Commands run in Daytona cloud sandboxes
# Great for: Cloud dev environments, persistent workspaces, team collaboration
# Requires: pip install daytona, DAYTONA_API_KEY env var
# -----------------------------------------------------------------------------
# terminal:
# backend: "daytona"
# cwd: "~"
# timeout: 180
# lifetime_seconds: 300
# daytona_image: "nikolaik/python-nodejs:python3.11-nodejs20"
# container_disk: 10240 # Daytona max is 10GB per sandbox
#
# --- Container resource limits (docker, singularity, modal -- ignored for local/ssh) ---
# --- Container resource limits (docker, singularity, modal, daytona -- ignored for local/ssh) ---
# These settings apply to all container backends. They control the resources
# allocated to the sandbox and whether its filesystem persists across sessions.
container_cpu: 1 # CPU cores
@@ -266,7 +295,7 @@ agent:
# Reasoning effort level (OpenRouter and Nous Portal)
# Controls how much "thinking" the model does before responding.
# Options: "xhigh" (max), "high", "medium", "low", "minimal", "none" (disable)
reasoning_effort: "xhigh"
reasoning_effort: "medium"
# Predefined personalities (use with /personality command)
personalities:

990
cli.py

File diff suppressed because it is too large Load Diff

View File

@@ -14,6 +14,8 @@ from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, List, Any
from hermes_time import now as _hermes_now
try:
from croniter import croniter
HAS_CRONITER = True
@@ -128,7 +130,7 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
# Duration like "30m", "2h", "1d" → one-shot from now
try:
minutes = parse_duration(schedule)
run_at = datetime.now() + timedelta(minutes=minutes)
run_at = _hermes_now() + timedelta(minutes=minutes)
return {
"kind": "once",
"run_at": run_at.isoformat(),
@@ -146,37 +148,50 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
)
def _ensure_aware(dt: datetime) -> datetime:
"""Make a naive datetime tz-aware using the configured timezone.
Handles backward compatibility: timestamps stored before timezone support
are naive (server-local). We assume they were in the same timezone as
the current configuration so comparisons work without crashing.
"""
if dt.tzinfo is None:
tz = _hermes_now().tzinfo
return dt.replace(tzinfo=tz)
return dt
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
"""
Compute the next run time for a schedule.
Returns ISO timestamp string, or None if no more runs.
"""
now = datetime.now()
now = _hermes_now()
if schedule["kind"] == "once":
run_at = datetime.fromisoformat(schedule["run_at"])
run_at = _ensure_aware(datetime.fromisoformat(schedule["run_at"]))
# If in the future, return it; if in the past, no more runs
return schedule["run_at"] if run_at > now else None
elif schedule["kind"] == "interval":
minutes = schedule["minutes"]
if last_run_at:
# Next run is last_run + interval
last = datetime.fromisoformat(last_run_at)
last = _ensure_aware(datetime.fromisoformat(last_run_at))
next_run = last + timedelta(minutes=minutes)
else:
# First run is now + interval
next_run = now + timedelta(minutes=minutes)
return next_run.isoformat()
elif schedule["kind"] == "cron":
if not HAS_CRONITER:
return None
cron = croniter(schedule["expr"], now)
next_run = cron.get_next(datetime)
return next_run.isoformat()
return None
@@ -204,7 +219,7 @@ def save_jobs(jobs: List[Dict[str, Any]]):
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_')
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
json.dump({"jobs": jobs, "updated_at": datetime.now().isoformat()}, f, indent=2)
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, JOBS_FILE)
@@ -249,7 +264,7 @@ def create_job(
deliver = "origin" if origin else "local"
job_id = uuid.uuid4().hex[:12]
now = datetime.now().isoformat()
now = _hermes_now().isoformat()
job = {
"id": job_id,
@@ -328,7 +343,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
jobs = load_jobs()
for i, job in enumerate(jobs):
if job["id"] == job_id:
now = datetime.now().isoformat()
now = _hermes_now().isoformat()
job["last_run_at"] = now
job["last_status"] = "ok" if success else "error"
job["last_error"] = error if not success else None
@@ -361,7 +376,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
def get_due_jobs() -> List[Dict[str, Any]]:
"""Get all jobs that are due to run now."""
now = datetime.now()
now = _hermes_now()
jobs = load_jobs()
due = []
@@ -373,7 +388,7 @@ def get_due_jobs() -> List[Dict[str, Any]]:
if not next_run:
continue
next_run_dt = datetime.fromisoformat(next_run)
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
if next_run_dt <= now:
due.append(job)
@@ -386,7 +401,7 @@ def save_job_output(job_id: str, output: str):
job_output_dir = OUTPUT_DIR / job_id
job_output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
output_file = job_output_dir / f"{timestamp}.md"
with open(output_file, 'w', encoding='utf-8') as f:

View File

@@ -27,6 +27,8 @@ from datetime import datetime
from pathlib import Path
from typing import Optional
from hermes_time import now as _hermes_now
logger = logging.getLogger(__name__)
# Add parent directory to path for imports
@@ -174,6 +176,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
model = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
# Load config.yaml for model, reasoning, prefill, toolsets, provider routing
_cfg = {}
try:
import yaml
_cfg_path = str(_hermes_home / "config.yaml")
@@ -188,6 +192,41 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
except Exception:
pass
# Reasoning config from env or config.yaml
reasoning_config = None
effort = os.getenv("HERMES_REASONING_EFFORT", "")
if not effort:
effort = str(_cfg.get("agent", {}).get("reasoning_effort", "")).strip()
if effort and effort.lower() != "none":
valid = ("xhigh", "high", "medium", "low", "minimal")
if effort.lower() in valid:
reasoning_config = {"enabled": True, "effort": effort.lower()}
elif effort.lower() == "none":
reasoning_config = {"enabled": False}
# Prefill messages from env or config.yaml
prefill_messages = None
prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "")
if prefill_file:
import json as _json
pfpath = Path(prefill_file).expanduser()
if not pfpath.is_absolute():
pfpath = _hermes_home / pfpath
if pfpath.exists():
try:
with open(pfpath, "r", encoding="utf-8") as _pf:
prefill_messages = _json.load(_pf)
if not isinstance(prefill_messages, list):
prefill_messages = None
except Exception:
prefill_messages = None
# Max iterations
max_iterations = _cfg.get("agent", {}).get("max_turns") or _cfg.get("max_turns") or 90
# Provider routing
pr = _cfg.get("provider_routing", {})
from hermes_cli.runtime_provider import (
resolve_runtime_provider,
format_runtime_provider_error,
@@ -206,8 +245,15 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
base_url=runtime.get("base_url"),
provider=runtime.get("provider"),
api_mode=runtime.get("api_mode"),
max_iterations=max_iterations,
reasoning_config=reasoning_config,
prefill_messages=prefill_messages,
providers_allowed=pr.get("only"),
providers_ignored=pr.get("ignore"),
providers_order=pr.get("order"),
provider_sort=pr.get("sort"),
quiet_mode=True,
session_id=f"cron_{job_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
)
result = agent.run_conversation(prompt)
@@ -219,7 +265,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
output = f"""# Cron Job: {job_name}
**Job ID:** {job_id}
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
**Schedule:** {job.get('schedule_display', 'N/A')}
## Prompt
@@ -241,7 +287,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
output = f"""# Cron Job: {job_name} (FAILED)
**Job ID:** {job_id}
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
**Schedule:** {job.get('schedule_display', 'N/A')}
## Prompt
@@ -280,6 +326,7 @@ def tick(verbose: bool = True) -> int:
_LOCK_DIR.mkdir(parents=True, exist_ok=True)
# Cross-platform file locking: fcntl on Unix, msvcrt on Windows
lock_fd = None
try:
lock_fd = open(_LOCK_FILE, "w")
if fcntl:
@@ -288,17 +335,19 @@ def tick(verbose: bool = True) -> int:
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_NBLCK, 1)
except (OSError, IOError):
logger.debug("Tick skipped — another instance holds the lock")
if lock_fd is not None:
lock_fd.close()
return 0
try:
due_jobs = get_due_jobs()
if verbose and not due_jobs:
logger.info("%s - No jobs due", datetime.now().strftime('%H:%M:%S'))
logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S'))
return 0
if verbose:
logger.info("%s - %s job(s) due", datetime.now().strftime('%H:%M:%S'), len(due_jobs))
logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs))
executed = 0
for job in due_jobs:

View File

@@ -0,0 +1,345 @@
# send_file Integration Map — Hermes Agent Codebase Deep Dive
## 1. environments/tool_context.py — Base64 File Transfer Implementation
### upload_file() (lines 153-205)
- Reads local file as raw bytes, base64-encodes to ASCII string
- Creates parent dirs in sandbox via `self.terminal(f"mkdir -p {parent}")`
- **Chunk size:** 60,000 chars (~60KB per shell command)
- **Small files (<=60KB b64):** Single `printf '%s' '{b64}' | base64 -d > {remote_path}`
- **Large files:** Writes chunks to `/tmp/_hermes_upload.b64` via `printf >> append`, then `base64 -d` to target
- **Error handling:** Checks local file exists; returns `{exit_code, output}`
- **Size limits:** No explicit limit, but shell arg limit ~2MB means chunking is necessary for files >~45KB raw
- **No theoretical max** — but very large files would be slow (many terminal round trips)
### download_file() (lines 234-278)
- Runs `base64 {remote_path}` inside sandbox, captures stdout
- Strips output, base64-decodes to raw bytes
- Writes to host filesystem with parent dir creation
- **Error handling:** Checks exit code, empty output, decode errors
- Returns `{success: bool, bytes: int}` or `{success: false, error: str}`
- **Size limit:** Bounded by terminal output buffer (practical limit ~few MB via base64 terminal output)
### Promotion potential:
- These methods work via `self.terminal()` — they're environment-agnostic
- Could be directly lifted into a new tool that operates on the agent's current sandbox
- For send_file, this `download_file()` pattern is the key: it extracts files from sandbox → host
## 2. tools/environments/base.py — BaseEnvironment Interface
### Current methods:
- `execute(command, cwd, timeout, stdin_data)``{output, returncode}`
- `cleanup()` — release resources
- `stop()` — alias for cleanup
- `_prepare_command()` — sudo transformation
- `_build_run_kwargs()` — subprocess kwargs
- `_timeout_result()` — standard timeout dict
### What would need to be added for file transfer:
- **Nothing required at this level.** File transfer can be implemented via `execute()` (base64 over terminal, like ToolContext does) or via environment-specific methods.
- Optional: `upload_file(local_path, remote_path)` and `download_file(remote_path, local_path)` methods could be added to BaseEnvironment for optimized per-backend transfers, but the base64-over-terminal approach already works universally.
## 3. tools/environments/docker.py — Docker Container Details
### Container ID tracking:
- `self._container_id` stored at init from `self._inner.container_id`
- Inner is `minisweagent.environments.docker.DockerEnvironment`
- Container ID is a standard Docker container hash
### docker cp feasibility:
- **YES**, `docker cp` could be used for optimized file transfer:
- `docker cp {container_id}:{remote_path} {local_path}` (download)
- `docker cp {local_path} {container_id}:{remote_path}` (upload)
- Much faster than base64-over-terminal for large files
- Container ID is directly accessible via `env._container_id` or `env._inner.container_id`
### Volumes mounted:
- **Persistent mode:** Bind mounts at `~/.hermes/sandboxes/docker/{task_id}/workspace``/workspace` and `.../home``/root`
- **Ephemeral mode:** tmpfs at `/workspace` (10GB), `/home` (1GB), `/root` (1GB)
- **User volumes:** From `config.yaml docker_volumes` (arbitrary `-v` mounts)
- **Security tmpfs:** `/tmp` (512MB), `/var/tmp` (256MB), `/run` (64MB)
### Direct host access for persistent mode:
- If persistent, files at `/workspace/foo.txt` are just `~/.hermes/sandboxes/docker/{task_id}/workspace/foo.txt` on host — no transfer needed!
## 4. tools/environments/ssh.py — SSH Connection Management
### Connection management:
- Uses SSH ControlMaster for persistent connection
- Control socket at `/tmp/hermes-ssh/{user}@{host}:{port}.sock`
- ControlPersist=300 (5 min keepalive)
- BatchMode=yes (non-interactive)
- Stores: `self.host`, `self.user`, `self.port`, `self.key_path`
### SCP/SFTP feasibility:
- **YES**, SCP can piggyback on the ControlMaster socket:
- `scp -o ControlPath={socket} {user}@{host}:{remote} {local}` (download)
- `scp -o ControlPath={socket} {local} {user}@{host}:{remote}` (upload)
- Same SSH key and connection reuse — zero additional auth
- Would be much faster than base64-over-terminal for large files
## 5. tools/environments/modal.py — Modal Sandbox Filesystem
### Filesystem API exposure:
- **Not directly.** The inner `SwerexModalEnvironment` wraps Modal's sandbox
- The sandbox object is accessible at: `env._inner.deployment._sandbox`
- Modal's Python SDK exposes `sandbox.open()` for file I/O — but only via async API
- Currently only used for `snapshot_filesystem()` during cleanup
- **Could use:** `sandbox.open(path, "rb")` to read files or `sandbox.open(path, "wb")` to write
- **Alternative:** Base64-over-terminal already works via `execute()` — simpler, no SDK dependency
## 6. gateway/platforms/base.py — MEDIA: Tag Flow (Complete)
### extract_media() (lines 587-620):
- **Pattern:** `MEDIA:\S+` — extracts file paths after MEDIA: prefix
- **Voice flag:** `[[audio_as_voice]]` global directive sets `is_voice=True` for all media in message
- Returns `List[Tuple[str, bool]]` (path, is_voice) and cleaned content
### _process_message_background() media routing (lines 752-786):
- After extracting MEDIA tags, routes by file extension:
- `.ogg .opus .mp3 .wav .m4a``send_voice()`
- `.mp4 .mov .avi .mkv .3gp``send_video()`
- `.jpg .jpeg .png .webp .gif``send_image_file()`
- **Everything else** → `send_document()`
- This routing already supports arbitrary files!
### send_* method inventory (base class):
- `send(chat_id, content, reply_to, metadata)` — ABSTRACT, text
- `send_image(chat_id, image_url, caption, reply_to)` — URL-based images
- `send_animation(chat_id, animation_url, caption, reply_to)` — GIF animations
- `send_voice(chat_id, audio_path, caption, reply_to)` — voice messages
- `send_video(chat_id, video_path, caption, reply_to)` — video files
- `send_document(chat_id, file_path, caption, file_name, reply_to)` — generic files
- `send_image_file(chat_id, image_path, caption, reply_to)` — local image files
- `send_typing(chat_id)` — typing indicator
- `edit_message(chat_id, message_id, content)` — edit sent messages
### What's missing:
- **Telegram:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
- **Discord:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
- **Slack:** No override for `send_document` — falls back to text! (`send_image_file` ✅ added)
- **WhatsApp:** Has `send_document` and `send_image_file` via bridge — COMPLETE.
- The base class defaults just send "📎 File: /path" as text — useless for actual file delivery.
## 7. gateway/platforms/telegram.py — Send Method Analysis
### Implemented send methods:
- `send()` — MarkdownV2 text with fallback to plain
- `send_voice()``.ogg`/`.opus` as `send_voice()`, others as `send_audio()`
- `send_image()` — URL-based via `send_photo()`
- `send_image_file()` — local file via `send_photo(photo=open(path, 'rb'))`
- `send_animation()` — GIF via `send_animation()`
- `send_typing()` — "typing" chat action
- `edit_message()` — edit text messages
### MISSING:
- **`send_document()` NOT overridden** — Need to add `self._bot.send_document(chat_id, document=open(file_path, 'rb'), ...)`
- **`send_video()` NOT overridden** — Need to add `self._bot.send_video(...)`
## 8. gateway/platforms/discord.py — Send Method Analysis
### Implemented send methods:
- `send()` — text messages with chunking
- `send_voice()` — discord.File attachment
- `send_image()` — downloads URL, creates discord.File attachment
- `send_image_file()` — local file via discord.File attachment ✅
- `send_typing()` — channel.typing()
- `edit_message()` — edit text messages
### MISSING:
- **`send_document()` NOT overridden** — Need to add discord.File attachment
- **`send_video()` NOT overridden** — Need to add discord.File attachment
## 9. gateway/run.py — User File Attachment Handling
### Current attachment flow:
1. **Telegram photos** (line 509-529): Download via `photo.get_file()``cache_image_from_bytes()` → vision auto-analysis
2. **Telegram voice** (line 532-541): Download → `cache_audio_from_bytes()` → STT transcription
3. **Telegram audio** (line 542-551): Same pattern
4. **Telegram documents** (line 553-617): Extension validation against `SUPPORTED_DOCUMENT_TYPES`, 20MB limit, content injection for text files
5. **Discord attachments** (line 717-751): Content-type detection, image/audio caching, URL fallback for other types
6. **Gateway run.py** (lines 818-883): Auto-analyzes images with vision, transcribes audio, enriches document messages with context notes
### Key insight: Files are always cached to host filesystem first, then processed. The agent sees local file paths.
## 10. tools/terminal_tool.py — Terminal Tool & Environment Interaction
### How it manages environments:
- Global dict `_active_environments: Dict[str, Any]` keyed by task_id
- Per-task creation locks prevent duplicate sandbox creation
- Auto-cleanup thread kills idle environments after `TERMINAL_LIFETIME_SECONDS`
- `_get_env_config()` reads all TERMINAL_* env vars for backend selection
- `_create_environment()` factory creates the right backend type
### Could send_file piggyback?
- **YES.** send_file needs access to the same environment to extract files from sandboxes.
- It can reuse `_active_environments[task_id]` to get the environment, then:
- Docker: Use `docker cp` via `env._container_id`
- SSH: Use `scp` via `env.control_socket`
- Local: Just read the file directly
- Modal: Use base64-over-terminal via `env.execute()`
- The file_tools.py module already does this with `ShellFileOperations` — read_file/write_file/search/patch all share the same env instance.
## 11. tools/tts_tool.py — Working Example of File Delivery
### Flow:
1. Generate audio file to `~/.hermes/audio_cache/tts_TIMESTAMP.{ogg,mp3}`
2. Return JSON with `media_tag: "MEDIA:/path/to/file"`
3. For Telegram voice: prepend `[[audio_as_voice]]` directive
4. The LLM includes the MEDIA tag in its response text
5. `BasePlatformAdapter._process_message_background()` calls `extract_media()` to find the tag
6. Routes by extension → `send_voice()` for audio files
7. Platform adapter sends the file natively
### Key pattern: Tool saves file to host → returns MEDIA: path → LLM echoes it → gateway extracts → platform delivers
## 12. tools/image_generation_tool.py — Working Example of Image Delivery
### Flow:
1. Call FAL.ai API → get image URL
2. Return JSON with `image: "https://fal.media/..."` URL
3. The LLM includes the URL in markdown: `![description](URL)`
4. `BasePlatformAdapter.extract_images()` finds `![alt](url)` patterns
5. Routes through `send_image()` (URL) or `send_animation()` (GIF)
6. Platform downloads and sends natively
### Key difference from TTS: Images are URL-based, not local files. The gateway downloads at send time.
---
# INTEGRATION MAP: Where send_file Hooks In
## Architecture Decision: MEDIA: Tag Protocol vs. New Tool
The MEDIA: tag protocol is already the established pattern for file delivery. Two options:
### Option A: Pure MEDIA: Tag (Minimal Change)
- No new tool needed
- Agent downloads file from sandbox to host using terminal (base64)
- Saves to known location (e.g., `~/.hermes/file_cache/`)
- Includes `MEDIA:/path` in response text
- Existing routing in `_process_message_background()` handles delivery
- **Problem:** Agent has to manually do base64 dance + know about MEDIA: convention
### Option B: Dedicated send_file Tool (Recommended)
- New tool that the agent calls with `(file_path, caption?)`
- Tool handles the sandbox → host extraction automatically
- Returns MEDIA: tag that gets routed through existing pipeline
- Much cleaner agent experience
## Implementation Plan for Option B
### Files to CREATE:
1. **`tools/send_file_tool.py`** — The new tool
- Accepts: `file_path` (path in sandbox), `caption` (optional)
- Detects environment backend from `_active_environments`
- Extracts file from sandbox:
- **local:** `shutil.copy()` or direct path
- **docker:** `docker cp {container_id}:{path} {local_cache}/`
- **ssh:** `scp -o ControlPath=... {user}@{host}:{path} {local_cache}/`
- **modal:** base64-over-terminal via `env.execute("base64 {path}")`
- Saves to `~/.hermes/file_cache/{uuid}_{filename}`
- Returns: `MEDIA:/cached/path` in response for gateway to pick up
- Register with `registry.register(name="send_file", toolset="file", ...)`
### Files to MODIFY:
2. **`gateway/platforms/telegram.py`** — Add missing send methods:
```python
async def send_document(self, chat_id, file_path, caption=None, file_name=None, reply_to=None):
with open(file_path, "rb") as f:
msg = await self._bot.send_document(
chat_id=int(chat_id), document=f,
caption=caption, filename=file_name or os.path.basename(file_path))
return SendResult(success=True, message_id=str(msg.message_id))
async def send_image_file(self, chat_id, image_path, caption=None, reply_to=None):
with open(image_path, "rb") as f:
msg = await self._bot.send_photo(chat_id=int(chat_id), photo=f, caption=caption)
return SendResult(success=True, message_id=str(msg.message_id))
async def send_video(self, chat_id, video_path, caption=None, reply_to=None):
with open(video_path, "rb") as f:
msg = await self._bot.send_video(chat_id=int(chat_id), video=f, caption=caption)
return SendResult(success=True, message_id=str(msg.message_id))
```
3. **`gateway/platforms/discord.py`** — Add missing send methods:
```python
async def send_document(self, chat_id, file_path, caption=None, file_name=None, reply_to=None):
channel = self._client.get_channel(int(chat_id)) or await self._client.fetch_channel(int(chat_id))
with open(file_path, "rb") as f:
file = discord.File(io.BytesIO(f.read()), filename=file_name or os.path.basename(file_path))
msg = await channel.send(content=caption, file=file)
return SendResult(success=True, message_id=str(msg.id))
async def send_image_file(self, chat_id, image_path, caption=None, reply_to=None):
# Same pattern as send_document with image filename
async def send_video(self, chat_id, video_path, caption=None, reply_to=None):
# Same pattern, discord renders video attachments inline
```
4. **`toolsets.py`** — Add `"send_file"` to `_HERMES_CORE_TOOLS` list
5. **`agent/prompt_builder.py`** — Update platform hints to mention send_file tool
### Code that can be REUSED (zero rewrite):
- `BasePlatformAdapter.extract_media()` — Already extracts MEDIA: tags
- `BasePlatformAdapter._process_message_background()` — Already routes by extension
- `ToolContext.download_file()` — Base64-over-terminal extraction pattern
- `tools/terminal_tool.py` _active_environments dict — Environment access
- `tools/registry.py` — Tool registration infrastructure
- `gateway/platforms/base.py` send_document/send_image_file/send_video signatures — Already defined
### Code that needs to be WRITTEN from scratch:
1. `tools/send_file_tool.py` (~150 lines):
- File extraction from each environment backend type
- Local file cache management
- Registry registration
2. Telegram `send_document` + `send_image_file` + `send_video` overrides (~40 lines)
3. Discord `send_document` + `send_image_file` + `send_video` overrides (~50 lines)
### Total effort: ~240 lines of new code, ~5 lines of config changes
## Key Environment-Specific Extract Strategies
| Backend | Extract Method | Speed | Complexity |
|------------|-------------------------------|----------|------------|
| local | shutil.copy / direct path | Instant | None |
| docker | `docker cp container:path .` | Fast | Low |
| docker+vol | Direct host path access | Instant | None |
| ssh | `scp -o ControlPath=...` | Fast | Low |
| modal | base64-over-terminal | Moderate | Medium |
| singularity| Direct path (overlay mount) | Fast | Low |
## Data Flow Summary
```
Agent calls send_file(file_path="/workspace/output.pdf", caption="Here's the report")
send_file_tool.py:
1. Get environment from _active_environments[task_id]
2. Detect backend type (docker/ssh/modal/local)
3. Extract file to ~/.hermes/file_cache/{uuid}_{filename}
4. Return: '{"success": true, "media_tag": "MEDIA:/home/user/.hermes/file_cache/abc123_output.pdf"}'
LLM includes MEDIA: tag in its response text
BasePlatformAdapter._process_message_background():
1. extract_media(response) → finds MEDIA:/path
2. Checks extension: .pdf → send_document()
3. Calls platform-specific send_document(chat_id, file_path, caption)
TelegramAdapter.send_document() / DiscordAdapter.send_document():
Opens file, sends via platform API as native document attachment
User receives downloadable file in chat
```

View File

@@ -40,7 +40,7 @@ This directory contains the integration layer between **hermes-agent's** tool-ca
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, ssh, singularity)
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, daytona, ssh, singularity)
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` which queries `tools/registry.py`)
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
@@ -195,8 +195,12 @@ environments/
│ └── hermes_swe_env.py
└── benchmarks/ # Evaluation benchmarks
── terminalbench_2/
└── terminalbench2_env.py
── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
└── terminalbench2_env.py
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
│ └── tblite_env.py
└── yc_bench/ # Long-horizon strategic benchmark
└── yc_bench_env.py
```
## Concrete Environments
@@ -324,7 +328,7 @@ For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
| `distribution` | Probabilistic toolset distribution name | `None` |
| `max_agent_turns` | Max LLM calls per rollout | `30` |
| `agent_temperature` | Sampling temperature | `1.0` |
| `terminal_backend` | `local`, `docker`, `modal`, `ssh`, `singularity` | `local` |
| `terminal_backend` | `local`, `docker`, `modal`, `daytona`, `ssh`, `singularity` | `local` |
| `system_prompt` | System message for the agent | `None` |
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |

View File

@@ -18,9 +18,14 @@ Benchmarks (eval-only):
- benchmarks/terminalbench_2/: Terminal-Bench 2.0 evaluation
"""
from environments.agent_loop import AgentResult, HermesAgentLoop
from environments.tool_context import ToolContext
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
try:
from environments.agent_loop import AgentResult, HermesAgentLoop
from environments.tool_context import ToolContext
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
except ImportError:
# atroposlib not installed — environments are unavailable but
# submodules like tool_call_parsers can still be imported directly.
pass
__all__ = [
"AgentResult",

View File

@@ -23,7 +23,7 @@ from typing import Any, Dict, List, Optional, Set
from model_tools import handle_function_call
# Thread pool for running sync tool calls that internally use asyncio.run()
# (e.g., mini-swe-agent's modal/docker backends). Running them in a separate
# (e.g., mini-swe-agent's modal/docker/daytona backends). Running them in a separate
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
@@ -249,23 +249,62 @@ class HermesAgentLoop:
reasoning = _extract_reasoning_from_message(assistant_msg)
reasoning_per_turn.append(reasoning)
# Check for tool calls -- standard OpenAI spec
# Check for tool calls -- standard OpenAI spec.
# Fallback: if response has no structured tool_calls but content
# contains raw tool call tags (e.g. <tool_call>), parse them using
# hermes-agent's standalone parsers. This handles the case where
# ManagedServer's ToolCallTranslator couldn't parse because vLLM
# isn't installed.
if (
not assistant_msg.tool_calls
and assistant_msg.content
and self.tool_schemas
and "<tool_call>" in (assistant_msg.content or "")
):
try:
from environments.tool_call_parsers import get_parser
fallback_parser = get_parser("hermes")
parsed_content, parsed_calls = fallback_parser.parse(
assistant_msg.content
)
if parsed_calls:
assistant_msg.tool_calls = parsed_calls
if parsed_content is not None:
assistant_msg.content = parsed_content
logger.debug(
"Fallback parser extracted %d tool calls from raw content",
len(parsed_calls),
)
except Exception:
pass # Fall through to no tool calls
if assistant_msg.tool_calls:
# Normalize tool calls to dicts — they may come as objects
# (OpenAI API) or dicts (vLLM ToolCallTranslator).
def _tc_to_dict(tc):
if isinstance(tc, dict):
return {
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
"type": "function",
"function": {
"name": tc.get("function", {}).get("name", tc.get("name", "")),
"arguments": tc.get("function", {}).get("arguments", tc.get("arguments", "{}")),
},
}
return {
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
# Build the assistant message dict for conversation history
msg_dict: Dict[str, Any] = {
"role": "assistant",
"content": assistant_msg.content or "",
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in assistant_msg.tool_calls
],
"tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls],
}
# Preserve reasoning_content for multi-turn chat template handling
@@ -278,8 +317,13 @@ class HermesAgentLoop:
# Execute each tool call via hermes-agent's dispatch
for tc in assistant_msg.tool_calls:
tool_name = tc.function.name
tool_args_raw = tc.function.arguments
# Handle both object (OpenAI) and dict (vLLM) formats
if isinstance(tc, dict):
tool_name = tc.get("function", {}).get("name", tc.get("name", ""))
tool_args_raw = tc.get("function", {}).get("arguments", tc.get("arguments", "{}"))
else:
tool_name = tc.function.name
tool_args_raw = tc.function.arguments
# Validate tool name
if tool_name not in self.valid_tool_names:
@@ -336,7 +380,7 @@ class HermesAgentLoop:
tool_elapsed = _time.monotonic() - tool_submit_time
else:
# Run tool calls in a thread pool so backends that
# use asyncio.run() internally (modal, docker) get
# use asyncio.run() internally (modal, docker, daytona) get
# a clean event loop instead of deadlocking.
loop = asyncio.get_event_loop()
# Capture current tool_name/args for the lambda
@@ -390,10 +434,11 @@ class HermesAgentLoop:
pass
# Add tool response to conversation
tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"tool_call_id": tc_id,
"content": tool_result,
}
)

View File

@@ -0,0 +1,38 @@
# OpenThoughts-TBLite Evaluation -- Docker Backend (Local Compute)
#
# Runs tasks in Docker containers on the local machine.
# Sandboxed like Modal but no cloud costs. Good for dev/testing.
#
# Usage:
# python environments/benchmarks/tblite/tblite_env.py evaluate \
# --config environments/benchmarks/tblite/local.yaml
#
# # Override concurrency:
# python environments/benchmarks/tblite/tblite_env.py evaluate \
# --config environments/benchmarks/tblite/local.yaml \
# --env.eval_concurrency 4
env:
enabled_toolsets: ["terminal", "file"]
max_agent_turns: 60
max_token_length: 32000
agent_temperature: 0.8
terminal_backend: "docker"
terminal_timeout: 300
tool_pool_size: 16
dataset_name: "NousResearch/openthoughts-tblite"
test_timeout: 600
task_timeout: 1200
eval_concurrency: 8 # max 8 tasks at once
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
use_wandb: false
wandb_name: "openthoughts-tblite-local"
ensure_scores_are_not_same: false
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite-local"
openai:
base_url: "https://openrouter.ai/api/v1"
model_name: "anthropic/claude-sonnet-4"
server_type: "openai"
health_check: false
# api_key loaded from OPENROUTER_API_KEY in .env

View File

@@ -0,0 +1,40 @@
# OpenThoughts-TBLite Evaluation -- Local vLLM Backend
#
# Runs against a local vLLM server with Docker sandboxes.
#
# Start the vLLM server from the atropos directory:
# python -m example_trainer.vllm_api_server \
# --model Qwen/Qwen3-4B-Instruct-2507 \
# --port 9001 \
# --gpu-memory-utilization 0.8 \
# --max-model-len=32000
#
# Then run:
# python environments/benchmarks/tblite/tblite_env.py evaluate \
# --config environments/benchmarks/tblite/local_vllm.yaml
env:
enabled_toolsets: ["terminal", "file"]
max_agent_turns: 60
max_token_length: 16000
agent_temperature: 0.6
terminal_backend: "docker"
terminal_timeout: 300
tool_pool_size: 16
dataset_name: "NousResearch/openthoughts-tblite"
test_timeout: 600
task_timeout: 1200
eval_concurrency: 8
tool_call_parser: "hermes"
system_prompt: "You are an expert terminal agent. You MUST use the provided tools to complete tasks. Use the terminal tool to run shell commands, read_file to read files, write_file to write files, search_files to search, and patch to edit files. Do NOT write out solutions as text - execute them using the tools. Always start by exploring the environment with terminal commands."
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
use_wandb: false
wandb_name: "tblite-qwen3-4b-instruct"
ensure_scores_are_not_same: false
data_dir_to_save_evals: "environments/benchmarks/evals/tblite-qwen3-4b-local"
openai:
base_url: "http://localhost:9001"
model_name: "Qwen/Qwen3-4B-Instruct-2507"
server_type: "vllm"
health_check: false

View File

@@ -118,6 +118,14 @@ class TerminalBench2EvalConfig(HermesAgentEnvConfig):
"Tasks exceeding this are scored as FAIL. Default 30 minutes.",
)
# --- Eval concurrency ---
eval_concurrency: int = Field(
default=0,
description="Maximum number of tasks to evaluate in parallel. "
"0 means unlimited (all tasks run concurrently). "
"Set to 8 for local backends to avoid overwhelming the machine.",
)
# Tasks that cannot run properly on Modal and are excluded from scoring.
MODAL_INCOMPATIBLE_TASKS = {
@@ -429,8 +437,13 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
"error": "no_image",
}
# --- 2. Register per-task Modal image override ---
register_task_env_overrides(task_id, {"modal_image": modal_image})
# --- 2. Register per-task image override ---
# Set both modal_image and docker_image so the task image is used
# regardless of which backend is configured.
register_task_env_overrides(task_id, {
"modal_image": modal_image,
"docker_image": modal_image,
})
logger.info(
"Task %s: registered image override for task_id %s",
task_name, task_id[:8],
@@ -445,17 +458,37 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
messages.append({"role": "user", "content": self.format_prompt(eval_item)})
# --- 4. Run agent loop ---
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
# Use ManagedServer (Phase 2) for vLLM/SGLang backends to get
# token-level tracking via /generate. Falls back to direct
# ServerManager (Phase 1) for OpenAI endpoints.
if self._use_managed_server():
async with self.server.managed_server(
tokenizer=self.tokenizer,
preserve_think_blocks=bool(self.config.thinking_mode),
) as managed:
agent = HermesAgentLoop(
server=managed,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
else:
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
# --- 5. Verify -- run test suite in the agent's sandbox ---
# Skip verification if the agent produced no meaningful output
@@ -655,13 +688,19 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
async def _eval_with_timeout(self, item: Dict[str, Any]) -> Dict:
"""
Wrap rollout_and_score_eval with a per-task wall-clock timeout.
Wrap rollout_and_score_eval with a per-task wall-clock timeout
and optional concurrency limit via semaphore.
If the task exceeds task_timeout seconds, it's automatically scored
as FAIL. This prevents any single task from hanging indefinitely.
"""
task_name = item.get("task_name", "unknown")
category = item.get("category", "unknown")
# Acquire concurrency semaphore if configured
if self._eval_semaphore:
await self._eval_semaphore.acquire()
try:
return await asyncio.wait_for(
self.rollout_and_score_eval(item),
@@ -679,6 +718,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
}
self._save_result(out)
return out
finally:
if self._eval_semaphore:
self._eval_semaphore.release()
async def evaluate(self, *args, **kwargs) -> None:
"""
@@ -696,6 +738,13 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
"""
start_time = time.time()
# Set up concurrency limit if configured
if self.config.eval_concurrency > 0:
self._eval_semaphore = asyncio.Semaphore(self.config.eval_concurrency)
print(f" Eval concurrency: {self.config.eval_concurrency} tasks at a time")
else:
self._eval_semaphore = None
# Route all logging through tqdm.write() so the progress bar stays
# pinned at the bottom while log lines scroll above it.
from tqdm import tqdm

View File

@@ -0,0 +1,115 @@
# YC-Bench: Long-Horizon Agent Benchmark
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
## Setup
```bash
# Install yc-bench (optional dependency)
pip install "hermes-agent[yc-bench]"
# Or install from source
git clone https://github.com/collinear-ai/yc-bench
cd yc-bench && pip install -e .
# Verify
yc-bench --help
```
## Running
```bash
# From the repo root:
bash environments/benchmarks/yc_bench/run_eval.sh
# Or directly:
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
--config environments/benchmarks/yc_bench/default.yaml
# Override model:
bash environments/benchmarks/yc_bench/run_eval.sh \
--openai.model_name anthropic/claude-opus-4-20250514
# Quick single-preset test:
bash environments/benchmarks/yc_bench/run_eval.sh \
--env.presets '["fast_test"]' --env.seeds '[1]'
```
## How It Works
### Architecture
```
HermesAgentLoop (our agent)
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
-> ... (100-500 turns per run)
```
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
### Simulation Mechanics
- **4 skill domains**: research, inference, data_environment, training
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
- **Financial pressure**: Monthly payroll, bankruptcy = game over
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
### Difficulty Presets
| Preset | Employees | Tasks | Focus |
|-----------|-----------|-------|-------|
| tutorial | 3 | 50 | Basic loop mechanics |
| easy | 5 | 100 | Throughput awareness |
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
| **hard** | 7 | 200 | Precise ETA reasoning |
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
### Scoring
```
composite = 0.5 × survival + 0.5 × normalised_funds
```
- **Survival** (binary): Did the company avoid bankruptcy?
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
## Configuration
Key fields in `default.yaml`:
| Field | Default | Description |
|-------|---------|-------------|
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
| `max_agent_turns` | 200 | Max LLM calls per run |
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
| `survival_weight` | 0.5 | Weight of survival in composite score |
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
| `horizon_years` | null | Override horizon (null = auto from preset) |
## Cost & Time Estimates
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
| Preset | Turns | Time | Est. Cost |
|--------|-------|------|-----------|
| fast_test | ~50 | 5-10 min | $1-5 |
| medium | ~200 | 20-40 min | $5-15 |
| hard | ~300 | 30-60 min | $10-25 |
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
## References
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)

View File

@@ -0,0 +1,43 @@
# YC-Bench Evaluation -- Default Configuration
#
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
#
# Requires: pip install "hermes-agent[yc-bench]"
#
# Usage:
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
# --config environments/benchmarks/yc_bench/default.yaml
#
# # Override model:
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
# --config environments/benchmarks/yc_bench/default.yaml \
# --openai.model_name anthropic/claude-opus-4-20250514
env:
enabled_toolsets: ["terminal"]
max_agent_turns: 200
max_token_length: 32000
agent_temperature: 0.0
terminal_backend: "local"
terminal_timeout: 60
presets: ["fast_test", "medium", "hard"]
seeds: [1, 2, 3]
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
survival_weight: 0.5 # weight of binary survival in composite score
funds_weight: 0.5 # weight of normalised final funds in composite score
db_dir: "/tmp/yc_bench_dbs"
company_name: "BenchCo"
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
use_wandb: true
wandb_name: "yc-bench"
ensure_scores_are_not_same: false
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
openai:
base_url: "https://openrouter.ai/api/v1"
model_name: "anthropic/claude-sonnet-4.6"
server_type: "openai"
health_check: false
# api_key loaded from OPENROUTER_API_KEY in .env

View File

@@ -0,0 +1,34 @@
#!/bin/bash
# YC-Bench Evaluation
#
# Requires: pip install "hermes-agent[yc-bench]"
#
# Run from repo root:
# bash environments/benchmarks/yc_bench/run_eval.sh
#
# Override model:
# bash environments/benchmarks/yc_bench/run_eval.sh \
# --openai.model_name anthropic/claude-opus-4-20250514
#
# Run a single preset:
# bash environments/benchmarks/yc_bench/run_eval.sh \
# --env.presets '["fast_test"]' --env.seeds '[1]'
set -euo pipefail
mkdir -p logs evals/yc-bench
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
echo "YC-Bench Evaluation"
echo "Log: $LOG_FILE"
echo ""
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
--config environments/benchmarks/yc_bench/default.yaml \
"$@" \
2>&1 | tee "$LOG_FILE"
echo ""
echo "Log saved to: $LOG_FILE"

View File

@@ -0,0 +1,847 @@
"""
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
interacting exclusively via CLI subprocess calls against a SQLite-backed
discrete-event simulation.
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
multi-turn strategic coherence -- whether an agent can manage compounding
decisions over hundreds of turns without going bankrupt.
This is an eval-only environment. Run via:
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
--config environments/benchmarks/yc_bench/default.yaml
The evaluate flow:
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
2. evaluate() -- Iterates over all runs sequentially through:
a. rollout_and_score_eval() -- Per-run agent loop
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
- Runs HermesAgentLoop with terminal tool only
- Reads final SQLite DB to extract score
- Returns survival (0/1) + normalised funds score
b. Aggregates per-preset and overall metrics
c. Logs results via evaluate_log() and wandb
Key features:
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
- Deterministic: same seed + preset = same world (SHA256-based RNG)
- Multi-dimensional scoring: survival + normalised final funds
- Per-preset difficulty breakdown in results
- Isolated SQLite DB per run (no cross-run state leakage)
Requires: pip install hermes-agent[yc-bench]
"""
import asyncio
import datetime
import json
import logging
import math
import os
import sqlite3
import subprocess
import sys
import threading
import time
import uuid
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
from pydantic import Field
from atroposlib.envs.base import EvalHandlingEnum
from atroposlib.envs.server_handling.server_manager import APIServerConfig
from environments.agent_loop import HermesAgentLoop
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
logger = logging.getLogger(__name__)
# =============================================================================
# System prompt
# =============================================================================
YC_BENCH_SYSTEM_PROMPT = """\
You are the autonomous CEO of an early-stage AI startup in a deterministic
business simulation. You manage the company exclusively through the `yc-bench`
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
without going bankrupt, while **maximising final funds**.
## Simulation Mechanics
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige 1))`.
- **Domains**: There are 4 skill domains: **research**, **inference**,
**data_environment**, and **training**. Each has its own prestige level
(1.0-10.0). Higher prestige unlocks better-paying tasks.
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
is the number of active tasks assigned to that employee. Focus beats breadth.
- **Payroll**: Deducted automatically on the first business day of each month.
Running out of funds = bankruptcy = game over.
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
Time only advances when you call `yc-bench sim resume`.
## Task Lifecycle
1. Browse market tasks with `market browse`
2. Accept a task with `task accept` (this sets its deadline)
3. Assign employees with `task assign`
4. Dispatch with `task dispatch` to start work
5. Call `sim resume` to advance time and let employees make progress
6. Tasks complete when all domain requirements are fulfilled
**Penalties for failure vary by difficulty preset.** Completing a task on time
earns full reward + prestige gain. Missing a deadline or cancelling a task
incurs prestige penalties -- cancelling is always more costly than letting a
task fail, so cancel only as a last resort.
## CLI Commands
### Observe
- `yc-bench company status` -- funds, prestige, runway
- `yc-bench employee list` -- skills, salary, active tasks
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
- `yc-bench task list [--status active|planned]` -- your tasks
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
- `yc-bench report monthly` -- monthly P&L
### Act
- `yc-bench task accept --task-id UUID` -- accept from market
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
- `yc-bench sim resume` -- advance simulation clock
### Memory (persists across context truncation)
- `yc-bench scratchpad read` -- read your persistent notes
- `yc-bench scratchpad write --content "text"` -- overwrite notes
- `yc-bench scratchpad append --content "text"` -- append to notes
- `yc-bench scratchpad clear` -- clear notes
## Strategy Guidelines
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
high-reward tasks. Don't spread thin across all 4 domains early on.
2. **Focus employees** -- assigning one employee to many tasks halves their
throughput per additional task. Keep assignments concentrated.
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
employee assignments. This persists even if conversation context is truncated.
4. **Monitor runway** -- always know how many months of payroll you can cover.
Accept high-reward tasks before payroll dates.
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
into prestige loss, locking you out of profitable contracts.
6. Use `finance ledger` and `report monthly` to track revenue trends.
## Your Turn
Each turn:
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
2. Check for completed tasks and pending deadlines.
3. Browse market for profitable tasks within your prestige level.
4. Accept, assign, and dispatch tasks strategically.
5. Call `yc-bench sim resume` to advance time.
6. Repeat until the simulation ends.
Think step by step before acting."""
# Starting funds in cents ($250,000)
INITIAL_FUNDS_CENTS = 25_000_000
# Default horizon per preset (years)
_PRESET_HORIZONS = {
"tutorial": 1,
"easy": 1,
"medium": 1,
"hard": 1,
"nightmare": 1,
"fast_test": 1,
"default": 3,
"high_reward": 1,
}
# =============================================================================
# Configuration
# =============================================================================
class YCBenchEvalConfig(HermesAgentEnvConfig):
"""
Configuration for the YC-Bench evaluation environment.
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
preset selection, seed control, scoring, and simulation parameters.
"""
presets: List[str] = Field(
default=["fast_test", "medium", "hard"],
description="YC-Bench preset names to evaluate.",
)
seeds: List[int] = Field(
default=[1, 2, 3],
description="Random seeds -- each preset x seed = one run.",
)
run_timeout: int = Field(
default=3600,
description="Maximum wall-clock seconds per run. Default 60 minutes.",
)
survival_weight: float = Field(
default=0.5,
description="Weight of survival (0/1) in composite score.",
)
funds_weight: float = Field(
default=0.5,
description="Weight of normalised final funds in composite score.",
)
db_dir: str = Field(
default="/tmp/yc_bench_dbs",
description="Directory for per-run SQLite databases.",
)
horizon_years: Optional[int] = Field(
default=None,
description=(
"Simulation horizon in years. If None (default), inferred from "
"preset name (1 year for most, 3 for 'default')."
),
)
company_name: str = Field(
default="BenchCo",
description="Name of the simulated company.",
)
start_date: str = Field(
default="01/01/2025",
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
)
# =============================================================================
# Scoring helpers
# =============================================================================
def _read_final_score(db_path: str) -> Dict[str, Any]:
"""
Read final game state from a YC-Bench SQLite database.
Returns dict with final_funds_cents (int), survived (bool),
terminal_reason (str).
Note: yc-bench table names are plural -- 'companies' not 'company',
'sim_events' not 'simulation_log'.
"""
if not os.path.exists(db_path):
logger.warning("DB not found at %s", db_path)
return {
"final_funds_cents": 0,
"survived": False,
"terminal_reason": "db_missing",
}
conn = None
try:
conn = sqlite3.connect(db_path)
cur = conn.cursor()
# Read final funds from the 'companies' table
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
row = cur.fetchone()
funds = row[0] if row else 0
# Determine terminal reason from 'sim_events' table
terminal_reason = "unknown"
try:
cur.execute(
"SELECT event_type FROM sim_events "
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
"ORDER BY scheduled_at DESC LIMIT 1"
)
event_row = cur.fetchone()
if event_row:
terminal_reason = event_row[0]
except sqlite3.OperationalError:
# Table may not exist if simulation didn't progress
pass
survived = funds >= 0 and terminal_reason != "bankruptcy"
return {
"final_funds_cents": funds,
"survived": survived,
"terminal_reason": terminal_reason,
}
except Exception as e:
logger.error("Failed to read DB %s: %s", db_path, e)
return {
"final_funds_cents": 0,
"survived": False,
"terminal_reason": f"db_error: {e}",
}
finally:
if conn:
conn.close()
def _compute_composite_score(
final_funds_cents: int,
survived: bool,
survival_weight: float = 0.5,
funds_weight: float = 0.5,
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
) -> float:
"""
Compute composite score from survival and final funds.
Score = survival_weight * survival_score
+ funds_weight * normalised_funds_score
Normalised funds uses log-scale relative to initial capital:
- funds <= 0: 0.0
- funds == initial: ~0.15
- funds == 10x: ~0.52
- funds == 100x: 1.0
"""
survival_score = 1.0 if survived else 0.0
if final_funds_cents <= 0:
funds_score = 0.0
else:
max_ratio = 100.0
ratio = final_funds_cents / max(initial_funds_cents, 1)
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
return survival_weight * survival_score + funds_weight * funds_score
# =============================================================================
# Main Environment
# =============================================================================
class YCBenchEvalEnv(HermesAgentBaseEnv):
"""
YC-Bench long-horizon agent benchmark environment (eval-only).
Each eval item is a (preset, seed) pair. The environment initialises the
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
a competing built-in agent loop). The HermesAgentLoop then drives the
interaction by calling individual yc-bench CLI commands via the terminal tool.
After the agent loop ends, the SQLite DB is read to extract the final score.
Scoring:
composite = 0.5 * survival + 0.5 * normalised_funds
"""
name = "yc-bench"
env_config_cls = YCBenchEvalConfig
@classmethod
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
env_config = YCBenchEvalConfig(
enabled_toolsets=["terminal"],
disabled_toolsets=None,
distribution=None,
max_agent_turns=200,
max_token_length=32000,
agent_temperature=0.0,
system_prompt=YC_BENCH_SYSTEM_PROMPT,
terminal_backend="local",
terminal_timeout=60,
presets=["fast_test", "medium", "hard"],
seeds=[1, 2, 3],
run_timeout=3600,
survival_weight=0.5,
funds_weight=0.5,
db_dir="/tmp/yc_bench_dbs",
eval_handling=EvalHandlingEnum.STOP_TRAIN,
group_size=1,
steps_per_eval=1,
total_steps=1,
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
use_wandb=True,
wandb_name="yc-bench",
ensure_scores_are_not_same=False,
)
server_configs = [
APIServerConfig(
base_url="https://openrouter.ai/api/v1",
model_name="anthropic/claude-sonnet-4.6",
server_type="openai",
api_key=os.getenv("OPENROUTER_API_KEY", ""),
health_check=False,
)
]
return env_config, server_configs
# =========================================================================
# Setup
# =========================================================================
async def setup(self):
"""Verify yc-bench is installed and build the eval matrix."""
# Verify yc-bench CLI is available
try:
result = subprocess.run(
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
)
if result.returncode != 0:
raise FileNotFoundError
except (FileNotFoundError, subprocess.TimeoutExpired):
raise RuntimeError(
"yc-bench CLI not found. Install with:\n"
' pip install "hermes-agent[yc-bench]"\n'
"Or: git clone https://github.com/collinear-ai/yc-bench "
"&& cd yc-bench && pip install -e ."
)
print("yc-bench CLI verified.")
# Build eval matrix: preset x seed
self.all_eval_items = [
{"preset": preset, "seed": seed}
for preset in self.config.presets
for seed in self.config.seeds
]
self.iter = 0
os.makedirs(self.config.db_dir, exist_ok=True)
self.eval_metrics: List[Tuple[str, float]] = []
# Streaming JSONL log for crash-safe result persistence
log_dir = os.path.join(os.path.dirname(__file__), "logs")
os.makedirs(log_dir, exist_ok=True)
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
self._streaming_file = open(self._streaming_path, "w")
self._streaming_lock = threading.Lock()
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
for item in self.all_eval_items:
print(f" preset={item['preset']!r} seed={item['seed']}")
print(f"Streaming results to: {self._streaming_path}\n")
def _save_result(self, result: Dict[str, Any]):
"""Write a single run result to the streaming JSONL file immediately."""
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
return
with self._streaming_lock:
self._streaming_file.write(
json.dumps(result, ensure_ascii=False, default=str) + "\n"
)
self._streaming_file.flush()
# =========================================================================
# Training pipeline stubs (eval-only -- not used)
# =========================================================================
async def get_next_item(self):
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
self.iter += 1
return item
def format_prompt(self, item: Dict[str, Any]) -> str:
preset = item["preset"]
seed = item["seed"]
return (
f"A new YC-Bench simulation has been initialized "
f"(preset='{preset}', seed={seed}).\n"
f"Your company '{self.config.company_name}' is ready.\n\n"
"Begin by calling:\n"
"1. `yc-bench company status` -- see your starting funds and prestige\n"
"2. `yc-bench employee list` -- see your team and their skills\n"
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
"you can take\n\n"
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
"`yc-bench sim resume` to advance time. Repeat this loop until the "
"simulation ends (horizon reached or bankruptcy)."
)
async def compute_reward(self, item, result, ctx) -> float:
return 0.0
async def collect_trajectories(self, item):
return None, []
async def score(self, rollout_group_data):
return None
# =========================================================================
# Per-run evaluation
# =========================================================================
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
"""
Evaluate a single (preset, seed) run.
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
3. Runs HermesAgentLoop with terminal tool
4. Reads SQLite DB to compute final score
5. Returns result dict with survival, funds, and composite score
"""
preset = eval_item["preset"]
seed = eval_item["seed"]
run_id = str(uuid.uuid4())[:8]
run_key = f"{preset}_seed{seed}_{run_id}"
from tqdm import tqdm
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
run_start = time.time()
# Isolated DB per run -- prevents cross-run state leakage
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
os.environ["YC_BENCH_EXPERIMENT"] = preset
# Determine horizon: explicit config override > preset lookup > default 1
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
try:
# ----------------------------------------------------------
# Step 1: Initialise the simulation via CLI
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
# `yc-bench run` starts yc-bench's own LLM agent loop (via
# LiteLLM), which would compete with our HermesAgentLoop.
# `sim init` just sets up the world and returns.
# ----------------------------------------------------------
init_cmd = [
"yc-bench", "sim", "init",
"--seed", str(seed),
"--start-date", self.config.start_date,
"--company-name", self.config.company_name,
"--horizon-years", str(horizon),
]
init_result = subprocess.run(
init_cmd, capture_output=True, text=True, timeout=30,
)
if init_result.returncode != 0:
error_msg = (init_result.stderr or init_result.stdout).strip()
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
# ----------------------------------------------------------
# Step 2: Run the HermesAgentLoop
# ----------------------------------------------------------
tools, valid_names = self._resolve_tools_for_group()
messages: List[Dict[str, Any]] = [
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
{"role": "user", "content": self.format_prompt(eval_item)},
]
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=run_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
# ----------------------------------------------------------
# Step 3: Read final score from the simulation DB
# ----------------------------------------------------------
score_data = _read_final_score(db_path)
final_funds = score_data["final_funds_cents"]
survived = score_data["survived"]
terminal_reason = score_data["terminal_reason"]
composite = _compute_composite_score(
final_funds_cents=final_funds,
survived=survived,
survival_weight=self.config.survival_weight,
funds_weight=self.config.funds_weight,
)
elapsed = time.time() - run_start
status = "SURVIVED" if survived else "BANKRUPT"
if final_funds >= 0:
funds_str = f"${final_funds / 100:,.0f}"
else:
funds_str = f"-${abs(final_funds) / 100:,.0f}"
tqdm.write(
f" [{status}] preset={preset!r} seed={seed} "
f"funds={funds_str} score={composite:.3f} "
f"turns={result.turns_used} ({elapsed:.0f}s)"
)
out = {
"preset": preset,
"seed": seed,
"survived": survived,
"final_funds_cents": final_funds,
"final_funds_usd": final_funds / 100,
"terminal_reason": terminal_reason,
"composite_score": composite,
"turns_used": result.turns_used,
"finished_naturally": result.finished_naturally,
"elapsed_seconds": elapsed,
"db_path": db_path,
"messages": result.messages,
}
self._save_result(out)
return out
except Exception as e:
elapsed = time.time() - run_start
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
tqdm.write(
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
)
out = {
"preset": preset,
"seed": seed,
"survived": False,
"final_funds_cents": 0,
"final_funds_usd": 0.0,
"terminal_reason": f"error: {e}",
"composite_score": 0.0,
"turns_used": 0,
"error": str(e),
"elapsed_seconds": elapsed,
}
self._save_result(out)
return out
# =========================================================================
# Evaluate
# =========================================================================
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
"""Wrap a single rollout with a wall-clock timeout."""
preset = item["preset"]
seed = item["seed"]
try:
return await asyncio.wait_for(
self.rollout_and_score_eval(item),
timeout=self.config.run_timeout,
)
except asyncio.TimeoutError:
from tqdm import tqdm
tqdm.write(
f" [TIMEOUT] preset={preset!r} seed={seed} "
f"(exceeded {self.config.run_timeout}s)"
)
out = {
"preset": preset,
"seed": seed,
"survived": False,
"final_funds_cents": 0,
"final_funds_usd": 0.0,
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
"composite_score": 0.0,
"turns_used": 0,
"error": "timeout",
}
self._save_result(out)
return out
async def evaluate(self, *args, **kwargs) -> None:
"""
Run YC-Bench evaluation over all (preset, seed) combinations.
Runs sequentially -- each run is 100-500 turns, parallelising would
be prohibitively expensive and cause env var conflicts.
"""
start_time = time.time()
from tqdm import tqdm
# --- tqdm-compatible logging handler (TB2 pattern) ---
class _TqdmHandler(logging.Handler):
def emit(self, record):
try:
tqdm.write(self.format(record))
except Exception:
self.handleError(record)
root = logging.getLogger()
handler = _TqdmHandler()
handler.setFormatter(
logging.Formatter("%(levelname)s %(name)s: %(message)s")
)
root.handlers = [handler]
for noisy in ("httpx", "openai"):
logging.getLogger(noisy).setLevel(logging.WARNING)
# --- Print config summary ---
print(f"\n{'='*60}")
print("Starting YC-Bench Evaluation")
print(f"{'='*60}")
print(f" Presets: {self.config.presets}")
print(f" Seeds: {self.config.seeds}")
print(f" Total runs: {len(self.all_eval_items)}")
print(f" Max turns/run: {self.config.max_agent_turns}")
print(f" Run timeout: {self.config.run_timeout}s")
print(f"{'='*60}\n")
results = []
pbar = tqdm(
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
)
try:
for item in self.all_eval_items:
result = await self._run_with_timeout(item)
results.append(result)
survived_count = sum(1 for r in results if r.get("survived"))
pbar.set_postfix_str(
f"survived={survived_count}/{len(results)}"
)
pbar.update(1)
except (KeyboardInterrupt, asyncio.CancelledError):
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
pbar.close()
try:
from tools.terminal_tool import cleanup_all_environments
cleanup_all_environments()
except Exception:
pass
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
self._streaming_file.close()
return
pbar.close()
end_time = time.time()
# --- Compute metrics ---
valid = [r for r in results if r is not None]
if not valid:
print("Warning: No valid results.")
return
total = len(valid)
survived_total = sum(1 for r in valid if r.get("survived"))
survival_rate = survived_total / total if total else 0.0
avg_score = (
sum(r.get("composite_score", 0) for r in valid) / total
if total
else 0.0
)
preset_results: Dict[str, List[Dict]] = defaultdict(list)
for r in valid:
preset_results[r["preset"]].append(r)
eval_metrics = {
"eval/survival_rate": survival_rate,
"eval/avg_composite_score": avg_score,
"eval/total_runs": total,
"eval/survived_runs": survived_total,
"eval/evaluation_time_seconds": end_time - start_time,
}
for preset, items in sorted(preset_results.items()):
ps = sum(1 for r in items if r.get("survived"))
pt = len(items)
pa = (
sum(r.get("composite_score", 0) for r in items) / pt
if pt
else 0
)
key = preset.replace("-", "_")
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
eval_metrics[f"eval/avg_score_{key}"] = pa
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
# --- Print summary ---
print(f"\n{'='*60}")
print("YC-Bench Evaluation Results")
print(f"{'='*60}")
print(
f"Overall survival rate: {survival_rate:.1%} "
f"({survived_total}/{total})"
)
print(f"Average composite score: {avg_score:.4f}")
print(f"Evaluation time: {end_time - start_time:.1f}s")
print("\nPer-preset breakdown:")
for preset, items in sorted(preset_results.items()):
ps = sum(1 for r in items if r.get("survived"))
pt = len(items)
pa = (
sum(r.get("composite_score", 0) for r in items) / pt
if pt
else 0
)
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
for r in items:
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
funds = r.get("final_funds_usd", 0)
print(
f" seed={r['seed']} [{status}] "
f"${funds:,.0f} "
f"score={r.get('composite_score', 0):.3f}"
)
print(f"{'='*60}\n")
# --- Log results ---
samples = [
{k: v for k, v in r.items() if k != "messages"} for r in valid
]
try:
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": self.config.agent_temperature,
"max_tokens": self.config.max_token_length,
"max_agent_turns": self.config.max_agent_turns,
},
)
except Exception as e:
print(f"Error logging results: {e}")
# --- Cleanup (TB2 pattern) ---
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
self._streaming_file.close()
print(f"Results saved to: {self._streaming_path}")
try:
from tools.terminal_tool import cleanup_all_environments
cleanup_all_environments()
except Exception:
pass
try:
from environments.agent_loop import _tool_executor
_tool_executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
# =========================================================================
# Wandb logging
# =========================================================================
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log YC-Bench-specific metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
for k, v in self.eval_metrics:
wandb_metrics[k] = v
self.eval_metrics = []
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
YCBenchEvalEnv.cli()

View File

@@ -114,8 +114,8 @@ class HermesAgentEnvConfig(BaseEnvConfig):
# --- Terminal backend ---
terminal_backend: str = Field(
default="local",
description="Terminal backend: 'local', 'docker', 'modal', 'ssh', 'singularity'. "
"Modal recommended for production RL (cloud isolation per rollout).",
description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
"Modal or Daytona recommended for production RL (cloud isolation per rollout).",
)
terminal_timeout: int = Field(
default=120,
@@ -229,6 +229,12 @@ class HermesAgentBaseEnv(BaseEnv):
from environments.agent_loop import resize_tool_pool
resize_tool_pool(config.tool_pool_size)
# Set tool_parser on the ServerManager so ManagedServer uses it
# for bidirectional tool call translation (raw text ↔ OpenAI tool_calls).
if hasattr(self.server, 'tool_parser'):
self.server.tool_parser = config.tool_call_parser
print(f"🔧 Tool parser: {config.tool_call_parser}")
# Current group's resolved tools (set in collect_trajectories)
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
@@ -466,22 +472,14 @@ class HermesAgentBaseEnv(BaseEnv):
# Run the agent loop
result: AgentResult
if self._use_managed_server():
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
# Load the tool call parser from registry based on config
from environments.tool_call_parsers import get_parser
try:
tc_parser = get_parser(self.config.tool_call_parser)
except KeyError:
logger.warning(
"Tool call parser '%s' not found, falling back to 'hermes'",
self.config.tool_call_parser,
)
tc_parser = get_parser("hermes")
# Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs
# tool_parser is set on ServerManager in __init__ and passed through
# to ManagedServer, which uses ToolCallTranslator for bidirectional
# translation between raw text and OpenAI tool_calls.
try:
async with self.server.managed_server(
tokenizer=self.tokenizer,
tool_call_parser=tc_parser,
preserve_think_blocks=bool(self.config.thinking_mode),
) as managed:
agent = HermesAgentLoop(
server=managed,

View File

@@ -114,11 +114,27 @@ def _patch_swerex_modal():
self._worker = _AsyncWorker()
self._worker.start()
# Pre-build a modal.Image with pip fix for Modal's legacy image builder.
# Modal requires `python -m pip` to work during image build, but some
# task images (e.g., TBLite's broken-python) have intentionally broken pip.
# Fix: remove stale pip dist-info and reinstall via ensurepip before Modal
# tries to use it. This is a no-op for images where pip already works.
import modal as _modal
image_spec = self.config.image
if isinstance(image_spec, str):
image_spec = _modal.Image.from_registry(
image_spec,
setup_dockerfile_commands=[
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
],
)
# Create AND start the deployment entirely on the worker's loop/thread
# so all gRPC channels and async state are bound to that loop
async def _create_and_start():
deployment = ModalDeployment(
image=self.config.image,
image=image_spec,
startup_timeout=self.config.startup_timeout,
runtime_timeout=self.config.runtime_timeout,
deployment_timeout=self.config.deployment_timeout,

View File

@@ -35,7 +35,8 @@ class DeepSeekV31ToolCallParser(ToolCallParser):
# Regex captures: function_name, function_arguments
PATTERN = re.compile(
r"<tool▁call▁begin>(?P<function_name>.*?)<tool▁sep>(?P<function_arguments>.*?)<tool▁call▁end>"
r"<tool▁call▁begin>(?P<function_name>.*?)<tool▁sep>(?P<function_arguments>.*?)<tool▁call▁end>",
re.DOTALL,
)
def parse(self, text: str) -> ParseResult:

View File

@@ -38,7 +38,8 @@ class DeepSeekV3ToolCallParser(ToolCallParser):
# Regex captures: type, function_name, function_arguments
PATTERN = re.compile(
r"<tool▁call▁begin>(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<tool▁call▁end>"
r"<tool▁call▁begin>(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<tool▁call▁end>",
re.DOTALL,
)
def parse(self, text: str) -> ParseResult:

View File

@@ -44,7 +44,7 @@ _tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
"""
Run a tool call in a thread pool executor so backends that use asyncio.run()
internally (modal, docker) get a clean event loop.
internally (modal, docker, daytona) get a clean event loop.
If we're already in an async context, executes handle_function_call() in a
disposable worker thread and blocks for the result.
@@ -95,7 +95,7 @@ class ToolContext:
backend = os.getenv("TERMINAL_ENV", "local")
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
# Run via thread helper so modal/docker backends' asyncio.run() doesn't deadlock
# Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock
result = _run_tool_in_thread(
"terminal",
{"command": command, "timeout": timeout},

View File

@@ -701,6 +701,8 @@ class BasePlatformAdapter(ABC):
# Extract image URLs and send them as native platform attachments
images, text_content = self.extract_images(response)
if images:
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
# Send the text portion first (if any remains after extractions)
if text_content:
@@ -727,10 +729,13 @@ class BasePlatformAdapter(ABC):
human_delay = self._get_human_delay()
# Send extracted images as native attachments
if images:
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
for image_url, alt_text in images:
if human_delay > 0:
await asyncio.sleep(human_delay)
try:
logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "")
# Route animated GIFs through send_animation for proper playback
if self._is_animation_url(image_url):
img_result = await self.send_animation(
@@ -745,9 +750,9 @@ class BasePlatformAdapter(ABC):
caption=alt_text if alt_text else None,
)
if not img_result.success:
print(f"[{self.name}] Failed to send image: {img_result.error}")
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
except Exception as img_err:
print(f"[{self.name}] Error sending image: {img_err}")
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
# Send extracted media files — route by file type
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}

View File

@@ -267,6 +267,43 @@ class DiscordAdapter(BasePlatformAdapter):
print(f"[{self.name}] Failed to send audio: {e}")
return await super().send_voice(chat_id, audio_path, caption, reply_to)
async def send_image_file(
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
) -> SendResult:
"""Send a local image file natively as a Discord file attachment."""
if not self._client:
return SendResult(success=False, error="Not connected")
try:
import io
channel = self._client.get_channel(int(chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
if not channel:
return SendResult(success=False, error=f"Channel {chat_id} not found")
if not os.path.exists(image_path):
return SendResult(success=False, error=f"Image file not found: {image_path}")
filename = os.path.basename(image_path)
with open(image_path, "rb") as f:
file = discord.File(io.BytesIO(f.read()), filename=filename)
msg = await channel.send(
content=caption if caption else None,
file=file,
)
return SendResult(success=True, message_id=str(msg.id))
except Exception as e:
print(f"[{self.name}] Failed to send local image: {e}")
return await super().send_image_file(chat_id, image_path, caption, reply_to)
async def send_image(
self,
chat_id: str,

View File

@@ -179,6 +179,35 @@ class SlackAdapter(BasePlatformAdapter):
"""Slack doesn't have a direct typing indicator API for bots."""
pass
async def send_image_file(
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
) -> SendResult:
"""Send a local image file to Slack by uploading it."""
if not self._app:
return SendResult(success=False, error="Not connected")
try:
import os
if not os.path.exists(image_path):
return SendResult(success=False, error=f"Image file not found: {image_path}")
result = await self._app.client.files_upload_v2(
channel=chat_id,
file=image_path,
filename=os.path.basename(image_path),
initial_comment=caption or "",
thread_ts=reply_to,
)
return SendResult(success=True, raw_response=result)
except Exception as e:
print(f"[{self.name}] Failed to send local image: {e}")
return await super().send_image_file(chat_id, image_path, caption, reply_to)
async def send_image(
self,
chat_id: str,

View File

@@ -8,10 +8,13 @@ Uses python-telegram-bot library for:
"""
import asyncio
import logging
import os
import re
from typing import Dict, List, Optional, Any
logger = logging.getLogger(__name__)
try:
from telegram import Update, Bot, Message
from telegram.ext import (
@@ -73,6 +76,19 @@ def _escape_mdv2(text: str) -> str:
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
def _strip_mdv2(text: str) -> str:
"""Strip MarkdownV2 escape backslashes to produce clean plain text.
Also removes MarkdownV2 bold markers (*text* -> text) so the fallback
doesn't show stray asterisks from header/bold conversion.
"""
# Remove escape backslashes before special characters
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
# Remove MarkdownV2 bold markers that format_message converted from **bold**
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
return cleaned
class TelegramAdapter(BasePlatformAdapter):
"""
Telegram bot adapter.
@@ -199,9 +215,13 @@ class TelegramAdapter(BasePlatformAdapter):
except Exception as md_error:
# Markdown parsing failed, try plain text
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
# Strip MDV2 escape backslashes so the user doesn't
# see raw backslashes littered through the message.
plain_chunk = _strip_mdv2(chunk)
msg = await self._bot.send_message(
chat_id=int(chat_id),
text=chunk,
text=plain_chunk,
parse_mode=None, # Plain text
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
message_thread_id=int(thread_id) if thread_id else None,
@@ -286,6 +306,34 @@ class TelegramAdapter(BasePlatformAdapter):
print(f"[{self.name}] Failed to send voice/audio: {e}")
return await super().send_voice(chat_id, audio_path, caption, reply_to)
async def send_image_file(
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
) -> SendResult:
"""Send a local image file natively as a Telegram photo."""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
import os
if not os.path.exists(image_path):
return SendResult(success=False, error=f"Image file not found: {image_path}")
with open(image_path, "rb") as image_file:
msg = await self._bot.send_photo(
chat_id=int(chat_id),
photo=image_file,
caption=caption[:1024] if caption else None,
reply_to_message_id=int(reply_to) if reply_to else None,
)
return SendResult(success=True, message_id=str(msg.message_id))
except Exception as e:
print(f"[{self.name}] Failed to send local image: {e}")
return await super().send_image_file(chat_id, image_path, caption, reply_to)
async def send_image(
self,
chat_id: str,
@@ -293,12 +341,16 @@ class TelegramAdapter(BasePlatformAdapter):
caption: Optional[str] = None,
reply_to: Optional[str] = None,
) -> SendResult:
"""Send an image natively as a Telegram photo."""
"""Send an image natively as a Telegram photo.
Tries URL-based send first (fast, works for <5MB images).
Falls back to downloading and uploading as file (supports up to 10MB).
"""
if not self._bot:
return SendResult(success=False, error="Not connected")
try:
# Telegram can send photos directly from URLs
# Telegram can send photos directly from URLs (up to ~5MB)
msg = await self._bot.send_photo(
chat_id=int(chat_id),
photo=image_url,
@@ -307,9 +359,26 @@ class TelegramAdapter(BasePlatformAdapter):
)
return SendResult(success=True, message_id=str(msg.message_id))
except Exception as e:
print(f"[{self.name}] Failed to send photo, falling back to URL: {e}")
# Fallback: send as text link
return await super().send_image(chat_id, image_url, caption, reply_to)
logger.warning("[%s] URL-based send_photo failed (%s), trying file upload", self.name, e)
# Fallback: download and upload as file (supports up to 10MB)
try:
import httpx
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(image_url)
resp.raise_for_status()
image_data = resp.content
msg = await self._bot.send_photo(
chat_id=int(chat_id),
photo=image_data,
caption=caption[:1024] if caption else None,
reply_to_message_id=int(reply_to) if reply_to else None,
)
return SendResult(success=True, message_id=str(msg.message_id))
except Exception as e2:
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
# Final fallback: send URL as text
return await super().send_image(chat_id, image_url, caption, reply_to)
async def send_animation(
self,

View File

@@ -28,6 +28,41 @@ from typing import Dict, List, Optional, Any
logger = logging.getLogger(__name__)
def _kill_port_process(port: int) -> None:
"""Kill any process listening on the given TCP port."""
try:
if _IS_WINDOWS:
# Use netstat to find the PID bound to this port, then taskkill
result = subprocess.run(
["netstat", "-ano", "-p", "TCP"],
capture_output=True, text=True, timeout=5,
)
for line in result.stdout.splitlines():
parts = line.split()
if len(parts) >= 5 and parts[3] == "LISTENING":
local_addr = parts[1]
if local_addr.endswith(f":{port}"):
try:
subprocess.run(
["taskkill", "/PID", parts[4], "/F"],
capture_output=True, timeout=5,
)
except subprocess.SubprocessError:
pass
else:
result = subprocess.run(
["fuser", f"{port}/tcp"],
capture_output=True, timeout=5,
)
if result.returncode == 0:
subprocess.run(
["fuser", "-k", f"{port}/tcp"],
capture_output=True, timeout=5,
)
except Exception:
pass
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
@@ -145,21 +180,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
self._session_path.mkdir(parents=True, exist_ok=True)
# Kill any orphaned bridge from a previous gateway run
try:
result = subprocess.run(
["fuser", f"{self._bridge_port}/tcp"],
capture_output=True, timeout=5,
)
if result.returncode == 0:
# Port is in use — kill the process
subprocess.run(
["fuser", "-k", f"{self._bridge_port}/tcp"],
capture_output=True, timeout=5,
)
import time
time.sleep(2)
except Exception:
pass
_kill_port_process(self._bridge_port)
import time
time.sleep(1)
# Start the bridge process in its own process group.
# Route output to a log file so QR codes, errors, and reconnection
@@ -293,13 +316,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
print(f"[{self.name}] Error stopping bridge: {e}")
# Also kill any orphaned bridge processes on our port
try:
subprocess.run(
["fuser", "-k", f"{self._bridge_port}/tcp"],
capture_output=True, timeout=5,
)
except Exception:
pass
_kill_port_process(self._bridge_port)
self._running = False
self._bridge_process = None

View File

@@ -66,6 +66,7 @@ if _config_path.exists():
"docker_image": "TERMINAL_DOCKER_IMAGE",
"singularity_image": "TERMINAL_SINGULARITY_IMAGE",
"modal_image": "TERMINAL_MODAL_IMAGE",
"daytona_image": "TERMINAL_DAYTONA_IMAGE",
"ssh_host": "TERMINAL_SSH_HOST",
"ssh_user": "TERMINAL_SSH_USER",
"ssh_port": "TERMINAL_SSH_PORT",
@@ -74,6 +75,7 @@ if _config_path.exists():
"container_memory": "TERMINAL_CONTAINER_MEMORY",
"container_disk": "TERMINAL_CONTAINER_DISK",
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
"sandbox_dir": "TERMINAL_SANDBOX_DIR",
}
for _cfg_key, _env_var in _terminal_env_map.items():
if _cfg_key in _terminal_cfg:
@@ -92,6 +94,11 @@ if _config_path.exists():
if _agent_cfg and isinstance(_agent_cfg, dict):
if "max_turns" in _agent_cfg:
os.environ["HERMES_MAX_ITERATIONS"] = str(_agent_cfg["max_turns"])
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
_tz_cfg = _cfg.get("timezone", "")
if _tz_cfg and isinstance(_tz_cfg, str) and "HERMES_TIMEZONE" not in os.environ:
os.environ["HERMES_TIMEZONE"] = _tz_cfg.strip()
except Exception:
pass # Non-fatal; gateway can still run with .env values
@@ -101,11 +108,13 @@ os.environ["HERMES_QUIET"] = "1"
# Enable interactive exec approval for dangerous commands on messaging platforms
os.environ["HERMES_EXEC_ASK"] = "1"
# Set terminal working directory for messaging platforms
# Uses MESSAGING_CWD if set, otherwise defaults to home directory
# This is separate from CLI which uses the directory where `hermes` is run
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
os.environ["TERMINAL_CWD"] = messaging_cwd
# Set terminal working directory for messaging platforms.
# If the user set an explicit path in config.yaml (not "." or "auto"),
# respect it. Otherwise use MESSAGING_CWD or default to home directory.
_configured_cwd = os.environ.get("TERMINAL_CWD", "")
if not _configured_cwd or _configured_cwd in (".", "auto", "cwd"):
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
os.environ["TERMINAL_CWD"] = messaging_cwd
from gateway.config import (
Platform,
@@ -172,7 +181,6 @@ class GatewayRunner:
self.session_store = SessionStore(
self.config.sessions_dir, self.config,
has_active_processes_fn=lambda key: process_registry.has_active_for_session(key),
on_auto_reset=self._flush_memories_before_reset,
)
self.delivery_router = DeliveryRouter(self.config)
self._running = False
@@ -203,15 +211,14 @@ class GatewayRunner:
from gateway.hooks import HookRegistry
self.hooks = HookRegistry()
def _flush_memories_before_reset(self, old_entry):
"""Prompt the agent to save memories/skills before an auto-reset.
Called synchronously by SessionStore before destroying an expired session.
Loads the transcript, gives the agent a real turn with memory + skills
tools, and explicitly asks it to preserve anything worth keeping.
def _flush_memories_for_session(self, old_session_id: str):
"""Prompt the agent to save memories/skills before context is lost.
Synchronous worker — meant to be called via run_in_executor from
an async context so it doesn't block the event loop.
"""
try:
history = self.session_store.load_transcript(old_entry.session_id)
history = self.session_store.load_transcript(old_session_id)
if not history or len(history) < 4:
return
@@ -225,7 +232,7 @@ class GatewayRunner:
max_iterations=8,
quiet_mode=True,
enabled_toolsets=["memory", "skills"],
session_id=old_entry.session_id,
session_id=old_session_id,
)
# Build conversation history from transcript
@@ -254,9 +261,14 @@ class GatewayRunner:
user_message=flush_prompt,
conversation_history=msgs,
)
logger.info("Pre-reset save completed for session %s", old_entry.session_id)
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
except Exception as e:
logger.debug("Pre-reset save failed for session %s: %s", old_entry.session_id, e)
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
async def _async_flush_memories(self, old_session_id: str):
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
@staticmethod
def _load_prefill_messages() -> List[Dict[str, Any]]:
@@ -324,7 +336,7 @@ class GatewayRunner:
Checks HERMES_REASONING_EFFORT env var first, then agent.reasoning_effort
in config.yaml. Valid: "xhigh", "high", "medium", "low", "minimal", "none".
Returns None to use default (xhigh).
Returns None to use default (medium).
"""
effort = os.getenv("HERMES_REASONING_EFFORT", "")
if not effort:
@@ -345,7 +357,7 @@ class GatewayRunner:
valid = ("xhigh", "high", "medium", "low", "minimal")
if effort in valid:
return {"enabled": True, "effort": effort}
logger.warning("Unknown reasoning_effort '%s', using default (xhigh)", effort)
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
return None
@staticmethod
@@ -458,10 +470,50 @@ class GatewayRunner:
# Check if we're restarting after a /update command
await self._send_update_notification()
# Start background session expiry watcher for proactive memory flushing
asyncio.create_task(self._session_expiry_watcher())
logger.info("Press Ctrl+C to stop")
return True
async def _session_expiry_watcher(self, interval: int = 300):
"""Background task that proactively flushes memories for expired sessions.
Runs every `interval` seconds (default 5 min). For each session that
has expired according to its reset policy, flushes memories in a thread
pool and marks the session so it won't be flushed again.
This means memories are already saved by the time the user sends their
next message, so there's no blocking delay.
"""
await asyncio.sleep(60) # initial delay — let the gateway fully start
while self._running:
try:
self.session_store._ensure_loaded()
for key, entry in list(self.session_store._entries.items()):
if entry.session_id in self.session_store._pre_flushed_sessions:
continue # already flushed this session
if not self.session_store._is_session_expired(entry):
continue # session still active
# Session has expired — flush memories in the background
logger.info(
"Session %s expired (key=%s), flushing memories proactively",
entry.session_id, key,
)
try:
await self._async_flush_memories(entry.session_id)
self.session_store._pre_flushed_sessions.add(entry.session_id)
except Exception as e:
logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e)
except Exception as e:
logger.debug("Session expiry watcher error: %s", e)
# Sleep in small increments so we can stop quickly
for _ in range(interval):
if not self._running:
break
await asyncio.sleep(1)
async def stop(self) -> None:
"""Stop the gateway and disconnect all adapters."""
logger.info("Stopping gateway...")
@@ -658,7 +710,8 @@ class GatewayRunner:
# Emit command:* hook for any recognized slash command
_known_commands = {"new", "reset", "help", "status", "stop", "model",
"personality", "retry", "undo", "sethome", "set-home",
"compress", "usage", "reload-mcp", "update"}
"compress", "usage", "insights", "reload-mcp", "update",
"title"}
if command and command in _known_commands:
await self.hooks.emit(f"command:{command}", {
"platform": source.platform.value if source.platform else "",
@@ -682,6 +735,9 @@ class GatewayRunner:
if command == "model":
return await self._handle_model_command(event)
if command == "provider":
return await self._handle_provider_command(event)
if command == "personality":
return await self._handle_personality_command(event)
@@ -700,11 +756,17 @@ class GatewayRunner:
if command == "usage":
return await self._handle_usage_command(event)
if command == "insights":
return await self._handle_insights_command(event)
if command == "reload-mcp":
return await self._handle_reload_mcp_command(event)
if command == "update":
return await self._handle_update_command(event)
if command == "title":
return await self._handle_title_command(event)
# Skill slash commands: /skill-name loads the skill and sends to agent
if command:
@@ -779,6 +841,167 @@ class GatewayRunner:
# Load conversation history from transcript
history = self.session_store.load_transcript(session_entry.session_id)
# -----------------------------------------------------------------
# Session hygiene: auto-compress pathologically large transcripts
#
# Long-lived gateway sessions can accumulate enough history that
# every new message rehydrates an oversized transcript, causing
# repeated truncation/context failures. Detect this early and
# compress proactively — before the agent even starts. (#628)
# -----------------------------------------------------------------
if history and len(history) >= 4:
from agent.model_metadata import estimate_messages_tokens_rough
# Read thresholds from config.yaml → session_hygiene section
_hygiene_cfg = {}
try:
_hyg_cfg_path = _hermes_home / "config.yaml"
if _hyg_cfg_path.exists():
import yaml as _hyg_yaml
with open(_hyg_cfg_path) as _hyg_f:
_hyg_data = _hyg_yaml.safe_load(_hyg_f) or {}
_hygiene_cfg = _hyg_data.get("session_hygiene", {})
if not isinstance(_hygiene_cfg, dict):
_hygiene_cfg = {}
except Exception:
pass
_compress_token_threshold = int(
_hygiene_cfg.get("auto_compress_tokens", 100_000)
)
_compress_msg_threshold = int(
_hygiene_cfg.get("auto_compress_messages", 200)
)
_warn_token_threshold = int(
_hygiene_cfg.get("warn_tokens", 200_000)
)
_msg_count = len(history)
_approx_tokens = estimate_messages_tokens_rough(history)
_needs_compress = (
_approx_tokens >= _compress_token_threshold
or _msg_count >= _compress_msg_threshold
)
if _needs_compress:
logger.info(
"Session hygiene: %s messages, ~%s tokens — auto-compressing "
"(thresholds: %s msgs / %s tokens)",
_msg_count, f"{_approx_tokens:,}",
_compress_msg_threshold, f"{_compress_token_threshold:,}",
)
_hyg_adapter = self.adapters.get(source.platform)
if _hyg_adapter:
try:
await _hyg_adapter.send(
source.chat_id,
f"🗜️ Session is large ({_msg_count} messages, "
f"~{_approx_tokens:,} tokens). Auto-compressing..."
)
except Exception:
pass
try:
from run_agent import AIAgent
_hyg_runtime = _resolve_runtime_agent_kwargs()
if _hyg_runtime.get("api_key"):
_hyg_msgs = [
{"role": m.get("role"), "content": m.get("content")}
for m in history
if m.get("role") in ("user", "assistant")
and m.get("content")
]
if len(_hyg_msgs) >= 4:
_hyg_agent = AIAgent(
**_hyg_runtime,
max_iterations=4,
quiet_mode=True,
enabled_toolsets=["memory"],
session_id=session_entry.session_id,
)
loop = asyncio.get_event_loop()
_compressed, _ = await loop.run_in_executor(
None,
lambda: _hyg_agent._compress_context(
_hyg_msgs, "",
approx_tokens=_approx_tokens,
),
)
self.session_store.rewrite_transcript(
session_entry.session_id, _compressed
)
history = _compressed
_new_count = len(_compressed)
_new_tokens = estimate_messages_tokens_rough(
_compressed
)
logger.info(
"Session hygiene: compressed %s%s msgs, "
"~%s → ~%s tokens",
_msg_count, _new_count,
f"{_approx_tokens:,}", f"{_new_tokens:,}",
)
if _hyg_adapter:
try:
await _hyg_adapter.send(
source.chat_id,
f"🗜️ Compressed: {_msg_count}"
f"{_new_count} messages, "
f"~{_approx_tokens:,}"
f"~{_new_tokens:,} tokens"
)
except Exception:
pass
# Still too large after compression — warn user
if _new_tokens >= _warn_token_threshold:
logger.warning(
"Session hygiene: still ~%s tokens after "
"compression — suggesting /reset",
f"{_new_tokens:,}",
)
if _hyg_adapter:
try:
await _hyg_adapter.send(
source.chat_id,
"⚠️ Session is still very large "
"after compression "
f"(~{_new_tokens:,} tokens). "
"Consider using /reset to start "
"fresh if you experience issues."
)
except Exception:
pass
except Exception as e:
logger.warning(
"Session hygiene auto-compress failed: %s", e
)
# Compression failed and session is dangerously large
if _approx_tokens >= _warn_token_threshold:
_hyg_adapter = self.adapters.get(source.platform)
if _hyg_adapter:
try:
await _hyg_adapter.send(
source.chat_id,
f"⚠️ Session is very large "
f"({_msg_count} messages, "
f"~{_approx_tokens:,} tokens) and "
"auto-compression failed. Consider "
"using /compress or /reset to avoid "
"issues."
)
except Exception:
pass
# First-message onboarding -- only on the very first interaction ever
if not history and not self.session_store.has_any_sessions():
context_prompt += (
@@ -1003,33 +1226,12 @@ class GatewayRunner:
# Get existing session key
session_key = self.session_store._generate_session_key(source)
# Memory flush before reset: load the old transcript and let a
# temporary agent save memories before the session is wiped.
# Flush memories in the background (fire-and-forget) so the user
# gets the "Session reset!" response immediately.
try:
old_entry = self.session_store._entries.get(session_key)
if old_entry:
old_history = self.session_store.load_transcript(old_entry.session_id)
if old_history:
from run_agent import AIAgent
loop = asyncio.get_event_loop()
_flush_kwargs = _resolve_runtime_agent_kwargs()
def _do_flush():
tmp_agent = AIAgent(
**_flush_kwargs,
max_iterations=5,
quiet_mode=True,
enabled_toolsets=["memory"],
session_id=old_entry.session_id,
)
# Build simple message list from transcript
msgs = []
for m in old_history:
role = m.get("role")
content = m.get("content")
if role in ("user", "assistant") and content:
msgs.append({"role": role, "content": content})
tmp_agent.flush_memories(msgs)
await loop.run_in_executor(None, _do_flush)
asyncio.create_task(self._async_flush_memories(old_entry.session_id))
except Exception as e:
logger.debug("Gateway memory flush on reset failed: %s", e)
@@ -1096,13 +1298,16 @@ class GatewayRunner:
"`/reset` — Reset conversation history",
"`/status` — Show session info",
"`/stop` — Interrupt the running agent",
"`/model [name]` — Show or change the model",
"`/model [provider:model]` — Show/change model (or switch provider)",
"`/provider` — Show available providers and auth status",
"`/personality [name]` — Set a personality",
"`/retry` — Retry your last message",
"`/undo` — Remove the last exchange",
"`/sethome` — Set this chat as the home channel",
"`/compress` — Compress conversation context",
"`/title [name]` — Set or show the session title",
"`/usage` — Show token usage for this session",
"`/insights [days]` — Show usage insights and analytics",
"`/reload-mcp` — Reload MCP servers from config",
"`/update` — Update Hermes Agent to the latest version",
"`/help` — Show this message",
@@ -1121,13 +1326,20 @@ class GatewayRunner:
async def _handle_model_command(self, event: MessageEvent) -> str:
"""Handle /model command - show or change the current model."""
import yaml
from hermes_cli.models import (
parse_model_input,
validate_requested_model,
curated_models_for_provider,
normalize_provider,
_PROVIDER_LABELS,
)
args = event.get_command_args().strip()
config_path = _hermes_home / 'config.yaml'
# Resolve current model the same way the agent init does:
# env vars first, then config.yaml always overrides.
# Resolve current model and provider from config
current = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
current_provider = "openrouter"
try:
if config_path.exists():
with open(config_path) as f:
@@ -1137,39 +1349,164 @@ class GatewayRunner:
current = model_cfg
elif isinstance(model_cfg, dict):
current = model_cfg.get("default", current)
current_provider = model_cfg.get("provider", current_provider)
except Exception:
pass
# Resolve "auto" to the actual provider using credential detection
current_provider = normalize_provider(current_provider)
if current_provider == "auto":
try:
from hermes_cli.auth import resolve_provider as _resolve_provider
current_provider = _resolve_provider(current_provider)
except Exception:
current_provider = "openrouter"
if not args:
return f"🤖 **Current model:** `{current}`\n\nTo change: `/model provider/model-name`"
provider_label = _PROVIDER_LABELS.get(current_provider, current_provider)
lines = [
f"🤖 **Current model:** `{current}`",
f"**Provider:** {provider_label}",
"",
]
curated = curated_models_for_provider(current_provider)
if curated:
lines.append(f"**Available models ({provider_label}):**")
for mid, desc in curated:
marker = "" if mid == current else ""
label = f" _{desc}_" if desc else ""
lines.append(f"• `{mid}`{label}{marker}")
lines.append("")
lines.append("To change: `/model model-name`")
lines.append("Switch provider: `/model provider:model-name`")
return "\n".join(lines)
if "/" not in args:
return (
f"🤖 Invalid model format: `{args}`\n\n"
f"Use `provider/model-name` format, e.g.:\n"
f"• `anthropic/claude-sonnet-4`\n"
f"• `google/gemini-2.5-pro`\n"
f"• `openai/gpt-4o`"
)
# Parse provider:model syntax
target_provider, new_model = parse_model_input(args, current_provider)
provider_changed = target_provider != current_provider
# Write to config.yaml (source of truth), same pattern as CLI save_config_value.
# Resolve credentials for the target provider (for API probe)
api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
base_url = "https://openrouter.ai/api/v1"
if provider_changed:
try:
from hermes_cli.runtime_provider import resolve_runtime_provider
runtime = resolve_runtime_provider(requested=target_provider)
api_key = runtime.get("api_key", "")
base_url = runtime.get("base_url", "")
except Exception as e:
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
return f"⚠️ Could not resolve credentials for provider '{provider_label}': {e}"
else:
# Use current provider's base_url from config or registry
try:
from hermes_cli.runtime_provider import resolve_runtime_provider
runtime = resolve_runtime_provider(requested=current_provider)
api_key = runtime.get("api_key", "")
base_url = runtime.get("base_url", "")
except Exception:
pass
# Validate the model against the live API
try:
validation = validate_requested_model(
new_model,
target_provider,
api_key=api_key,
base_url=base_url,
)
except Exception:
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
if not validation.get("accepted"):
msg = validation.get("message", "Invalid model")
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
return f"⚠️ {msg}{tip}"
# Persist to config only if validation approves
if validation.get("persist"):
try:
user_config = {}
if config_path.exists():
with open(config_path) as f:
user_config = yaml.safe_load(f) or {}
if "model" not in user_config or not isinstance(user_config["model"], dict):
user_config["model"] = {}
user_config["model"]["default"] = new_model
if provider_changed:
user_config["model"]["provider"] = target_provider
with open(config_path, 'w') as f:
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
except Exception as e:
return f"⚠️ Failed to save model change: {e}"
# Set env vars so the next agent run picks up the change
os.environ["HERMES_MODEL"] = new_model
if provider_changed:
os.environ["HERMES_INFERENCE_PROVIDER"] = target_provider
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
provider_note = f"\n**Provider:** {provider_label}" if provider_changed else ""
warning = ""
if validation.get("message"):
warning = f"\n⚠️ {validation['message']}"
if validation.get("persist"):
persist_note = "saved to config"
else:
persist_note = "this session only — will revert on restart"
return f"🤖 Model changed to `{new_model}` ({persist_note}){provider_note}{warning}\n_(takes effect on next message)_"
async def _handle_provider_command(self, event: MessageEvent) -> str:
"""Handle /provider command - show available providers."""
import yaml
from hermes_cli.models import (
list_available_providers,
normalize_provider,
_PROVIDER_LABELS,
)
# Resolve current provider from config
current_provider = "openrouter"
config_path = _hermes_home / 'config.yaml'
try:
user_config = {}
if config_path.exists():
with open(config_path) as f:
user_config = yaml.safe_load(f) or {}
if "model" not in user_config or not isinstance(user_config["model"], dict):
user_config["model"] = {}
user_config["model"]["default"] = args
with open(config_path, 'w') as f:
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
except Exception as e:
return f"⚠️ Failed to save model change: {e}"
cfg = yaml.safe_load(f) or {}
model_cfg = cfg.get("model", {})
if isinstance(model_cfg, dict):
current_provider = model_cfg.get("provider", current_provider)
except Exception:
pass
# Also set env var so code reading it before the next agent init sees the update.
os.environ["HERMES_MODEL"] = args
current_provider = normalize_provider(current_provider)
if current_provider == "auto":
try:
from hermes_cli.auth import resolve_provider as _resolve_provider
current_provider = _resolve_provider(current_provider)
except Exception:
current_provider = "openrouter"
return f"🤖 Model changed to `{args}`\n_(takes effect on next message)_"
current_label = _PROVIDER_LABELS.get(current_provider, current_provider)
lines = [
f"🔌 **Current provider:** {current_label} (`{current_provider}`)",
"",
"**Available providers:**",
]
providers = list_available_providers()
for p in providers:
marker = " ← active" if p["id"] == current_provider else ""
auth = "" if p["authenticated"] else ""
aliases = f" _(also: {', '.join(p['aliases'])})_" if p["aliases"] else ""
lines.append(f"{auth} `{p['id']}` — {p['label']}{aliases}{marker}")
lines.append("")
lines.append("Switch: `/model provider:model-name`")
lines.append("Setup: `hermes setup`")
return "\n".join(lines)
async def _handle_personality_command(self, event: MessageEvent) -> str:
"""Handle /personality command - list or set a personality."""
@@ -1253,8 +1590,7 @@ class GatewayRunner:
)
# Let the normal message handler process it
await self._handle_message(retry_event)
return None # Response sent through normal flow
return await self._handle_message(retry_event)
async def _handle_undo_command(self, event: MessageEvent) -> str:
"""Handle /undo command - remove the last user/assistant exchange."""
@@ -1360,6 +1696,40 @@ class GatewayRunner:
logger.warning("Manual compress failed: %s", e)
return f"Compression failed: {e}"
async def _handle_title_command(self, event: MessageEvent) -> str:
"""Handle /title command — set or show the current session's title."""
source = event.source
session_entry = self.session_store.get_or_create_session(source)
session_id = session_entry.session_id
if not self._session_db:
return "Session database not available."
title_arg = event.get_command_args().strip()
if title_arg:
# Sanitize the title before setting
try:
sanitized = self._session_db.sanitize_title(title_arg)
except ValueError as e:
return f"⚠️ {e}"
if not sanitized:
return "⚠️ Title is empty after cleanup. Please use printable characters."
# Set the title
try:
if self._session_db.set_session_title(session_id, sanitized):
return f"✏️ Session title set: **{sanitized}**"
else:
return "Session not found in database."
except ValueError as e:
return f"⚠️ {e}"
else:
# Show the current title
title = self._session_db.get_session_title(session_id)
if title:
return f"📌 Session title: **{title}**"
else:
return "No title set. Usage: `/title My Session Name`"
async def _handle_usage_command(self, event: MessageEvent) -> str:
"""Handle /usage command -- show token usage for the session's last agent run."""
source = event.source
@@ -1397,6 +1767,53 @@ class GatewayRunner:
)
return "No usage data available for this session."
async def _handle_insights_command(self, event: MessageEvent) -> str:
"""Handle /insights command -- show usage insights and analytics."""
import asyncio as _asyncio
args = event.get_command_args().strip()
days = 30
source = None
# Parse simple args: /insights 7 or /insights --days 7
if args:
parts = args.split()
i = 0
while i < len(parts):
if parts[i] == "--days" and i + 1 < len(parts):
try:
days = int(parts[i + 1])
except ValueError:
return f"Invalid --days value: {parts[i + 1]}"
i += 2
elif parts[i] == "--source" and i + 1 < len(parts):
source = parts[i + 1]
i += 2
elif parts[i].isdigit():
days = int(parts[i])
i += 1
else:
i += 1
try:
from hermes_state import SessionDB
from agent.insights import InsightsEngine
loop = _asyncio.get_event_loop()
def _run_insights():
db = SessionDB()
engine = InsightsEngine(db)
report = engine.generate(days=days, source=source)
result = engine.format_gateway(report)
db.close()
return result
return await loop.run_in_executor(None, _run_insights)
except Exception as e:
logger.error("Insights command error: %s", e, exc_info=True)
return f"Error generating insights: {e}"
async def _handle_reload_mcp_command(self, event: MessageEvent) -> str:
"""Handle /reload-mcp command -- disconnect and reconnect all MCP servers."""
loop = asyncio.get_event_loop()
@@ -2041,7 +2458,7 @@ class GatewayRunner:
os.environ["HERMES_SESSION_KEY"] = session_key or ""
# Read from env var or use default (same as CLI)
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
# Map platform enum to the platform hint key the agent understands.
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
@@ -2381,14 +2798,85 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
logger.info("Cron ticker stopped")
async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = False) -> bool:
"""
Start the gateway and run until interrupted.
This is the main entry point for running the gateway.
Returns True if the gateway ran successfully, False if it failed to start.
A False return causes a non-zero exit code so systemd can auto-restart.
Args:
config: Optional gateway configuration override.
replace: If True, kill any existing gateway instance before starting.
Useful for systemd services to avoid restart-loop deadlocks
when the previous process hasn't fully exited yet.
"""
# ── Duplicate-instance guard ──────────────────────────────────────
# Prevent two gateways from running under the same HERMES_HOME.
# The PID file is scoped to HERMES_HOME, so future multi-profile
# setups (each profile using a distinct HERMES_HOME) will naturally
# allow concurrent instances without tripping this guard.
import time as _time
from gateway.status import get_running_pid, remove_pid_file
existing_pid = get_running_pid()
if existing_pid is not None and existing_pid != os.getpid():
if replace:
logger.info(
"Replacing existing gateway instance (PID %d) with --replace.",
existing_pid,
)
try:
os.kill(existing_pid, signal.SIGTERM)
except ProcessLookupError:
pass # Already gone
except PermissionError:
logger.error(
"Permission denied killing PID %d. Cannot replace.",
existing_pid,
)
return False
# Wait up to 10 seconds for the old process to exit
for _ in range(20):
try:
os.kill(existing_pid, 0)
_time.sleep(0.5)
except (ProcessLookupError, PermissionError):
break # Process is gone
else:
# Still alive after 10s — force kill
logger.warning(
"Old gateway (PID %d) did not exit after SIGTERM, sending SIGKILL.",
existing_pid,
)
try:
os.kill(existing_pid, signal.SIGKILL)
_time.sleep(0.5)
except (ProcessLookupError, PermissionError):
pass
remove_pid_file()
else:
hermes_home = os.getenv("HERMES_HOME", "~/.hermes")
logger.error(
"Another gateway instance is already running (PID %d, HERMES_HOME=%s). "
"Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.",
existing_pid, hermes_home,
)
print(
f"\n❌ Gateway already running (PID {existing_pid}).\n"
f" Use 'hermes gateway restart' to replace it,\n"
f" or 'hermes gateway stop' to kill it first.\n"
f" Or use 'hermes gateway run --replace' to auto-replace.\n"
)
return False
# Sync bundled skills on gateway start (fast -- skips unchanged)
try:
from tools.skills_sync import sync_skills
sync_skills(quiet=True)
except Exception:
pass
# Configure rotating file log so gateway output is persisted for debugging
log_dir = _hermes_home / 'logs'
log_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -311,7 +311,9 @@ class SessionStore:
self._entries: Dict[str, SessionEntry] = {}
self._loaded = False
self._has_active_processes_fn = has_active_processes_fn
self._on_auto_reset = on_auto_reset # callback(old_entry) before auto-reset
# on_auto_reset is deprecated — memory flush now runs proactively
# via the background session expiry watcher in GatewayRunner.
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
# Initialize SQLite session database
self._db = None
@@ -353,6 +355,44 @@ class SessionStore:
"""Generate a session key from a source."""
return build_session_key(source)
def _is_session_expired(self, entry: SessionEntry) -> bool:
"""Check if a session has expired based on its reset policy.
Works from the entry alone — no SessionSource needed.
Used by the background expiry watcher to proactively flush memories.
Sessions with active background processes are never considered expired.
"""
if self._has_active_processes_fn:
if self._has_active_processes_fn(entry.session_key):
return False
policy = self.config.get_reset_policy(
platform=entry.platform,
session_type=entry.chat_type,
)
if policy.mode == "none":
return False
now = datetime.now()
if policy.mode in ("idle", "both"):
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
if now > idle_deadline:
return True
if policy.mode in ("daily", "both"):
today_reset = now.replace(
hour=policy.at_hour,
minute=0, second=0, microsecond=0,
)
if now.hour < policy.at_hour:
today_reset -= timedelta(days=1)
if entry.updated_at < today_reset:
return True
return False
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
"""
Check if a session should be reset based on policy.
@@ -439,13 +479,11 @@ class SessionStore:
self._save()
return entry
else:
# Session is being auto-reset — flush memories before destroying
# Session is being auto-reset. The background expiry watcher
# should have already flushed memories proactively; discard
# the marker so it doesn't accumulate.
was_auto_reset = True
if self._on_auto_reset:
try:
self._on_auto_reset(entry)
except Exception as e:
logger.debug("Auto-reset callback failed: %s", e)
self._pre_flushed_sessions.discard(entry.session_id)
if self._db:
try:
self._db.end_session(entry.session_id, "session_reset")

View File

@@ -3,37 +3,59 @@ Gateway runtime status helpers.
Provides PID-file based detection of whether the gateway daemon is running,
used by send_message's check_fn to gate availability in the CLI.
The PID file lives at ``{HERMES_HOME}/gateway.pid``. HERMES_HOME defaults to
``~/.hermes`` but can be overridden via the environment variable. This means
separate HERMES_HOME directories naturally get separate PID files — a property
that will be useful when we add named profiles (multiple agents running
concurrently under distinct configurations).
"""
import os
from pathlib import Path
from typing import Optional
_PID_FILE = Path.home() / ".hermes" / "gateway.pid"
def _get_pid_path() -> Path:
"""Return the path to the gateway PID file, respecting HERMES_HOME."""
home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
return home / "gateway.pid"
def write_pid_file() -> None:
"""Write the current process PID to the gateway PID file."""
_PID_FILE.parent.mkdir(parents=True, exist_ok=True)
_PID_FILE.write_text(str(os.getpid()))
pid_path = _get_pid_path()
pid_path.parent.mkdir(parents=True, exist_ok=True)
pid_path.write_text(str(os.getpid()))
def remove_pid_file() -> None:
"""Remove the gateway PID file if it exists."""
try:
_PID_FILE.unlink(missing_ok=True)
_get_pid_path().unlink(missing_ok=True)
except Exception:
pass
def get_running_pid() -> Optional[int]:
"""Return the PID of a running gateway instance, or ``None``.
Checks the PID file and verifies the process is actually alive.
Cleans up stale PID files automatically.
"""
pid_path = _get_pid_path()
if not pid_path.exists():
return None
try:
pid = int(pid_path.read_text().strip())
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
return pid
except (ValueError, ProcessLookupError, PermissionError):
# Stale PID file — process is gone
remove_pid_file()
return None
def is_gateway_running() -> bool:
"""Check if the gateway daemon is currently running."""
if not _PID_FILE.exists():
return False
try:
pid = int(_PID_FILE.read_text().strip())
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
return True
except (ValueError, ProcessLookupError, PermissionError):
# Stale PID file -- process is gone
remove_pid_file()
return False
return get_running_pid() is not None

View File

@@ -72,15 +72,19 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
@dataclass
class ProviderConfig:
"""Describes a known OAuth provider."""
"""Describes a known inference provider."""
id: str
name: str
auth_type: str # "oauth_device_code" or "api_key"
auth_type: str # "oauth_device_code", "oauth_external", or "api_key"
portal_base_url: str = ""
inference_base_url: str = ""
client_id: str = ""
scope: str = ""
extra: Dict[str, Any] = field(default_factory=dict)
# For API-key providers: env vars to check (in priority order)
api_key_env_vars: tuple = ()
# Optional env var for base URL override
base_url_env_var: str = ""
PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
@@ -99,9 +103,118 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
auth_type="oauth_external",
inference_base_url=DEFAULT_CODEX_BASE_URL,
),
"zai": ProviderConfig(
id="zai",
name="Z.AI / GLM",
auth_type="api_key",
inference_base_url="https://api.z.ai/api/paas/v4",
api_key_env_vars=("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
base_url_env_var="GLM_BASE_URL",
),
"kimi-coding": ProviderConfig(
id="kimi-coding",
name="Kimi / Moonshot",
auth_type="api_key",
inference_base_url="https://api.moonshot.ai/v1",
api_key_env_vars=("KIMI_API_KEY",),
base_url_env_var="KIMI_BASE_URL",
),
"minimax": ProviderConfig(
id="minimax",
name="MiniMax",
auth_type="api_key",
inference_base_url="https://api.minimax.io/v1",
api_key_env_vars=("MINIMAX_API_KEY",),
base_url_env_var="MINIMAX_BASE_URL",
),
"minimax-cn": ProviderConfig(
id="minimax-cn",
name="MiniMax (China)",
auth_type="api_key",
inference_base_url="https://api.minimaxi.com/v1",
api_key_env_vars=("MINIMAX_CN_API_KEY",),
base_url_env_var="MINIMAX_CN_BASE_URL",
),
}
# =============================================================================
# Kimi Code Endpoint Detection
# =============================================================================
# Kimi Code (platform.kimi.ai) issues keys prefixed "sk-kimi-" that only work
# on api.kimi.com/coding/v1. Legacy keys from platform.moonshot.ai work on
# api.moonshot.ai/v1 (the default). Auto-detect when user hasn't set
# KIMI_BASE_URL explicitly.
KIMI_CODE_BASE_URL = "https://api.kimi.com/coding/v1"
def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) -> str:
"""Return the correct Kimi base URL based on the API key prefix.
If the user has explicitly set KIMI_BASE_URL, that always wins.
Otherwise, sk-kimi- prefixed keys route to api.kimi.com/coding/v1.
"""
if env_override:
return env_override
if api_key.startswith("sk-kimi-"):
return KIMI_CODE_BASE_URL
return default_url
# =============================================================================
# Z.AI Endpoint Detection
# =============================================================================
# Z.AI has separate billing for general vs coding plans, and global vs China
# endpoints. A key that works on one may return "Insufficient balance" on
# another. We probe at setup time and store the working endpoint.
ZAI_ENDPOINTS = [
# (id, base_url, default_model, label)
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
]
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str, str]]:
"""Probe z.ai endpoints to find one that accepts this API key.
Returns {"id": ..., "base_url": ..., "model": ..., "label": ...} for the
first working endpoint, or None if all fail.
"""
for ep_id, base_url, model, label in ZAI_ENDPOINTS:
try:
resp = httpx.post(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": model,
"stream": False,
"max_tokens": 1,
"messages": [{"role": "user", "content": "ping"}],
},
timeout=timeout,
)
if resp.status_code == 200:
logger.debug("Z.AI endpoint probe: %s (%s) OK", ep_id, base_url)
return {
"id": ep_id,
"base_url": base_url,
"model": model,
"label": label,
}
logger.debug("Z.AI endpoint probe: %s returned %s", ep_id, resp.status_code)
except Exception as exc:
logger.debug("Z.AI endpoint probe: %s failed: %s", ep_id, exc)
return None
# =============================================================================
# Error Types
# =============================================================================
@@ -355,10 +468,19 @@ def resolve_provider(
1. active_provider in auth.json with valid credentials
2. Explicit CLI api_key/base_url -> "openrouter"
3. OPENAI_API_KEY or OPENROUTER_API_KEY env vars -> "openrouter"
4. Fallback: "openrouter"
4. Provider-specific API keys (GLM, Kimi, MiniMax) -> that provider
5. Fallback: "openrouter"
"""
normalized = (requested or "auto").strip().lower()
# Normalize provider aliases
_PROVIDER_ALIASES = {
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
"kimi": "kimi-coding", "moonshot": "kimi-coding",
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
}
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
if normalized in {"openrouter", "custom"}:
return "openrouter"
if normalized in PROVIDER_REGISTRY:
@@ -387,6 +509,14 @@ def resolve_provider(
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
return "openrouter"
# Auto-detect API-key providers by checking their env vars
for pid, pconfig in PROVIDER_REGISTRY.items():
if pconfig.auth_type != "api_key":
continue
for env_var in pconfig.api_key_env_vars:
if os.getenv(env_var, "").strip():
return pid
return "openrouter"
@@ -1230,6 +1360,42 @@ def get_codex_auth_status() -> Dict[str, Any]:
}
def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
"""Status snapshot for API-key providers (z.ai, Kimi, MiniMax)."""
pconfig = PROVIDER_REGISTRY.get(provider_id)
if not pconfig or pconfig.auth_type != "api_key":
return {"configured": False}
api_key = ""
key_source = ""
for env_var in pconfig.api_key_env_vars:
val = os.getenv(env_var, "").strip()
if val:
api_key = val
key_source = env_var
break
env_url = ""
if pconfig.base_url_env_var:
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
if provider_id == "kimi-coding":
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, env_url)
elif env_url:
base_url = env_url
else:
base_url = pconfig.inference_base_url
return {
"configured": bool(api_key),
"provider": provider_id,
"name": pconfig.name,
"key_source": key_source,
"base_url": base_url,
"logged_in": bool(api_key), # compat with OAuth status shape
}
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
"""Generic auth status dispatcher."""
target = provider_id or get_active_provider()
@@ -1237,9 +1403,54 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
return get_nous_auth_status()
if target == "openai-codex":
return get_codex_auth_status()
# API-key providers
pconfig = PROVIDER_REGISTRY.get(target)
if pconfig and pconfig.auth_type == "api_key":
return get_api_key_provider_status(target)
return {"logged_in": False}
def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
"""Resolve API key and base URL for an API-key provider.
Returns dict with: provider, api_key, base_url, source.
"""
pconfig = PROVIDER_REGISTRY.get(provider_id)
if not pconfig or pconfig.auth_type != "api_key":
raise AuthError(
f"Provider '{provider_id}' is not an API-key provider.",
provider=provider_id,
code="invalid_provider",
)
api_key = ""
key_source = ""
for env_var in pconfig.api_key_env_vars:
val = os.getenv(env_var, "").strip()
if val:
api_key = val
key_source = env_var
break
env_url = ""
if pconfig.base_url_env_var:
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
if provider_id == "kimi-coding":
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, env_url)
elif env_url:
base_url = env_url.rstrip("/")
else:
base_url = pconfig.inference_base_url
return {
"provider": provider_id,
"api_key": api_key,
"base_url": base_url.rstrip("/"),
"source": key_source or "default",
}
# =============================================================================
# External credential detection
# =============================================================================

View File

@@ -1,10 +1,15 @@
"""Welcome banner, ASCII art, and skills summary for the CLI.
"""Welcome banner, ASCII art, skills summary, and update check for the CLI.
Pure display functions with no HermesCLI state dependency.
"""
import json
import logging
import os
import subprocess
import time
from pathlib import Path
from typing import Dict, List, Any
from typing import Dict, List, Any, Optional
from rich.console import Console
from rich.panel import Panel
@@ -13,6 +18,8 @@ from rich.table import Table
from prompt_toolkit import print_formatted_text as _pt_print
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
logger = logging.getLogger(__name__)
# =========================================================================
# ANSI building blocks for conversation display
@@ -95,15 +102,93 @@ def get_available_skills() -> Dict[str, List[str]]:
return skills_by_category
# =========================================================================
# Update check
# =========================================================================
# Cache update check results for 6 hours to avoid repeated git fetches
_UPDATE_CHECK_CACHE_SECONDS = 6 * 3600
def check_for_updates() -> Optional[int]:
"""Check how many commits behind origin/main the local repo is.
Does a ``git fetch`` at most once every 6 hours (cached to
``~/.hermes/.update_check``). Returns the number of commits behind,
or ``None`` if the check fails or isn't applicable.
"""
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
repo_dir = hermes_home / "hermes-agent"
cache_file = hermes_home / ".update_check"
# Must be a git repo
if not (repo_dir / ".git").exists():
return None
# Read cache
now = time.time()
try:
if cache_file.exists():
cached = json.loads(cache_file.read_text())
if now - cached.get("ts", 0) < _UPDATE_CHECK_CACHE_SECONDS:
return cached.get("behind")
except Exception:
pass
# Fetch latest refs (fast — only downloads ref metadata, no files)
try:
subprocess.run(
["git", "fetch", "origin", "--quiet"],
capture_output=True, timeout=10,
cwd=str(repo_dir),
)
except Exception:
pass # Offline or timeout — use stale refs, that's fine
# Count commits behind
try:
result = subprocess.run(
["git", "rev-list", "--count", "HEAD..origin/main"],
capture_output=True, text=True, timeout=5,
cwd=str(repo_dir),
)
if result.returncode == 0:
behind = int(result.stdout.strip())
else:
behind = None
except Exception:
behind = None
# Write cache
try:
cache_file.write_text(json.dumps({"ts": now, "behind": behind}))
except Exception:
pass
return behind
# =========================================================================
# Welcome banner
# =========================================================================
def _format_context_length(tokens: int) -> str:
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
if tokens >= 1_000_000:
val = tokens / 1_000_000
return f"{val:g}M"
elif tokens >= 1_000:
val = tokens / 1_000
return f"{val:g}K"
return str(tokens)
def build_welcome_banner(console: Console, model: str, cwd: str,
tools: List[dict] = None,
enabled_toolsets: List[str] = None,
session_id: str = None,
get_toolset_for_tool=None):
get_toolset_for_tool=None,
context_length: int = None):
"""Build and print a welcome banner with caduceus on left and info on right.
Args:
@@ -114,6 +199,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
enabled_toolsets: List of enabled toolset names.
session_id: Session identifier.
get_toolset_for_tool: Callable to map tool name -> toolset name.
context_length: Model's context window size in tokens.
"""
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
if get_toolset_for_tool is None:
@@ -135,7 +221,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
model_short = model.split("/")[-1] if "/" in model else model
if len(model_short) > 28:
model_short = model_short[:25] + "..."
left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
left_lines.append(f"[dim #B8860B]{cwd}[/]")
if session_id:
left_lines.append(f"[dim #8B8682]Session: {session_id}[/]")
@@ -245,6 +332,18 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
summary_parts.append("/help for commands")
right_lines.append(f"[dim #B8860B]{' · '.join(summary_parts)}[/]")
# Update check — show if behind origin/main
try:
behind = check_for_updates()
if behind and behind > 0:
commits_word = "commit" if behind == 1 else "commits"
right_lines.append(
f"[bold yellow]⚠ {behind} {commits_word} behind[/]"
f"[dim yellow] — run [bold]hermes update[/bold] to update[/]"
)
except Exception:
pass # Never break the banner over an update check
right_content = "\n".join(right_lines)
layout_table.add_row(left_content, right_content)

352
hermes_cli/clipboard.py Normal file
View File

@@ -0,0 +1,352 @@
"""Clipboard image extraction for macOS, Linux, and WSL2.
Provides a single function `save_clipboard_image(dest)` that checks the
system clipboard for image data, saves it to *dest* as PNG, and returns
True on success. No external Python dependencies — uses only OS-level
CLI tools that ship with the platform (or are commonly installed).
Platform support:
macOS — osascript (always available), pngpaste (if installed)
WSL2 — powershell.exe via .NET System.Windows.Forms.Clipboard
Linux — wl-paste (Wayland), xclip (X11)
"""
import base64
import logging
import os
import subprocess
import sys
from pathlib import Path
logger = logging.getLogger(__name__)
# Cache WSL detection (checked once per process)
_wsl_detected: bool | None = None
def save_clipboard_image(dest: Path) -> bool:
"""Extract an image from the system clipboard and save it as PNG.
Returns True if an image was found and saved, False otherwise.
"""
dest.parent.mkdir(parents=True, exist_ok=True)
if sys.platform == "darwin":
return _macos_save(dest)
return _linux_save(dest)
def has_clipboard_image() -> bool:
"""Quick check: does the clipboard currently contain an image?
Lighter than save_clipboard_image — doesn't extract or write anything.
"""
if sys.platform == "darwin":
return _macos_has_image()
if _is_wsl():
return _wsl_has_image()
if os.environ.get("WAYLAND_DISPLAY"):
return _wayland_has_image()
return _xclip_has_image()
# ── macOS ────────────────────────────────────────────────────────────────
def _macos_save(dest: Path) -> bool:
"""Try pngpaste first (fast, handles more formats), fall back to osascript."""
return _macos_pngpaste(dest) or _macos_osascript(dest)
def _macos_has_image() -> bool:
"""Check if macOS clipboard contains image data."""
try:
info = subprocess.run(
["osascript", "-e", "clipboard info"],
capture_output=True, text=True, timeout=3,
)
return "«class PNGf»" in info.stdout or "«class TIFF»" in info.stdout
except Exception:
return False
def _macos_pngpaste(dest: Path) -> bool:
"""Use pngpaste (brew install pngpaste) — fastest, cleanest."""
try:
r = subprocess.run(
["pngpaste", str(dest)],
capture_output=True, timeout=3,
)
if r.returncode == 0 and dest.exists() and dest.stat().st_size > 0:
return True
except FileNotFoundError:
pass # pngpaste not installed
except Exception as e:
logger.debug("pngpaste failed: %s", e)
return False
def _macos_osascript(dest: Path) -> bool:
"""Use osascript to extract PNG data from clipboard (always available)."""
if not _macos_has_image():
return False
# Extract as PNG
script = (
'try\n'
' set imgData to the clipboard as «class PNGf»\n'
f' set f to open for access POSIX file "{dest}" with write permission\n'
' write imgData to f\n'
' close access f\n'
'on error\n'
' return "fail"\n'
'end try\n'
)
try:
r = subprocess.run(
["osascript", "-e", script],
capture_output=True, text=True, timeout=5,
)
if r.returncode == 0 and "fail" not in r.stdout and dest.exists() and dest.stat().st_size > 0:
return True
except Exception as e:
logger.debug("osascript clipboard extract failed: %s", e)
return False
# ── Linux ────────────────────────────────────────────────────────────────
def _is_wsl() -> bool:
"""Detect if running inside WSL (1 or 2)."""
global _wsl_detected
if _wsl_detected is not None:
return _wsl_detected
try:
with open("/proc/version", "r") as f:
_wsl_detected = "microsoft" in f.read().lower()
except Exception:
_wsl_detected = False
return _wsl_detected
def _linux_save(dest: Path) -> bool:
"""Try clipboard backends in priority order: WSL → Wayland → X11."""
if _is_wsl():
if _wsl_save(dest):
return True
# Fall through — WSLg might have wl-paste or xclip working
if os.environ.get("WAYLAND_DISPLAY"):
if _wayland_save(dest):
return True
return _xclip_save(dest)
# ── WSL2 (powershell.exe) ────────────────────────────────────────────────
# PowerShell script: get clipboard image as base64-encoded PNG on stdout.
# Using .NET System.Windows.Forms.Clipboard — always available on Windows.
_PS_CHECK_IMAGE = (
"Add-Type -AssemblyName System.Windows.Forms;"
"[System.Windows.Forms.Clipboard]::ContainsImage()"
)
_PS_EXTRACT_IMAGE = (
"Add-Type -AssemblyName System.Windows.Forms;"
"Add-Type -AssemblyName System.Drawing;"
"$img = [System.Windows.Forms.Clipboard]::GetImage();"
"if ($null -eq $img) { exit 1 }"
"$ms = New-Object System.IO.MemoryStream;"
"$img.Save($ms, [System.Drawing.Imaging.ImageFormat]::Png);"
"[System.Convert]::ToBase64String($ms.ToArray())"
)
def _wsl_has_image() -> bool:
"""Check if Windows clipboard has an image (via powershell.exe)."""
try:
r = subprocess.run(
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
_PS_CHECK_IMAGE],
capture_output=True, text=True, timeout=8,
)
return r.returncode == 0 and "True" in r.stdout
except FileNotFoundError:
logger.debug("powershell.exe not found — WSL clipboard unavailable")
except Exception as e:
logger.debug("WSL clipboard check failed: %s", e)
return False
def _wsl_save(dest: Path) -> bool:
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
try:
r = subprocess.run(
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
_PS_EXTRACT_IMAGE],
capture_output=True, text=True, timeout=15,
)
if r.returncode != 0:
return False
b64_data = r.stdout.strip()
if not b64_data:
return False
png_bytes = base64.b64decode(b64_data)
dest.write_bytes(png_bytes)
return dest.exists() and dest.stat().st_size > 0
except FileNotFoundError:
logger.debug("powershell.exe not found — WSL clipboard unavailable")
except Exception as e:
logger.debug("WSL clipboard extraction failed: %s", e)
dest.unlink(missing_ok=True)
return False
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
def _wayland_has_image() -> bool:
"""Check if Wayland clipboard has image content."""
try:
r = subprocess.run(
["wl-paste", "--list-types"],
capture_output=True, text=True, timeout=3,
)
return r.returncode == 0 and any(
t.startswith("image/") for t in r.stdout.splitlines()
)
except FileNotFoundError:
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
except Exception:
pass
return False
def _wayland_save(dest: Path) -> bool:
"""Use wl-paste to extract clipboard image (Wayland sessions)."""
try:
# Check available MIME types
types_r = subprocess.run(
["wl-paste", "--list-types"],
capture_output=True, text=True, timeout=3,
)
if types_r.returncode != 0:
return False
types = types_r.stdout.splitlines()
# Prefer PNG, fall back to other image formats
mime = None
for preferred in ("image/png", "image/jpeg", "image/bmp",
"image/gif", "image/webp"):
if preferred in types:
mime = preferred
break
if not mime:
return False
# Extract the image data
with open(dest, "wb") as f:
subprocess.run(
["wl-paste", "--type", mime],
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
)
if not dest.exists() or dest.stat().st_size == 0:
return False
# BMP needs conversion to PNG (common in WSLg where only BMP
# is bridged from Windows clipboard via RDP).
if mime == "image/bmp":
return _convert_to_png(dest)
return True
except FileNotFoundError:
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
except Exception as e:
logger.debug("wl-paste clipboard extraction failed: %s", e)
dest.unlink(missing_ok=True)
return False
def _convert_to_png(path: Path) -> bool:
"""Convert an image file to PNG in-place (requires Pillow or ImageMagick)."""
# Try Pillow first (likely installed in the venv)
try:
from PIL import Image
img = Image.open(path)
img.save(path, "PNG")
return True
except ImportError:
pass
except Exception as e:
logger.debug("Pillow BMP→PNG conversion failed: %s", e)
# Fall back to ImageMagick convert
try:
tmp = path.with_suffix(".bmp")
path.rename(tmp)
r = subprocess.run(
["convert", str(tmp), "png:" + str(path)],
capture_output=True, timeout=5,
)
tmp.unlink(missing_ok=True)
if r.returncode == 0 and path.exists() and path.stat().st_size > 0:
return True
except FileNotFoundError:
logger.debug("ImageMagick not installed — cannot convert BMP to PNG")
except Exception as e:
logger.debug("ImageMagick BMP→PNG conversion failed: %s", e)
# Can't convert — BMP is still usable as-is for most APIs
return path.exists() and path.stat().st_size > 0
# ── X11 (xclip) ─────────────────────────────────────────────────────────
def _xclip_has_image() -> bool:
"""Check if X11 clipboard has image content."""
try:
r = subprocess.run(
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
capture_output=True, text=True, timeout=3,
)
return r.returncode == 0 and "image/png" in r.stdout
except FileNotFoundError:
pass
except Exception:
pass
return False
def _xclip_save(dest: Path) -> bool:
"""Use xclip to extract clipboard image (X11 sessions)."""
# Check if clipboard has image content
try:
targets = subprocess.run(
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
capture_output=True, text=True, timeout=3,
)
if "image/png" not in targets.stdout:
return False
except FileNotFoundError:
logger.debug("xclip not installed — X11 clipboard image paste unavailable")
return False
except Exception:
return False
# Extract PNG data
try:
with open(dest, "wb") as f:
subprocess.run(
["xclip", "-selection", "clipboard", "-t", "image/png", "-o"],
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
)
if dest.exists() and dest.stat().st_size > 0:
return True
except Exception as e:
logger.debug("xclip image extraction failed: %s", e)
dest.unlink(missing_ok=True)
return False

View File

@@ -1,9 +1,15 @@
"""Slash command definitions and autocomplete for the Hermes CLI.
Contains the COMMANDS dict and the SlashCommandCompleter class.
These are pure data/UI with no HermesCLI state dependency.
Contains the shared built-in ``COMMANDS`` dict and ``SlashCommandCompleter``.
The completer can optionally include dynamic skill slash commands supplied by the
interactive CLI.
"""
from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import Any
from prompt_toolkit.completion import Completer, Completion
@@ -12,6 +18,7 @@ COMMANDS = {
"/tools": "List available tools",
"/toolsets": "List available toolsets",
"/model": "Show or change the current model",
"/provider": "Show available providers and current provider",
"/prompt": "View/set custom system prompt",
"/personality": "Set a predefined personality",
"/clear": "Clear screen and reset conversation (fresh start)",
@@ -27,25 +34,68 @@ COMMANDS = {
"/platforms": "Show gateway/messaging platform status",
"/verbose": "Cycle tool progress display: off → new → all → verbose",
"/compress": "Manually compress conversation context (flush memories + summarize)",
"/title": "Set a title for the current session (usage: /title My Session Name)",
"/usage": "Show token usage for the current session",
"/insights": "Show usage insights and analytics (last 30 days)",
"/paste": "Check clipboard for an image and attach it",
"/reload-mcp": "Reload MCP servers from config.yaml",
"/quit": "Exit the CLI (also: /exit, /q)",
}
class SlashCommandCompleter(Completer):
"""Autocomplete for /commands in the input area."""
"""Autocomplete for built-in slash commands and optional skill commands."""
def __init__(
self,
skill_commands_provider: Callable[[], Mapping[str, dict[str, Any]]] | None = None,
) -> None:
self._skill_commands_provider = skill_commands_provider
def _iter_skill_commands(self) -> Mapping[str, dict[str, Any]]:
if self._skill_commands_provider is None:
return {}
try:
return self._skill_commands_provider() or {}
except Exception:
return {}
@staticmethod
def _completion_text(cmd_name: str, word: str) -> str:
"""Return replacement text for a completion.
When the user has already typed the full command exactly (``/help``),
returning ``help`` would be a no-op and prompt_toolkit suppresses the
menu. Appending a trailing space keeps the dropdown visible and makes
backspacing retrigger it naturally.
"""
return f"{cmd_name} " if cmd_name == word else cmd_name
def get_completions(self, document, complete_event):
text = document.text_before_cursor
if not text.startswith("/"):
return
word = text[1:]
for cmd, desc in COMMANDS.items():
cmd_name = cmd[1:]
if cmd_name.startswith(word):
yield Completion(
cmd_name,
self._completion_text(cmd_name, word),
start_position=-len(word),
display=cmd,
display_meta=desc,
)
for cmd, info in self._iter_skill_commands().items():
cmd_name = cmd[1:]
if cmd_name.startswith(word):
description = str(info.get("description", "Skill command"))
short_desc = description[:50] + ("..." if len(description) > 50 else "")
yield Completion(
self._completion_text(cmd_name, word),
start_position=-len(word),
display=cmd,
display_meta=f"{short_desc}",
)

View File

@@ -71,7 +71,8 @@ DEFAULT_CONFIG = {
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
"modal_image": "nikolaik/python-nodejs:python3.11-nodejs20",
# Container resource limits (docker, singularity, modal — ignored for local/ssh)
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
# Container resource limits (docker, singularity, modal, daytona — ignored for local/ssh)
"container_cpu": 1,
"container_memory": 5120, # MB (default 5GB)
"container_disk": 51200, # MB (default 50GB)
@@ -140,9 +141,13 @@ DEFAULT_CONFIG = {
# (apiKey, workspace, peerName, sessions, enabled) comes from the global config.
"honcho": {},
# IANA timezone (e.g. "Asia/Kolkata", "America/New_York").
# Empty string means use server-local time.
"timezone": "",
# Permanently allowed dangerous command patterns (added via "always" approval)
"command_allowlist": [],
# Config schema version - bump this when adding new required fields
"_config_version": 5,
}
@@ -151,6 +156,15 @@ DEFAULT_CONFIG = {
# Config Migration System
# =============================================================================
# Track which env vars were introduced in each config version.
# Migration only mentions vars new since the user's previous version.
ENV_VARS_BY_VERSION: Dict[int, List[str]] = {
3: ["FIRECRAWL_API_KEY", "BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID", "FAL_KEY"],
4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"],
5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS",
"SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"],
}
# Required environment variables with metadata for migration prompts.
# LLM provider is required but handled in the setup wizard's provider
# selection step (Nous Portal / OpenRouter / Custom endpoint), so this
@@ -169,6 +183,86 @@ OPTIONAL_ENV_VARS = {
"category": "provider",
"advanced": True,
},
"GLM_API_KEY": {
"description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)",
"prompt": "Z.AI / GLM API key",
"url": "https://z.ai/",
"password": True,
"category": "provider",
"advanced": True,
},
"ZAI_API_KEY": {
"description": "Z.AI API key (alias for GLM_API_KEY)",
"prompt": "Z.AI API key",
"url": "https://z.ai/",
"password": True,
"category": "provider",
"advanced": True,
},
"Z_AI_API_KEY": {
"description": "Z.AI API key (alias for GLM_API_KEY)",
"prompt": "Z.AI API key",
"url": "https://z.ai/",
"password": True,
"category": "provider",
"advanced": True,
},
"GLM_BASE_URL": {
"description": "Z.AI / GLM base URL override",
"prompt": "Z.AI / GLM base URL (leave empty for default)",
"url": None,
"password": False,
"category": "provider",
"advanced": True,
},
"KIMI_API_KEY": {
"description": "Kimi / Moonshot API key",
"prompt": "Kimi API key",
"url": "https://platform.moonshot.cn/",
"password": True,
"category": "provider",
"advanced": True,
},
"KIMI_BASE_URL": {
"description": "Kimi / Moonshot base URL override",
"prompt": "Kimi base URL (leave empty for default)",
"url": None,
"password": False,
"category": "provider",
"advanced": True,
},
"MINIMAX_API_KEY": {
"description": "MiniMax API key (international)",
"prompt": "MiniMax API key",
"url": "https://www.minimax.io/",
"password": True,
"category": "provider",
"advanced": True,
},
"MINIMAX_BASE_URL": {
"description": "MiniMax base URL override",
"prompt": "MiniMax base URL (leave empty for default)",
"url": None,
"password": False,
"category": "provider",
"advanced": True,
},
"MINIMAX_CN_API_KEY": {
"description": "MiniMax API key (China endpoint)",
"prompt": "MiniMax (China) API key",
"url": "https://www.minimaxi.com/",
"password": True,
"category": "provider",
"advanced": True,
},
"MINIMAX_CN_BASE_URL": {
"description": "MiniMax (China) base URL override",
"prompt": "MiniMax (China) base URL (leave empty for default)",
"url": None,
"password": False,
"category": "provider",
"advanced": True,
},
# ── Tool API keys ──
"FIRECRAWL_API_KEY": {
@@ -179,8 +273,16 @@ OPTIONAL_ENV_VARS = {
"password": True,
"category": "tool",
},
"FIRECRAWL_API_URL": {
"description": "Firecrawl API URL for self-hosted instances (optional)",
"prompt": "Firecrawl API URL (leave empty for cloud)",
"url": None,
"password": False,
"category": "tool",
"advanced": True,
},
"BROWSERBASE_API_KEY": {
"description": "Browserbase API key for browser automation",
"description": "Browserbase API key for cloud browser (optional — local browser works without this)",
"prompt": "Browserbase API key",
"url": "https://browserbase.com/",
"tools": ["browser_navigate", "browser_click"],
@@ -188,7 +290,7 @@ OPTIONAL_ENV_VARS = {
"category": "tool",
},
"BROWSERBASE_PROJECT_ID": {
"description": "Browserbase project ID",
"description": "Browserbase project ID (optional — only needed for cloud browser)",
"prompt": "Browserbase project ID",
"url": "https://browserbase.com/",
"tools": ["browser_navigate", "browser_click"],
@@ -476,6 +578,22 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
if not quiet:
print(f" ✓ Migrated tool progress to config.yaml: {display['tool_progress']}")
# ── Version 4 → 5: add timezone field ──
if current_ver < 5:
config = load_config()
if "timezone" not in config:
old_tz = os.getenv("HERMES_TIMEZONE", "")
if old_tz and old_tz.strip():
config["timezone"] = old_tz.strip()
results["config_added"].append(f"timezone={old_tz.strip()} (from HERMES_TIMEZONE)")
else:
config["timezone"] = ""
results["config_added"].append("timezone= (empty, uses server-local)")
save_config(config)
if not quiet:
tz_display = config["timezone"] or "(server-local)"
print(f" ✓ Added timezone to config.yaml: {tz_display}")
if current_ver < latest_ver and not quiet:
print(f"Config version: {current_ver}{latest_ver}")
@@ -516,34 +634,47 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
if v["name"] not in required_names and not v.get("advanced")
]
if interactive and missing_optional:
print(" Would you like to configure any optional keys now?")
try:
answer = input(" Configure optional keys? [y/N]: ").strip().lower()
except (EOFError, KeyboardInterrupt):
answer = "n"
if answer in ("y", "yes"):
# Only offer to configure env vars that are NEW since the user's previous version
new_var_names = set()
for ver in range(current_ver + 1, latest_ver + 1):
new_var_names.update(ENV_VARS_BY_VERSION.get(ver, []))
if new_var_names and interactive and not quiet:
new_and_unset = [
(name, OPTIONAL_ENV_VARS[name])
for name in sorted(new_var_names)
if not get_env_value(name) and name in OPTIONAL_ENV_VARS
]
if new_and_unset:
print(f"\n {len(new_and_unset)} new optional key(s) in this update:")
for name, info in new_and_unset:
print(f"{name}{info.get('description', '')}")
print()
for var in missing_optional:
desc = var.get("description", "")
if var.get("url"):
print(f" {desc}")
print(f" Get your key at: {var['url']}")
else:
print(f" {desc}")
if var.get("password"):
import getpass
value = getpass.getpass(f" {var['prompt']} (Enter to skip): ")
else:
value = input(f" {var['prompt']} (Enter to skip): ").strip()
if value:
save_env_value(var["name"], value)
results["env_added"].append(var["name"])
print(f" ✓ Saved {var['name']}")
try:
answer = input(" Configure new keys? [y/N]: ").strip().lower()
except (EOFError, KeyboardInterrupt):
answer = "n"
if answer in ("y", "yes"):
print()
for name, info in new_and_unset:
if info.get("url"):
print(f" {info.get('description', name)}")
print(f" Get your key at: {info['url']}")
else:
print(f" {info.get('description', name)}")
if info.get("password"):
import getpass
value = getpass.getpass(f" {info.get('prompt', name)} (Enter to skip): ")
else:
value = input(f" {info.get('prompt', name)} (Enter to skip): ").strip()
if value:
save_env_value(name, value)
results["env_added"].append(name)
print(f" ✓ Saved {name}")
print()
else:
print(" Set later with: hermes config set KEY VALUE")
# Check for missing config fields
missing_config = get_missing_config_fields()
@@ -753,12 +884,25 @@ def show_config():
print(f" Modal image: {terminal.get('modal_image', 'python:3.11')}")
modal_token = get_env_value('MODAL_TOKEN_ID')
print(f" Modal token: {'configured' if modal_token else '(not set)'}")
elif terminal.get('backend') == 'daytona':
print(f" Daytona image: {terminal.get('daytona_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
daytona_key = get_env_value('DAYTONA_API_KEY')
print(f" API key: {'configured' if daytona_key else '(not set)'}")
elif terminal.get('backend') == 'ssh':
ssh_host = get_env_value('TERMINAL_SSH_HOST')
ssh_user = get_env_value('TERMINAL_SSH_USER')
print(f" SSH host: {ssh_host or '(not set)'}")
print(f" SSH user: {ssh_user or '(not set)'}")
# Timezone
print()
print(color("◆ Timezone", Colors.CYAN, Colors.BOLD))
tz = config.get('timezone', '')
if tz:
print(f" Timezone: {tz}")
else:
print(f" Timezone: {color('(server-local)', Colors.DIM)}")
# Compression
print()
print(color("◆ Context Compression", Colors.CYAN, Colors.BOLD))
@@ -820,15 +964,16 @@ def set_config_value(key: str, value: str):
"""Set a configuration value."""
# Check if it's an API key (goes to .env)
api_keys = [
'OPENROUTER_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
'FIRECRAWL_API_KEY', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID',
'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID',
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
'GITHUB_TOKEN', 'HONCHO_API_KEY',
'GITHUB_TOKEN', 'HONCHO_API_KEY', 'NOUS_API_KEY', 'WANDB_API_KEY',
'TINKER_API_KEY',
]
if key.upper() in api_keys or key.upper().startswith('TERMINAL_SSH'):
if key.upper() in api_keys or key.upper().endswith('_API_KEY') or key.upper().endswith('_TOKEN') or key.upper().startswith('TERMINAL_SSH'):
save_env_value(key.upper(), value)
print(f"✓ Set {key} in {get_env_path()}")
return
@@ -878,8 +1023,10 @@ def set_config_value(key: str, value: str):
"terminal.docker_image": "TERMINAL_DOCKER_IMAGE",
"terminal.singularity_image": "TERMINAL_SINGULARITY_IMAGE",
"terminal.modal_image": "TERMINAL_MODAL_IMAGE",
"terminal.daytona_image": "TERMINAL_DAYTONA_IMAGE",
"terminal.cwd": "TERMINAL_CWD",
"terminal.timeout": "TERMINAL_TIMEOUT",
"terminal.sandbox_dir": "TERMINAL_SANDBOX_DIR",
}
if key in _config_to_env_sync:
save_env_value(_config_to_env_sync[key], str(value))

View File

@@ -33,6 +33,26 @@ os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
from hermes_cli.colors import Colors, color
from hermes_constants import OPENROUTER_MODELS_URL
_PROVIDER_ENV_HINTS = (
"OPENROUTER_API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"OPENAI_BASE_URL",
"GLM_API_KEY",
"ZAI_API_KEY",
"Z_AI_API_KEY",
"KIMI_API_KEY",
"MINIMAX_API_KEY",
"MINIMAX_CN_API_KEY",
)
def _has_provider_env_config(content: str) -> bool:
"""Return True when ~/.hermes/.env contains provider auth/base URL settings."""
return any(key in content for key in _PROVIDER_ENV_HINTS)
def check_ok(text: str, detail: str = ""):
print(f" {color('', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
@@ -132,8 +152,8 @@ def run_doctor(args):
# Check for common issues
content = env_path.read_text()
if "OPENROUTER_API_KEY" in content or "ANTHROPIC_API_KEY" in content:
check_ok("API key configured")
if _has_provider_env_config(content):
check_ok("API key or custom endpoint configured")
else:
check_warn("No API key found in ~/.hermes/.env")
issues.append("Run 'hermes setup' to configure API keys")
@@ -355,6 +375,21 @@ def run_doctor(args):
check_fail("TERMINAL_SSH_HOST not set", "(required for TERMINAL_ENV=ssh)")
issues.append("Set TERMINAL_SSH_HOST in .env")
# Daytona (if using daytona backend)
if terminal_env == "daytona":
daytona_key = os.getenv("DAYTONA_API_KEY")
if daytona_key:
check_ok("Daytona API key", "(configured)")
else:
check_fail("DAYTONA_API_KEY not set", "(required for TERMINAL_ENV=daytona)")
issues.append("Set DAYTONA_API_KEY environment variable")
try:
from daytona import Daytona
check_ok("daytona SDK", "(installed)")
except ImportError:
check_fail("daytona SDK not installed", "(pip install daytona)")
issues.append("Install daytona SDK: pip install daytona")
# Node.js + agent-browser (for browser automation tools)
if shutil.which("node"):
check_ok("Node.js")
@@ -453,7 +488,48 @@ def run_doctor(args):
print(f"\r {color('', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ")
except Exception as e:
print(f"\r {color('', Colors.YELLOW)} Anthropic API {color(f'({e})', Colors.DIM)} ")
# -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) --
_apikey_providers = [
("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"),
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
]
for _pname, _env_vars, _default_url, _base_env in _apikey_providers:
_key = ""
for _ev in _env_vars:
_key = os.getenv(_ev, "")
if _key:
break
if _key:
_label = _pname.ljust(20)
print(f" Checking {_pname} API...", end="", flush=True)
try:
import httpx
_base = os.getenv(_base_env, "")
# Auto-detect Kimi Code keys (sk-kimi-) → api.kimi.com
if not _base and _key.startswith("sk-kimi-"):
_base = "https://api.kimi.com/coding/v1"
_url = (_base.rstrip("/") + "/models") if _base else _default_url
_headers = {"Authorization": f"Bearer {_key}"}
if "api.kimi.com" in _url.lower():
_headers["User-Agent"] = "KimiCLI/1.0"
_resp = httpx.get(
_url,
headers=_headers,
timeout=10,
)
if _resp.status_code == 200:
print(f"\r {color('', Colors.GREEN)} {_label} ")
elif _resp.status_code == 401:
print(f"\r {color('', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ")
issues.append(f"Check {_env_vars[0]} in .env")
else:
print(f"\r {color('', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ")
except Exception as _e:
print(f"\r {color('', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ")
# =========================================================================
# Check: Submodules
# =========================================================================

View File

@@ -154,19 +154,33 @@ def get_hermes_cli_path() -> str:
# =============================================================================
def generate_systemd_unit() -> str:
import shutil
python_path = get_python_path()
working_dir = str(PROJECT_ROOT)
venv_dir = str(PROJECT_ROOT / "venv")
venv_bin = str(PROJECT_ROOT / "venv" / "bin")
node_bin = str(PROJECT_ROOT / "node_modules" / ".bin")
# Build a PATH that includes the venv, node_modules, and standard system dirs
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
return f"""[Unit]
Description={SERVICE_DESCRIPTION}
After=network.target
[Service]
Type=simple
ExecStart={python_path} -m hermes_cli.main gateway run
ExecStart={python_path} -m hermes_cli.main gateway run --replace
ExecStop={hermes_cli} gateway stop
WorkingDirectory={working_dir}
Environment="PATH={sane_path}"
Environment="VIRTUAL_ENV={venv_dir}"
Restart=on-failure
RestartSec=10
KillMode=mixed
KillSignal=SIGTERM
TimeoutStopSec=15
StandardOutput=journal
StandardError=journal
@@ -377,8 +391,15 @@ def launchd_status(deep: bool = False):
# Gateway Runner
# =============================================================================
def run_gateway(verbose: bool = False):
"""Run the gateway in foreground."""
def run_gateway(verbose: bool = False, replace: bool = False):
"""Run the gateway in foreground.
Args:
verbose: Enable verbose logging output.
replace: If True, kill any existing gateway instance before starting.
This prevents systemd restart loops when the old process
hasn't fully exited yet.
"""
sys.path.insert(0, str(PROJECT_ROOT))
from gateway.run import start_gateway
@@ -393,7 +414,7 @@ def run_gateway(verbose: bool = False):
# Exit with code 1 if gateway fails to connect any platform,
# so systemd Restart=on-failure will retry on transient errors
success = asyncio.run(start_gateway())
success = asyncio.run(start_gateway(replace=replace))
if not success:
sys.exit(1)
@@ -765,7 +786,8 @@ def gateway_command(args):
# Default to run if no subcommand
if subcmd is None or subcmd == "run":
verbose = getattr(args, 'verbose', False)
run_gateway(verbose)
replace = getattr(args, 'replace', False)
run_gateway(verbose, replace=replace)
return
if subcmd == "setup":

View File

@@ -64,7 +64,13 @@ def _has_any_provider_configured() -> bool:
# Check env vars (may be set by .env or shell).
# OPENAI_BASE_URL alone counts — local models (vLLM, llama.cpp, etc.)
# often don't require an API key.
provider_env_vars = ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL")
from hermes_cli.auth import PROVIDER_REGISTRY
# Collect all provider env vars
provider_env_vars = {"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL"}
for pconfig in PROVIDER_REGISTRY.values():
if pconfig.auth_type == "api_key":
provider_env_vars.update(pconfig.api_key_env_vars)
if any(os.getenv(v) for v in provider_env_vars):
return True
@@ -114,16 +120,63 @@ def _resolve_last_cli_session() -> Optional[str]:
return None
def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
"""Resolve a session name (title) or ID to a session ID.
- If it looks like a session ID (contains underscore + hex), try direct lookup first.
- Otherwise, treat it as a title and use resolve_session_by_title (auto-latest).
- Falls back to the other method if the first doesn't match.
"""
try:
from hermes_state import SessionDB
db = SessionDB()
# Try as exact session ID first
session = db.get_session(name_or_id)
if session:
db.close()
return session["id"]
# Try as title (with auto-latest for lineage)
session_id = db.resolve_session_by_title(name_or_id)
db.close()
return session_id
except Exception:
pass
return None
def cmd_chat(args):
"""Run interactive chat CLI."""
# Resolve --continue into --resume with the latest CLI session
if getattr(args, "continue_last", False) and not getattr(args, "resume", None):
last_id = _resolve_last_cli_session()
if last_id:
args.resume = last_id
# Resolve --continue into --resume with the latest CLI session or by name
continue_val = getattr(args, "continue_last", None)
if continue_val and not getattr(args, "resume", None):
if isinstance(continue_val, str):
# -c "session name" — resolve by title or ID
resolved = _resolve_session_by_name_or_id(continue_val)
if resolved:
args.resume = resolved
else:
print(f"No session found matching '{continue_val}'.")
print("Use 'hermes sessions list' to see available sessions.")
sys.exit(1)
else:
print("No previous CLI session found to continue.")
sys.exit(1)
# -c with no argument — continue the most recent session
last_id = _resolve_last_cli_session()
if last_id:
args.resume = last_id
else:
print("No previous CLI session found to continue.")
sys.exit(1)
# Resolve --resume by title if it's not a direct session ID
resume_val = getattr(args, "resume", None)
if resume_val:
resolved = _resolve_session_by_name_or_id(resume_val)
if resolved:
args.resume = resolved
# If resolution fails, keep the original value — _init_agent will
# report "Session not found" with the original input
# First-run guard: check if any provider is configured before launching
if not _has_any_provider_configured():
@@ -143,6 +196,13 @@ def cmd_chat(args):
print("You can run 'hermes setup' at any time to configure.")
sys.exit(1)
# Sync bundled skills on every CLI launch (fast -- skips unchanged skills)
try:
from tools.skills_sync import sync_skills
sync_skills(quiet=True)
except Exception:
pass
# Import and run the CLI
from cli import main as cli_main
@@ -154,6 +214,7 @@ def cmd_chat(args):
"verbose": args.verbose,
"query": args.query,
"resume": getattr(args, "resume", None),
"worktree": getattr(args, "worktree", False),
}
# Filter out None values
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@@ -404,6 +465,10 @@ def cmd_model(args):
"openrouter": "OpenRouter",
"nous": "Nous Portal",
"openai-codex": "OpenAI Codex",
"zai": "Z.AI / GLM",
"kimi-coding": "Kimi / Moonshot",
"minimax": "MiniMax",
"minimax-cn": "MiniMax (China)",
"custom": "Custom endpoint",
}
active_label = provider_labels.get(active, active)
@@ -418,11 +483,16 @@ def cmd_model(args):
("openrouter", "OpenRouter (100+ models, pay-per-use)"),
("nous", "Nous Portal (Nous Research subscription)"),
("openai-codex", "OpenAI Codex"),
("zai", "Z.AI / GLM (Zhipu AI direct API)"),
("kimi-coding", "Kimi / Moonshot (Moonshot AI direct API)"),
("minimax", "MiniMax (global direct API)"),
("minimax-cn", "MiniMax China (domestic direct API)"),
("custom", "Custom endpoint (self-hosted / VLLM / etc.)"),
]
# Reorder so the active provider is at the top
active_key = active if active in ("openrouter", "nous", "openai-codex") else "custom"
known_keys = {k for k, _ in providers}
active_key = active if active in known_keys else "custom"
ordered = []
for key, label in providers:
if key == active_key:
@@ -447,6 +517,8 @@ def cmd_model(args):
_model_flow_openai_codex(config, current_model)
elif selected_provider == "custom":
_model_flow_custom(config)
elif selected_provider in ("zai", "kimi-coding", "minimax", "minimax-cn"):
_model_flow_api_key_provider(config, selected_provider, current_model)
def _prompt_provider_choice(choices):
@@ -716,6 +788,117 @@ def _model_flow_custom(config):
print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.")
# Curated model lists for direct API-key providers
_PROVIDER_MODELS = {
"zai": [
"glm-5",
"glm-4.7",
"glm-4.5",
"glm-4.5-flash",
],
"kimi-coding": [
"kimi-k2.5",
"kimi-k2-thinking",
"kimi-k2-turbo-preview",
"kimi-k2-0905-preview",
],
"minimax": [
"MiniMax-M2.5",
"MiniMax-M2.5-highspeed",
"MiniMax-M2.1",
],
"minimax-cn": [
"MiniMax-M2.5",
"MiniMax-M2.5-highspeed",
"MiniMax-M2.1",
],
}
def _model_flow_api_key_provider(config, provider_id, current_model=""):
"""Generic flow for API-key providers (z.ai, Kimi, MiniMax)."""
from hermes_cli.auth import (
PROVIDER_REGISTRY, _prompt_model_selection, _save_model_choice,
_update_config_for_provider, deactivate_provider,
)
from hermes_cli.config import get_env_value, save_env_value, load_config, save_config
pconfig = PROVIDER_REGISTRY[provider_id]
key_env = pconfig.api_key_env_vars[0] if pconfig.api_key_env_vars else ""
base_url_env = pconfig.base_url_env_var or ""
# Check / prompt for API key
existing_key = ""
for ev in pconfig.api_key_env_vars:
existing_key = get_env_value(ev) or os.getenv(ev, "")
if existing_key:
break
if not existing_key:
print(f"No {pconfig.name} API key configured.")
if key_env:
try:
new_key = input(f"{key_env} (or Enter to cancel): ").strip()
except (KeyboardInterrupt, EOFError):
print()
return
if not new_key:
print("Cancelled.")
return
save_env_value(key_env, new_key)
print("API key saved.")
print()
else:
print(f" {pconfig.name} API key: {existing_key[:8]}... ✓")
print()
# Optional base URL override
current_base = ""
if base_url_env:
current_base = get_env_value(base_url_env) or os.getenv(base_url_env, "")
effective_base = current_base or pconfig.inference_base_url
try:
override = input(f"Base URL [{effective_base}]: ").strip()
except (KeyboardInterrupt, EOFError):
print()
override = ""
if override and base_url_env:
save_env_value(base_url_env, override)
effective_base = override
# Model selection
model_list = _PROVIDER_MODELS.get(provider_id, [])
if model_list:
selected = _prompt_model_selection(model_list, current_model=current_model)
else:
try:
selected = input("Model name: ").strip()
except (KeyboardInterrupt, EOFError):
selected = None
if selected:
# Clear custom endpoint if set (avoid confusion)
if get_env_value("OPENAI_BASE_URL"):
save_env_value("OPENAI_BASE_URL", "")
save_env_value("OPENAI_API_KEY", "")
_save_model_choice(selected)
# Update config with provider and base URL
cfg = load_config()
model = cfg.get("model")
if isinstance(model, dict):
model["provider"] = provider_id
model["base_url"] = effective_base
save_config(cfg)
deactivate_provider()
print(f"Default model set to: {selected} (via {pconfig.name})")
else:
print("No change.")
def cmd_login(args):
"""Authenticate Hermes CLI with a provider."""
from hermes_cli.auth import login_command
@@ -851,11 +1034,17 @@ def _update_via_zip(args):
# Sync skills
try:
from tools.skills_sync import sync_skills
print("Checking for new bundled skills...")
print("Syncing bundled skills...")
result = sync_skills(quiet=True)
if result["copied"]:
print(f" + {len(result['copied'])} new skill(s): {', '.join(result['copied'])}")
else:
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
if result.get("updated"):
print(f"{len(result['updated'])} updated: {', '.join(result['updated'])}")
if result.get("user_modified"):
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
if result.get("cleaned"):
print(f" {len(result['cleaned'])} removed from manifest")
if not result["copied"] and not result.get("updated"):
print(" ✓ Skills are up to date")
except Exception:
pass
@@ -961,15 +1150,21 @@ def cmd_update(args):
print()
print("✓ Code updated!")
# Sync any new bundled skills (manifest-based -- won't overwrite or re-add deleted skills)
# Sync bundled skills (copies new, updates changed, respects user deletions)
try:
from tools.skills_sync import sync_skills
print()
print("Checking for new bundled skills...")
print("Syncing bundled skills...")
result = sync_skills(quiet=True)
if result["copied"]:
print(f" + {len(result['copied'])} new skill(s): {', '.join(result['copied'])}")
else:
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
if result.get("updated"):
print(f"{len(result['updated'])} updated: {', '.join(result['updated'])}")
if result.get("user_modified"):
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
if result.get("cleaned"):
print(f" {len(result['cleaned'])} removed from manifest")
if not result["copied"] and not result.get("updated"):
print(" ✓ Skills are up to date")
except Exception as e:
logger.debug("Skills sync during update failed: %s", e)
@@ -1061,8 +1256,9 @@ def main():
Examples:
hermes Start interactive chat
hermes chat -q "Hello" Single query mode
hermes --continue Resume the most recent session
hermes --resume <session_id> Resume a specific session
hermes -c Resume the most recent session
hermes -c "my project" Resume a session by name (latest in lineage)
hermes --resume <session_id> Resume a specific session by ID
hermes setup Run setup wizard
hermes logout Clear stored authentication
hermes model Select default model
@@ -1070,8 +1266,10 @@ Examples:
hermes config edit Edit config in $EDITOR
hermes config set model gpt-4 Set a config value
hermes gateway Run messaging gateway
hermes -w Start in isolated git worktree
hermes gateway install Install as system service
hermes sessions list List past sessions
hermes sessions rename ID T Rename/title a session
hermes update Update to latest version
For more help on a command:
@@ -1086,16 +1284,24 @@ For more help on a command:
)
parser.add_argument(
"--resume", "-r",
metavar="SESSION_ID",
metavar="SESSION",
default=None,
help="Resume a previous session by ID (shortcut for: hermes chat --resume ID)"
help="Resume a previous session by ID or title"
)
parser.add_argument(
"--continue", "-c",
dest="continue_last",
nargs="?",
const=True,
default=None,
metavar="SESSION_NAME",
help="Resume a session by name, or the most recent if no name given"
)
parser.add_argument(
"--worktree", "-w",
action="store_true",
default=False,
help="Resume the most recent CLI session"
help="Run in an isolated git worktree (for parallel agents)"
)
subparsers = parser.add_subparsers(dest="command", help="Command to run")
@@ -1122,7 +1328,7 @@ For more help on a command:
)
chat_parser.add_argument(
"--provider",
choices=["auto", "openrouter", "nous", "openai-codex"],
choices=["auto", "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn"],
default=None,
help="Inference provider (default: auto)"
)
@@ -1139,9 +1345,17 @@ For more help on a command:
chat_parser.add_argument(
"--continue", "-c",
dest="continue_last",
nargs="?",
const=True,
default=None,
metavar="SESSION_NAME",
help="Resume a session by name, or the most recent if no name given"
)
chat_parser.add_argument(
"--worktree", "-w",
action="store_true",
default=False,
help="Resume the most recent CLI session"
help="Run in an isolated git worktree (for parallel agents on the same repo)"
)
chat_parser.set_defaults(func=cmd_chat)
@@ -1168,6 +1382,8 @@ For more help on a command:
# gateway run (default)
gateway_run = gateway_subparsers.add_parser("run", help="Run gateway in foreground")
gateway_run.add_argument("-v", "--verbose", action="store_true")
gateway_run.add_argument("--replace", action="store_true",
help="Replace any existing gateway instance (useful for systemd)")
# gateway start
gateway_start = gateway_subparsers.add_parser("start", help="Start gateway service")
@@ -1200,7 +1416,15 @@ For more help on a command:
setup_parser = subparsers.add_parser(
"setup",
help="Interactive setup wizard",
description="Configure Hermes Agent with an interactive wizard"
description="Configure Hermes Agent with an interactive wizard. "
"Run a specific section: hermes setup model|terminal|gateway|tools|agent"
)
setup_parser.add_argument(
"section",
nargs="?",
choices=["model", "terminal", "gateway", "tools", "agent"],
default=None,
help="Run a specific setup section instead of the full wizard"
)
setup_parser.add_argument(
"--non-interactive",
@@ -1424,9 +1648,16 @@ For more help on a command:
)
skills_subparsers = skills_parser.add_subparsers(dest="skills_action")
skills_browse = skills_subparsers.add_parser("browse", help="Browse all available skills (paginated)")
skills_browse.add_argument("--page", type=int, default=1, help="Page number (default: 1)")
skills_browse.add_argument("--size", type=int, default=20, help="Results per page (default: 20)")
skills_browse.add_argument("--source", default="all",
choices=["all", "official", "github", "clawhub", "lobehub"],
help="Filter by source (default: all)")
skills_search = skills_subparsers.add_parser("search", help="Search skill registries")
skills_search.add_argument("query", help="Search query")
skills_search.add_argument("--source", default="all", choices=["all", "github", "clawhub", "lobehub"])
skills_search.add_argument("--source", default="all", choices=["all", "official", "github", "clawhub", "lobehub"])
skills_search.add_argument("--limit", type=int, default=10, help="Max results")
skills_install = skills_subparsers.add_parser("install", help="Install a skill")
@@ -1493,7 +1724,7 @@ For more help on a command:
# =========================================================================
sessions_parser = subparsers.add_parser(
"sessions",
help="Manage session history (list, export, prune, delete)",
help="Manage session history (list, rename, export, prune, delete)",
description="View and manage the SQLite session store"
)
sessions_subparsers = sessions_parser.add_subparsers(dest="sessions_action")
@@ -1518,6 +1749,10 @@ For more help on a command:
sessions_stats = sessions_subparsers.add_parser("stats", help="Show session store statistics")
sessions_rename = sessions_subparsers.add_parser("rename", help="Set or change a session's title")
sessions_rename.add_argument("session_id", help="Session ID to rename")
sessions_rename.add_argument("title", nargs="+", help="New title for the session")
def cmd_sessions(args):
import json as _json
try:
@@ -1530,18 +1765,51 @@ For more help on a command:
action = args.sessions_action
if action == "list":
sessions = db.search_sessions(source=args.source, limit=args.limit)
sessions = db.list_sessions_rich(source=args.source, limit=args.limit)
if not sessions:
print("No sessions found.")
return
print(f"{'ID':<30} {'Source':<12} {'Model':<30} {'Messages':>8} {'Started'}")
print("" * 100)
from datetime import datetime
import time as _time
def _relative_time(ts):
"""Format a timestamp as relative time (e.g., '2h ago', 'yesterday')."""
if not ts:
return "?"
delta = _time.time() - ts
if delta < 60:
return "just now"
elif delta < 3600:
mins = int(delta / 60)
return f"{mins}m ago"
elif delta < 86400:
hours = int(delta / 3600)
return f"{hours}h ago"
elif delta < 172800:
return "yesterday"
elif delta < 604800:
days = int(delta / 86400)
return f"{days}d ago"
else:
return datetime.fromtimestamp(ts).strftime("%Y-%m-%d")
has_titles = any(s.get("title") for s in sessions)
if has_titles:
print(f"{'Title':<22} {'Preview':<40} {'Last Active':<13} {'ID'}")
print("" * 100)
else:
print(f"{'Preview':<50} {'Last Active':<13} {'Src':<6} {'ID'}")
print("" * 90)
for s in sessions:
started = datetime.fromtimestamp(s["started_at"]).strftime("%Y-%m-%d %H:%M") if s["started_at"] else "?"
model = (s.get("model") or "?")[:28]
ended = " (ended)" if s.get("ended_at") else ""
print(f"{s['id']:<30} {s['source']:<12} {model:<30} {s['message_count']:>8} {started}{ended}")
last_active = _relative_time(s.get("last_active"))
preview = s.get("preview", "")[:38] if has_titles else s.get("preview", "")[:48]
if has_titles:
title = (s.get("title") or "")[:20]
sid = s["id"][:20]
print(f"{title:<22} {preview:<40} {last_active:<13} {sid}")
else:
sid = s["id"][:20]
print(f"{preview:<50} {last_active:<13} {s['source']:<6} {sid}")
elif action == "export":
if args.session_id:
@@ -1581,6 +1849,16 @@ For more help on a command:
count = db.prune_sessions(older_than_days=days, source=args.source)
print(f"Pruned {count} session(s).")
elif action == "rename":
title = " ".join(args.title)
try:
if db.set_session_title(args.session_id, title):
print(f"Session '{args.session_id}' renamed to: {title}")
else:
print(f"Session '{args.session_id}' not found.")
except ValueError as e:
print(f"Error: {e}")
elif action == "stats":
total = db.session_count()
msgs = db.message_count()
@@ -1603,6 +1881,32 @@ For more help on a command:
sessions_parser.set_defaults(func=cmd_sessions)
# =========================================================================
# insights command
# =========================================================================
insights_parser = subparsers.add_parser(
"insights",
help="Show usage insights and analytics",
description="Analyze session history to show token usage, costs, tool patterns, and activity trends"
)
insights_parser.add_argument("--days", type=int, default=30, help="Number of days to analyze (default: 30)")
insights_parser.add_argument("--source", help="Filter by platform (cli, telegram, discord, etc.)")
def cmd_insights(args):
try:
from hermes_state import SessionDB
from agent.insights import InsightsEngine
db = SessionDB()
engine = InsightsEngine(db)
report = engine.generate(days=args.days, source=args.source)
print(engine.format_terminal(report))
db.close()
except Exception as e:
print(f"Error generating insights: {e}")
insights_parser.set_defaults(func=cmd_insights)
# =========================================================================
# version command
# =========================================================================
@@ -1660,6 +1964,8 @@ For more help on a command:
args.provider = None
args.toolsets = None
args.verbose = False
if not hasattr(args, "worktree"):
args.worktree = False
cmd_chat(args)
return
@@ -1671,7 +1977,9 @@ For more help on a command:
args.toolsets = None
args.verbose = False
args.resume = None
args.continue_last = False
args.continue_last = None
if not hasattr(args, "worktree"):
args.worktree = False
cmd_chat(args)
return

View File

@@ -1,27 +1,85 @@
"""
Canonical list of OpenRouter models offered in CLI and setup wizards.
Canonical model catalogs and lightweight validation helpers.
Add, remove, or reorder entries here — both `hermes setup` and
`hermes` provider-selection will pick up the change automatically.
"""
from __future__ import annotations
import json
import urllib.request
import urllib.error
from difflib import get_close_matches
from typing import Any, Optional
# (model_id, display description shown in menus)
OPENROUTER_MODELS: list[tuple[str, str]] = [
("anthropic/claude-opus-4.6", "recommended"),
("anthropic/claude-sonnet-4.5", ""),
("anthropic/claude-opus-4.5", ""),
("openai/gpt-5.2", ""),
("openai/gpt-5.4-pro", ""),
("openai/gpt-5.4", ""),
("openai/gpt-5.3-codex", ""),
("google/gemini-3-pro-preview", ""),
("google/gemini-3-flash-preview", ""),
("z-ai/glm-4.7", ""),
("qwen/qwen3.5-plus-02-15", ""),
("qwen/qwen3.5-35b-a3b", ""),
("stepfun/step-3.5-flash", ""),
("z-ai/glm-5", ""),
("moonshotai/kimi-k2.5", ""),
("minimax/minimax-m2.1", ""),
("minimax/minimax-m2.5", ""),
]
_PROVIDER_MODELS: dict[str, list[str]] = {
"zai": [
"glm-5",
"glm-4.7",
"glm-4.5",
"glm-4.5-flash",
],
"kimi-coding": [
"kimi-k2.5",
"kimi-k2-thinking",
"kimi-k2-turbo-preview",
"kimi-k2-0905-preview",
],
"minimax": [
"MiniMax-M2.5",
"MiniMax-M2.5-highspeed",
"MiniMax-M2.1",
],
"minimax-cn": [
"MiniMax-M2.5",
"MiniMax-M2.5-highspeed",
"MiniMax-M2.1",
],
}
_PROVIDER_LABELS = {
"openrouter": "OpenRouter",
"openai-codex": "OpenAI Codex",
"nous": "Nous Portal",
"zai": "Z.AI / GLM",
"kimi-coding": "Kimi / Moonshot",
"minimax": "MiniMax",
"minimax-cn": "MiniMax (China)",
"custom": "custom endpoint",
}
_PROVIDER_ALIASES = {
"glm": "zai",
"z-ai": "zai",
"z.ai": "zai",
"zhipu": "zai",
"kimi": "kimi-coding",
"moonshot": "kimi-coding",
"minimax-china": "minimax-cn",
"minimax_cn": "minimax-cn",
}
def model_ids() -> list[str]:
"""Return just the model-id strings (convenience helper)."""
"""Return just the OpenRouter model-id strings."""
return [mid for mid, _ in OPENROUTER_MODELS]
@@ -31,3 +89,231 @@ def menu_labels() -> list[str]:
for mid, desc in OPENROUTER_MODELS:
labels.append(f"{mid} ({desc})" if desc else mid)
return labels
# All provider IDs and aliases that are valid for the provider:model syntax.
_KNOWN_PROVIDER_NAMES: set[str] = (
set(_PROVIDER_LABELS.keys())
| set(_PROVIDER_ALIASES.keys())
| {"openrouter", "custom"}
)
def list_available_providers() -> list[dict[str, str]]:
"""Return info about all providers the user could use with ``provider:model``.
Each dict has ``id``, ``label``, and ``aliases``.
Checks which providers have valid credentials configured.
"""
# Canonical providers in display order
_PROVIDER_ORDER = [
"openrouter", "nous", "openai-codex",
"zai", "kimi-coding", "minimax", "minimax-cn",
]
# Build reverse alias map
aliases_for: dict[str, list[str]] = {}
for alias, canonical in _PROVIDER_ALIASES.items():
aliases_for.setdefault(canonical, []).append(alias)
result = []
for pid in _PROVIDER_ORDER:
label = _PROVIDER_LABELS.get(pid, pid)
alias_list = aliases_for.get(pid, [])
# Check if this provider has credentials available
has_creds = False
try:
from hermes_cli.runtime_provider import resolve_runtime_provider
runtime = resolve_runtime_provider(requested=pid)
has_creds = bool(runtime.get("api_key"))
except Exception:
pass
result.append({
"id": pid,
"label": label,
"aliases": alias_list,
"authenticated": has_creds,
})
return result
def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
"""Parse ``/model`` input into ``(provider, model)``.
Supports ``provider:model`` syntax to switch providers at runtime::
openrouter:anthropic/claude-sonnet-4.5 → ("openrouter", "anthropic/claude-sonnet-4.5")
nous:hermes-3 → ("nous", "hermes-3")
anthropic/claude-sonnet-4.5 → (current_provider, "anthropic/claude-sonnet-4.5")
gpt-5.4 → (current_provider, "gpt-5.4")
The colon is only treated as a provider delimiter if the left side is a
recognized provider name or alias. This avoids misinterpreting model names
that happen to contain colons (e.g. ``anthropic/claude-3.5-sonnet:beta``).
Returns ``(provider, model)`` where *provider* is either the explicit
provider from the input or *current_provider* if none was specified.
"""
stripped = raw.strip()
colon = stripped.find(":")
if colon > 0:
provider_part = stripped[:colon].strip().lower()
model_part = stripped[colon + 1:].strip()
if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES:
return (normalize_provider(provider_part), model_part)
return (current_provider, stripped)
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
"""Return ``(model_id, description)`` tuples for a provider's curated list."""
normalized = normalize_provider(provider)
if normalized == "openrouter":
return list(OPENROUTER_MODELS)
models = _PROVIDER_MODELS.get(normalized, [])
return [(m, "") for m in models]
def normalize_provider(provider: Optional[str]) -> str:
"""Normalize provider aliases to Hermes' canonical provider ids.
Note: ``"auto"`` passes through unchanged — use
``hermes_cli.auth.resolve_provider()`` to resolve it to a concrete
provider based on credentials and environment.
"""
normalized = (provider or "openrouter").strip().lower()
return _PROVIDER_ALIASES.get(normalized, normalized)
def provider_model_ids(provider: Optional[str]) -> list[str]:
"""Return the best known model catalog for a provider."""
normalized = normalize_provider(provider)
if normalized == "openrouter":
return model_ids()
if normalized == "openai-codex":
from hermes_cli.codex_models import get_codex_model_ids
return get_codex_model_ids()
return list(_PROVIDER_MODELS.get(normalized, []))
def fetch_api_models(
api_key: Optional[str],
base_url: Optional[str],
timeout: float = 5.0,
) -> Optional[list[str]]:
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
Returns a list of model ID strings, or ``None`` if the endpoint could not
be reached (network error, timeout, auth failure, etc.).
"""
if not base_url:
return None
url = base_url.rstrip("/") + "/models"
headers: dict[str, str] = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
req = urllib.request.Request(url, headers=headers)
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
data = json.loads(resp.read().decode())
# Standard OpenAI format: {"data": [{"id": "model-name", ...}, ...]}
return [m.get("id", "") for m in data.get("data", [])]
except Exception:
return None
def validate_requested_model(
model_name: str,
provider: Optional[str],
*,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
) -> dict[str, Any]:
"""
Validate a ``/model`` value for the active provider.
Performs format checks first, then probes the live API to confirm
the model actually exists.
Returns a dict with:
- accepted: whether the CLI should switch to the requested model now
- persist: whether it is safe to save to config
- recognized: whether it matched a known provider catalog
- message: optional warning / guidance for the user
"""
requested = (model_name or "").strip()
normalized = normalize_provider(provider)
if normalized == "openrouter" and base_url and "openrouter.ai" not in base_url:
normalized = "custom"
if not requested:
return {
"accepted": False,
"persist": False,
"recognized": False,
"message": "Model name cannot be empty.",
}
if any(ch.isspace() for ch in requested):
return {
"accepted": False,
"persist": False,
"recognized": False,
"message": "Model names cannot contain spaces.",
}
# Probe the live API to check if the model actually exists
api_models = fetch_api_models(api_key, base_url)
if api_models is not None:
if requested in set(api_models):
# API confirmed the model exists
return {
"accepted": True,
"persist": True,
"recognized": True,
"message": None,
}
else:
# API responded but model is not listed
suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5)
suggestion_text = ""
if suggestions:
suggestion_text = "\n Did you mean: " + ", ".join(f"`{s}`" for s in suggestions)
return {
"accepted": False,
"persist": False,
"recognized": False,
"message": (
f"Error: `{requested}` is not a valid model for this provider."
f"{suggestion_text}"
),
}
# api_models is None — couldn't reach API, fall back to catalog check
provider_label = _PROVIDER_LABELS.get(normalized, normalized)
known_models = provider_model_ids(normalized)
if requested in known_models:
return {
"accepted": True,
"persist": True,
"recognized": True,
"message": None,
}
# Can't validate — accept for session only
suggestion = get_close_matches(requested, known_models, n=1, cutoff=0.6)
suggestion_text = f" Did you mean `{suggestion[0]}`?" if suggestion else ""
return {
"accepted": True,
"persist": False,
"recognized": False,
"message": (
f"Could not validate `{requested}` against the live {provider_label} API. "
"Using it for this session only; config unchanged."
f"{suggestion_text}"
),
}

View File

@@ -7,10 +7,12 @@ from typing import Any, Dict, Optional
from hermes_cli.auth import (
AuthError,
PROVIDER_REGISTRY,
format_auth_error,
resolve_provider,
resolve_nous_runtime_credentials,
resolve_codex_runtime_credentials,
resolve_api_key_provider_credentials,
)
from hermes_cli.config import load_config
from hermes_constants import OPENROUTER_BASE_URL
@@ -72,12 +74,26 @@ def _resolve_openrouter_runtime(
or OPENROUTER_BASE_URL
).rstrip("/")
api_key = (
explicit_api_key
or os.getenv("OPENROUTER_API_KEY")
or os.getenv("OPENAI_API_KEY")
or ""
)
# Choose API key based on whether the resolved base_url targets OpenRouter.
# When hitting OpenRouter, prefer OPENROUTER_API_KEY (issue #289).
# When hitting a custom endpoint (e.g. Z.ai, local LLM), prefer
# OPENAI_API_KEY so the OpenRouter key doesn't leak to an unrelated
# provider (issues #420, #560).
_is_openrouter_url = "openrouter.ai" in base_url
if _is_openrouter_url:
api_key = (
explicit_api_key
or os.getenv("OPENROUTER_API_KEY")
or os.getenv("OPENAI_API_KEY")
or ""
)
else:
api_key = (
explicit_api_key
or os.getenv("OPENAI_API_KEY")
or os.getenv("OPENROUTER_API_KEY")
or ""
)
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
@@ -132,6 +148,19 @@ def resolve_runtime_provider(
"requested_provider": requested_provider,
}
# API-key providers (z.ai/GLM, Kimi, MiniMax, MiniMax-CN)
pconfig = PROVIDER_REGISTRY.get(provider)
if pconfig and pconfig.auth_type == "api_key":
creds = resolve_api_key_provider_credentials(provider)
return {
"provider": provider,
"api_mode": "chat_completions",
"base_url": creds.get("base_url", "").rstrip("/"),
"api_key": creds.get("api_key", ""),
"source": creds.get("source", "env"),
"requested_provider": requested_provider,
}
runtime = _resolve_openrouter_runtime(
requested_provider=requested_provider,
explicit_api_key=explicit_api_key,

File diff suppressed because it is too large Load Diff

View File

@@ -57,8 +57,9 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
table.add_column("Trust", style="dim")
table.add_column("Identifier", style="bold cyan")
for r in exact:
trust_style = {"trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
table.add_row(r.source, f"[{trust_style}]{r.trust_level}[/]", r.identifier)
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
trust_label = "official" if r.source == "official" else r.trust_level
table.add_row(r.source, f"[{trust_style}]{trust_label}[/]", r.identifier)
c.print(table)
c.print("[bold]Use the full identifier to install a specific one.[/]\n")
return ""
@@ -99,12 +100,13 @@ def do_search(query: str, source: str = "all", limit: int = 10,
table.add_column("Identifier", style="dim")
for r in results:
trust_style = {"trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
trust_label = "official" if r.source == "official" else r.trust_level
table.add_row(
r.name,
r.description[:60] + ("..." if len(r.description) > 60 else ""),
r.source,
f"[{trust_style}]{r.trust_level}[/]",
f"[{trust_style}]{trust_label}[/]",
r.identifier,
)
@@ -113,6 +115,130 @@ def do_search(query: str, source: str = "all", limit: int = 10,
"hermes skills install <identifier> to install[/]\n")
def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
console: Optional[Console] = None) -> None:
"""Browse all available skills across registries, paginated.
Official skills are always shown first, regardless of source filter.
"""
from tools.skills_hub import (
GitHubAuth, create_source_router, OptionalSkillSource, SkillMeta,
)
# Clamp page_size to safe range
page_size = max(1, min(page_size, 100))
c = console or _console
auth = GitHubAuth()
sources = create_source_router(auth)
# Collect results from all (or filtered) sources
# Use empty query to get everything; per-source limits prevent overload
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50,
"claude-marketplace": 50, "lobehub": 50}
all_results: list = []
source_counts: dict = {}
for src in sources:
sid = src.source_id()
if source != "all" and sid != source and sid != "official":
# Always include official source for the "first" placement
continue
try:
limit = _PER_SOURCE_LIMIT.get(sid, 50)
results = src.search("", limit=limit)
source_counts[sid] = len(results)
all_results.extend(results)
except Exception:
continue
if not all_results:
c.print("[dim]No skills found in the Skills Hub.[/]\n")
return
# Deduplicate by name, preferring higher trust
seen: dict = {}
for r in all_results:
rank = _TRUST_RANK.get(r.trust_level, 0)
if r.name not in seen or rank > _TRUST_RANK.get(seen[r.name].trust_level, 0):
seen[r.name] = r
deduped = list(seen.values())
# Sort: official first, then by trust level (desc), then alphabetically
deduped.sort(key=lambda r: (
-_TRUST_RANK.get(r.trust_level, 0),
r.source != "official",
r.name.lower(),
))
# Paginate
total = len(deduped)
total_pages = max(1, (total + page_size - 1) // page_size)
page = max(1, min(page, total_pages))
start = (page - 1) * page_size
end = min(start + page_size, total)
page_items = deduped[start:end]
# Count official vs other
official_count = sum(1 for r in deduped if r.source == "official")
# Build header
source_label = f"{source}" if source != "all" else "— all sources"
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
f" [dim]({total} skills, page {page}/{total_pages})[/]")
if official_count > 0 and page == 1:
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
c.print()
# Build table
table = Table(show_header=True, header_style="bold")
table.add_column("#", style="dim", width=4, justify="right")
table.add_column("Name", style="bold cyan", max_width=25)
table.add_column("Description", max_width=50)
table.add_column("Source", style="dim", width=12)
table.add_column("Trust", width=10)
for i, r in enumerate(page_items, start=start + 1):
trust_style = {"builtin": "bright_cyan", "trusted": "green",
"community": "yellow"}.get(r.trust_level, "dim")
trust_label = "★ official" if r.source == "official" else r.trust_level
desc = r.description[:50]
if len(r.description) > 50:
desc += "..."
table.add_row(
str(i),
r.name,
desc,
r.source,
f"[{trust_style}]{trust_label}[/]",
)
c.print(table)
# Navigation hints
nav_parts = []
if page > 1:
nav_parts.append(f"[cyan]--page {page - 1}[/] ← prev")
if page < total_pages:
nav_parts.append(f"[cyan]--page {page + 1}[/] → next")
if nav_parts:
c.print(f" {' | '.join(nav_parts)}")
# Source summary
if source == "all" and source_counts:
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
"hermes skills install <identifier> to install[/]\n")
def do_install(identifier: str, category: str = "", force: bool = False,
console: Optional[Console] = None) -> None:
"""Fetch, quarantine, scan, confirm, and install a skill."""
@@ -147,6 +273,12 @@ def do_install(identifier: str, category: str = "", force: bool = False,
c.print(f"[bold red]Error:[/] Could not fetch '{identifier}' from any source.\n")
return
# Auto-detect category for official skills (e.g. "official/autonomous-ai-agents/blackbox")
if bundle.source == "official" and not category:
id_parts = bundle.identifier.split("/") # ["official", "category", "skill"]
if len(id_parts) >= 3:
category = id_parts[1]
# Check if already installed
lock = HubLockFile()
existing = lock.get_installed(bundle.name)
@@ -177,18 +309,28 @@ def do_install(identifier: str, category: str = "", force: bool = False,
f"{len(result.findings)}_findings")
return
# Confirm with user — always show risk warning regardless of source
# Confirm with user — show appropriate warning based on source
if not force:
c.print()
c.print(Panel(
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
"External skills can contain instructions that influence agent behavior,\n"
"shell commands, and scripts. Even after automated scanning, you should\n"
"review the installed files before use.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Disclaimer",
border_style="yellow",
))
if bundle.source == "official":
c.print(Panel(
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
"It ships with hermes-agent but is not activated by default.\n"
"Installing will copy it to your skills directory where the agent can use it.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Official Skill",
border_style="bright_cyan",
))
else:
c.print(Panel(
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
"External skills can contain instructions that influence agent behavior,\n"
"shell commands, and scripts. Even after automated scanning, you should\n"
"review the installed files before use.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Disclaimer",
border_style="yellow",
))
c.print(f"[bold]Install '{bundle.name}'?[/]")
try:
answer = input("Confirm [y/N]: ").strip().lower()
@@ -237,13 +379,14 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
break
c.print()
trust_style = {"trusted": "green", "community": "yellow"}.get(meta.trust_level, "dim")
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(meta.trust_level, "dim")
trust_label = "official" if meta.source == "official" else meta.trust_level
info_lines = [
f"[bold]Name:[/] {meta.name}",
f"[bold]Description:[/] {meta.description}",
f"[bold]Source:[/] {meta.source}",
f"[bold]Trust:[/] [{trust_style}]{meta.trust_level}[/]",
f"[bold]Trust:[/] [{trust_style}]{trust_label}[/]",
f"[bold]Identifier:[/] {meta.identifier}",
]
if meta.tags:
@@ -297,8 +440,9 @@ def do_list(source_filter: str = "all", console: Optional[Console] = None) -> No
if source_filter == "builtin" and hub_entry:
continue
trust_style = {"builtin": "blue", "trusted": "green", "community": "yellow"}.get(trust, "dim")
table.add_row(name, category, source_display, f"[{trust_style}]{trust}[/]")
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(trust, "dim")
trust_label = "official" if source_display == "official" else trust
table.add_row(name, category, source_display, f"[{trust_style}]{trust_label}[/]")
c.print(table)
c.print(f"[dim]{len(hub_installed)} hub-installed, "
@@ -658,7 +802,9 @@ def skills_command(args) -> None:
"""Router for `hermes skills <subcommand>` — called from hermes_cli/main.py."""
action = getattr(args, "skills_action", None)
if action == "search":
if action == "browse":
do_browse(page=args.page, page_size=args.size, source=args.source)
elif action == "search":
do_search(args.query, source=args.source, limit=args.limit)
elif action == "install":
do_install(args.identifier, category=args.category, force=args.force)
@@ -692,7 +838,7 @@ def skills_command(args) -> None:
return
do_tap(tap_action, repo=repo)
else:
_console.print("Usage: hermes skills [search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
_console.print("Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
_console.print("Run 'hermes skills <command> --help' for details.\n")
@@ -732,7 +878,32 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
action = parts[0].lower()
args = parts[1:]
if action == "search":
if action == "browse":
page = 1
page_size = 20
source = "all"
i = 0
while i < len(args):
if args[i] == "--page" and i + 1 < len(args):
try:
page = int(args[i + 1])
except ValueError:
pass
i += 2
elif args[i] == "--size" and i + 1 < len(args):
try:
page_size = int(args[i + 1])
except ValueError:
pass
i += 2
elif args[i] == "--source" and i + 1 < len(args):
source = args[i + 1]
i += 2
else:
i += 1
do_browse(page=page, page_size=page_size, source=source, console=c)
elif action == "search":
if not args:
c.print("[bold red]Usage:[/] /skills search <query> [--source github] [--limit N]\n")
return
@@ -838,6 +1009,7 @@ def _print_skills_help(console: Console) -> None:
"""Print help for the /skills slash command."""
console.print(Panel(
"[bold]Skills Hub Commands:[/]\n\n"
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
" [cyan]search[/] <query> Search registries for skills\n"
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"

View File

@@ -79,8 +79,12 @@ def show_status(args):
"OpenRouter": "OPENROUTER_API_KEY",
"Anthropic": "ANTHROPIC_API_KEY",
"OpenAI": "OPENAI_API_KEY",
"Z.AI/GLM": "GLM_API_KEY",
"Kimi": "KIMI_API_KEY",
"MiniMax": "MINIMAX_API_KEY",
"MiniMax-CN": "MINIMAX_CN_API_KEY",
"Firecrawl": "FIRECRAWL_API_KEY",
"Browserbase": "BROWSERBASE_API_KEY",
"Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this
"FAL": "FAL_KEY",
"Tinker": "TINKER_API_KEY",
"WandB": "WANDB_API_KEY",
@@ -128,7 +132,7 @@ def show_status(args):
f" {'OpenAI Codex':<12} {check_mark(codex_logged_in)} "
f"{'logged in' if codex_logged_in else 'not logged in (run: hermes model)'}"
)
codex_auth_file = codex_status.get("auth_file")
codex_auth_file = codex_status.get("auth_store")
if codex_auth_file:
print(f" Auth file: {codex_auth_file}")
codex_last_refresh = _format_iso_timestamp(codex_status.get("last_refresh"))
@@ -137,6 +141,28 @@ def show_status(args):
if codex_status.get("error") and not codex_logged_in:
print(f" Error: {codex_status.get('error')}")
# =========================================================================
# API-Key Providers
# =========================================================================
print()
print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD))
apikey_providers = {
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
"Kimi / Moonshot": ("KIMI_API_KEY",),
"MiniMax": ("MINIMAX_API_KEY",),
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
}
for pname, env_vars in apikey_providers.items():
key_val = ""
for ev in env_vars:
key_val = get_env_value(ev) or ""
if key_val:
break
configured = bool(key_val)
label = "configured" if configured else "not configured (run: hermes model)"
print(f" {pname:<16} {check_mark(configured)} {label}")
# =========================================================================
# Terminal Configuration
# =========================================================================
@@ -163,6 +189,9 @@ def show_status(args):
elif terminal_env == "docker":
docker_image = os.getenv("TERMINAL_DOCKER_IMAGE", "python:3.11-slim")
print(f" Docker Image: {docker_image}")
elif terminal_env == "daytona":
daytona_image = os.getenv("TERMINAL_DAYTONA_IMAGE", "nikolaik/python-nodejs:python3.11-nodejs20")
print(f" Daytona Image: {daytona_image}")
sudo_password = os.getenv("SUDO_PASSWORD", "")
print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}")

View File

@@ -1,7 +1,10 @@
"""
Interactive tool configuration for Hermes Agent.
Unified tool configuration for Hermes Agent.
`hermes tools` and `hermes setup tools` both enter this module.
Select a platform → toggle toolsets on/off → for newly enabled tools
that need API keys, run through provider-aware configuration.
`hermes tools` — select a platform, then toggle toolsets on/off via checklist.
Saves per-platform tool configuration to ~/.hermes/config.yaml under
the `platform_toolsets` key.
"""
@@ -12,9 +15,63 @@ from typing import Dict, List, Set
import os
from hermes_cli.config import load_config, save_config, get_env_value, save_env_value
from hermes_cli.config import (
load_config, save_config, get_env_value, save_env_value,
get_hermes_home,
)
from hermes_cli.colors import Colors, color
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
def _print_info(text: str):
print(color(f" {text}", Colors.DIM))
def _print_success(text: str):
print(color(f"{text}", Colors.GREEN))
def _print_warning(text: str):
print(color(f"{text}", Colors.YELLOW))
def _print_error(text: str):
print(color(f"{text}", Colors.RED))
def _prompt(question: str, default: str = None, password: bool = False) -> str:
if default:
display = f"{question} [{default}]: "
else:
display = f"{question}: "
try:
if password:
import getpass
value = getpass.getpass(color(display, Colors.YELLOW))
else:
value = input(color(display, Colors.YELLOW))
return value.strip() or default or ""
except (KeyboardInterrupt, EOFError):
print()
return default or ""
def _prompt_yes_no(question: str, default: bool = True) -> bool:
default_str = "Y/n" if default else "y/N"
while True:
try:
value = input(color(f"{question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
except (KeyboardInterrupt, EOFError):
print()
return default
if not value:
return default
if value in ('y', 'yes'):
return True
if value in ('n', 'no'):
return False
# ─── Toolset Registry ─────────────────────────────────────────────────────────
# Toolsets shown in the configurator, grouped for display.
# Each entry: (toolset_name, label, description)
# These map to keys in toolsets.py TOOLSETS dict.
@@ -49,6 +106,187 @@ PLATFORMS = {
}
# ─── Tool Categories (provider-aware configuration) ──────────────────────────
# Maps toolset keys to their provider options. When a toolset is newly enabled,
# we use this to show provider selection and prompt for the right API keys.
# Toolsets not in this map either need no config or use the simple fallback.
TOOL_CATEGORIES = {
"tts": {
"name": "Text-to-Speech",
"icon": "🔊",
"providers": [
{
"name": "Microsoft Edge TTS",
"tag": "Free - no API key needed",
"env_vars": [],
"tts_provider": "edge",
},
{
"name": "OpenAI TTS",
"tag": "Premium - high quality voices",
"env_vars": [
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
],
"tts_provider": "openai",
},
{
"name": "ElevenLabs",
"tag": "Premium - most natural voices",
"env_vars": [
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
],
"tts_provider": "elevenlabs",
},
],
},
"web": {
"name": "Web Search & Extract",
"icon": "🔍",
"providers": [
{
"name": "Firecrawl Cloud",
"tag": "Recommended - hosted service",
"env_vars": [
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
],
},
{
"name": "Firecrawl Self-Hosted",
"tag": "Free - run your own instance",
"env_vars": [
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
],
},
],
},
"image_gen": {
"name": "Image Generation",
"icon": "🎨",
"providers": [
{
"name": "FAL.ai",
"tag": "FLUX 2 Pro with auto-upscaling",
"env_vars": [
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
],
},
],
},
"browser": {
"name": "Browser Automation",
"icon": "🌐",
"providers": [
{
"name": "Local Browser",
"tag": "Free headless Chromium (no API key needed)",
"env_vars": [],
"post_setup": "browserbase", # Same npm install for agent-browser
},
{
"name": "Browserbase",
"tag": "Cloud browser with stealth & proxies",
"env_vars": [
{"key": "BROWSERBASE_API_KEY", "prompt": "Browserbase API key", "url": "https://browserbase.com"},
{"key": "BROWSERBASE_PROJECT_ID", "prompt": "Browserbase project ID"},
],
"post_setup": "browserbase",
},
],
},
"homeassistant": {
"name": "Smart Home",
"icon": "🏠",
"providers": [
{
"name": "Home Assistant",
"tag": "REST API integration",
"env_vars": [
{"key": "HASS_TOKEN", "prompt": "Home Assistant Long-Lived Access Token"},
{"key": "HASS_URL", "prompt": "Home Assistant URL", "default": "http://homeassistant.local:8123"},
],
},
],
},
"rl": {
"name": "RL Training",
"icon": "🧪",
"requires_python": (3, 11),
"providers": [
{
"name": "Tinker / Atropos",
"tag": "RL training platform",
"env_vars": [
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
],
"post_setup": "rl_training",
},
],
},
}
# Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES.
# Used as a fallback for tools like vision/moa that just need an API key.
TOOLSET_ENV_REQUIREMENTS = {
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
}
# ─── Post-Setup Hooks ─────────────────────────────────────────────────────────
def _run_post_setup(post_setup_key: str):
"""Run post-setup hooks for tools that need extra installation steps."""
import shutil
if post_setup_key == "browserbase":
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
if not node_modules.exists() and shutil.which("npm"):
_print_info(" Installing Node.js dependencies for browser tools...")
import subprocess
result = subprocess.run(
["npm", "install", "--silent"],
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
)
if result.returncode == 0:
_print_success(" Node.js dependencies installed")
else:
_print_warning(" npm install failed - run manually: cd ~/.hermes/hermes-agent && npm install")
elif not node_modules.exists():
_print_warning(" Node.js not found - browser tools require: npm install (in hermes-agent directory)")
elif post_setup_key == "rl_training":
try:
__import__("tinker_atropos")
except ImportError:
tinker_dir = PROJECT_ROOT / "tinker-atropos"
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
_print_info(" Installing tinker-atropos submodule...")
import subprocess
uv_bin = shutil.which("uv")
if uv_bin:
result = subprocess.run(
[uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)],
capture_output=True, text=True
)
else:
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
capture_output=True, text=True
)
if result.returncode == 0:
_print_success(" tinker-atropos installed")
else:
_print_warning(" tinker-atropos install failed - run manually:")
_print_info(' uv pip install -e "./tinker-atropos"')
else:
_print_warning(" tinker-atropos submodule not found - run:")
_print_info(" git submodule update --init --recursive")
_print_info(' uv pip install -e "./tinker-atropos"')
# ─── Platform / Toolset Helpers ───────────────────────────────────────────────
def _get_enabled_platforms() -> List[str]:
"""Return platform keys that are configured (have tokens or are CLI)."""
enabled = ["cli"]
@@ -70,7 +308,7 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
platform_toolsets = config.get("platform_toolsets", {})
toolset_names = platform_toolsets.get(platform)
if not toolset_names or not isinstance(toolset_names, list):
if toolset_names is None or not isinstance(toolset_names, list):
default_ts = PLATFORMS[platform]["default_toolset"]
toolset_names = [default_ts]
@@ -97,61 +335,117 @@ def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[
save_config(config)
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
"""Single-select menu (arrow keys)."""
print(color(question, Colors.YELLOW))
try:
from simple_term_menu import TerminalMenu
menu = TerminalMenu(
[f" {c}" for c in choices],
cursor_index=default,
menu_cursor="",
menu_cursor_style=("fg_green", "bold"),
menu_highlight_style=("fg_green",),
cycle_cursor=True,
clear_screen=False,
)
idx = menu.show()
if idx is None:
sys.exit(0)
print()
return idx
except (ImportError, NotImplementedError):
for i, c in enumerate(choices):
marker = "" if i == default else ""
style = Colors.GREEN if i == default else ""
print(color(f" {marker} {c}", style) if style else f" {marker} {c}")
while True:
try:
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
if not val:
return default
idx = int(val) - 1
if 0 <= idx < len(choices):
return idx
except (ValueError, KeyboardInterrupt, EOFError):
print()
sys.exit(0)
def _toolset_has_keys(ts_key: str) -> bool:
"""Check if a toolset's required API keys are configured."""
# Check TOOL_CATEGORIES first (provider-aware)
cat = TOOL_CATEGORIES.get(ts_key)
if cat:
for provider in cat["providers"]:
env_vars = provider.get("env_vars", [])
if not env_vars:
return True # Free provider (e.g., Edge TTS)
if all(get_env_value(v["key"]) for v in env_vars):
return True
return False
# Fallback to simple requirements
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
return True
return all(get_env_value(var) for var, _ in requirements)
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
rendering bugs in tmux, iTerm, and other non-standard terminals."""
# Curses-based single-select — works in tmux, iTerm, and standard terminals
try:
import curses
result_holder = [default]
def _curses_menu(stdscr):
curses.curs_set(0)
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
cursor = default
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
try:
stdscr.addnstr(0, 0, question, max_x - 1,
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
except curses.error:
pass
for i, c in enumerate(choices):
y = i + 2
if y >= max_y - 1:
break
arrow = "" if i == cursor else " "
line = f" {arrow} {c}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line, max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord('k')):
cursor = (cursor - 1) % len(choices)
elif key in (curses.KEY_DOWN, ord('j')):
cursor = (cursor + 1) % len(choices)
elif key in (curses.KEY_ENTER, 10, 13):
result_holder[0] = cursor
return
elif key in (27, ord('q')):
return
curses.wrapper(_curses_menu)
return result_holder[0]
except Exception:
pass
# Fallback: numbered input (Windows without curses, etc.)
print(color(question, Colors.YELLOW))
for i, c in enumerate(choices):
marker = "" if i == default else ""
style = Colors.GREEN if i == default else ""
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
while True:
try:
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
if not val:
return default
idx = int(val) - 1
if 0 <= idx < len(choices):
return idx
except (ValueError, KeyboardInterrupt, EOFError):
print()
return default
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
import platform as _platform
labels = []
for ts_key, ts_label, ts_desc in CONFIGURABLE_TOOLSETS:
suffix = ""
if not _toolset_has_keys(ts_key) and TOOLSET_ENV_REQUIREMENTS.get(ts_key):
suffix = " no API key"
if not _toolset_has_keys(ts_key) and (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
suffix = " [no API key]"
labels.append(f"{ts_label} ({ts_desc}){suffix}")
pre_selected_indices = [
@@ -159,48 +453,8 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
if ts_key in enabled
]
# simple_term_menu multi-select has rendering bugs on macOS terminals,
# so we use a curses-based fallback there.
use_term_menu = _platform.system() != "Darwin"
if use_term_menu:
try:
from simple_term_menu import TerminalMenu
print(color(f"Tools for {platform_label}", Colors.YELLOW))
print(color(" SPACE to toggle, ENTER to confirm.", Colors.DIM))
print()
menu_items = [f" {label}" for label in labels]
menu = TerminalMenu(
menu_items,
multi_select=True,
show_multi_select_hint=False,
multi_select_cursor="[✓] ",
multi_select_select_on_accept=False,
multi_select_empty_ok=True,
preselected_entries=pre_selected_indices if pre_selected_indices else None,
menu_cursor="",
menu_cursor_style=("fg_green", "bold"),
menu_highlight_style=("fg_green",),
cycle_cursor=True,
clear_screen=False,
clear_menu_on_exit=False,
)
menu.show()
if menu.chosen_menu_entries is None:
return enabled
selected_indices = list(menu.chosen_menu_indices or [])
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected_indices}
except (ImportError, NotImplementedError):
pass # fall through to curses/numbered fallback
# Curses-based multi-select — arrow keys + space to toggle + enter to confirm.
# Used on macOS (where simple_term_menu ghosts) and as a fallback.
# simple_term_menu has rendering bugs in tmux, iTerm, and other terminals.
try:
import curses
selected = set(pre_selected_indices)
@@ -302,77 +556,294 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected}
# Map toolset keys to the env vars they require and where to get them
TOOLSET_ENV_REQUIREMENTS = {
"web": [("FIRECRAWL_API_KEY", "https://firecrawl.dev/")],
"browser": [("BROWSERBASE_API_KEY", "https://browserbase.com/"),
("BROWSERBASE_PROJECT_ID", None)],
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"image_gen": [("FAL_KEY", "https://fal.ai/")],
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"tts": [], # Edge TTS is free, no key needed
"rl": [("TINKER_API_KEY", "https://tinker-console.thinkingmachines.ai/keys"),
("WANDB_API_KEY", "https://wandb.ai/authorize")],
"homeassistant": [("HASS_TOKEN", "Home Assistant > Profile > Long-Lived Access Tokens"),
("HASS_URL", None)],
}
# ─── Provider-Aware Configuration ────────────────────────────────────────────
def _configure_toolset(ts_key: str, config: dict):
"""Configure a toolset - provider selection + API keys.
Uses TOOL_CATEGORIES for provider-aware config, falls back to simple
env var prompts for toolsets not in TOOL_CATEGORIES.
"""
cat = TOOL_CATEGORIES.get(ts_key)
if cat:
_configure_tool_category(ts_key, cat, config)
else:
# Simple fallback for vision, moa, etc.
_configure_simple_requirements(ts_key)
def _check_and_prompt_requirements(newly_enabled: Set[str]):
"""Check if newly enabled toolsets have missing API keys and offer to set them up."""
for ts_key in sorted(newly_enabled):
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
continue
def _configure_tool_category(ts_key: str, cat: dict, config: dict):
"""Configure a tool category with provider selection."""
icon = cat.get("icon", "")
name = cat["name"]
providers = cat["providers"]
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
if not missing:
continue
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
print()
print(color(f"{ts_label} requires configuration:", Colors.YELLOW))
for var, url in missing:
if url:
print(color(f" {var}", Colors.CYAN) + color(f" ({url})", Colors.DIM))
else:
print(color(f" {var}", Colors.CYAN))
print()
try:
response = input(color(" Set up now? [Y/n] ", Colors.YELLOW)).strip().lower()
except (KeyboardInterrupt, EOFError):
# Check Python version requirement
if cat.get("requires_python"):
req = cat["requires_python"]
if sys.version_info < req:
print()
continue
_print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})")
_print_info(" Upgrade Python and reinstall to enable this tool.")
return
if response in ("", "y", "yes"):
for var, url in missing:
if url:
print(color(f" Get key at: {url}", Colors.DIM))
try:
import getpass
value = getpass.getpass(color(f" {var}: ", Colors.YELLOW))
except (KeyboardInterrupt, EOFError):
print()
break
if value.strip():
save_env_value(var, value.strip())
print(color(f" ✓ Saved", Colors.GREEN))
if len(providers) == 1:
# Single provider - configure directly
provider = providers[0]
print()
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
if provider.get("tag"):
_print_info(f" {provider['tag']}")
_configure_provider(provider, config)
else:
# Multiple providers - let user choose
print()
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
print()
# Plain text labels only (no ANSI codes in menu items)
provider_choices = []
for p in providers:
tag = f" ({p['tag']})" if p.get("tag") else ""
configured = ""
env_vars = p.get("env_vars", [])
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
configured = " [active]"
elif not env_vars:
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
else:
print(color(f" Skipped", Colors.DIM))
configured = " [configured]"
provider_choices.append(f"{p['name']}{tag}{configured}")
# Detect current provider as default
default_idx = 0
for i, p in enumerate(providers):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
default_idx = i
break
env_vars = p.get("env_vars", [])
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
default_idx = i
break
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
_configure_provider(providers[provider_idx], config)
def _configure_provider(provider: dict, config: dict):
"""Configure a single provider - prompt for API keys and set config."""
env_vars = provider.get("env_vars", [])
# Set TTS provider in config if applicable
if provider.get("tts_provider"):
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
if not env_vars:
_print_success(f" {provider['name']} - no configuration needed!")
return
# Prompt for each required env var
all_configured = True
for var in env_vars:
existing = get_env_value(var["key"])
if existing:
_print_success(f" {var['key']}: already configured")
# Don't ask to update - this is a new enable flow.
# Reconfigure is handled separately.
else:
print(color(" Skipped — configure later with 'hermes setup'", Colors.DIM))
url = var.get("url", "")
if url:
_print_info(f" Get yours at: {url}")
default_val = var.get("default", "")
if default_val:
value = _prompt(f" {var.get('prompt', var['key'])}", default_val)
else:
value = _prompt(f" {var.get('prompt', var['key'])}", password=True)
if value:
save_env_value(var["key"], value)
_print_success(f" Saved")
else:
_print_warning(f" Skipped")
all_configured = False
# Run post-setup hooks if needed
if provider.get("post_setup") and all_configured:
_run_post_setup(provider["post_setup"])
if all_configured:
_print_success(f" {provider['name']} configured!")
def tools_command(args):
"""Entry point for `hermes tools`."""
def _configure_simple_requirements(ts_key: str):
"""Simple fallback for toolsets that just need env vars (no provider selection)."""
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
return
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
if not missing:
return
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
print()
print(color(f" {ts_label} requires configuration:", Colors.YELLOW))
for var, url in missing:
if url:
_print_info(f" Get key at: {url}")
value = _prompt(f" {var}", password=True)
if value and value.strip():
save_env_value(var, value.strip())
_print_success(f" Saved")
else:
_print_warning(f" Skipped")
def _reconfigure_tool(config: dict):
"""Let user reconfigure an existing tool's provider or API key."""
# Build list of configurable tools that are currently set up
configurable = []
for ts_key, ts_label, _ in CONFIGURABLE_TOOLSETS:
cat = TOOL_CATEGORIES.get(ts_key)
reqs = TOOLSET_ENV_REQUIREMENTS.get(ts_key)
if cat or reqs:
if _toolset_has_keys(ts_key):
configurable.append((ts_key, ts_label))
if not configurable:
_print_info("No configured tools to reconfigure.")
return
choices = [label for _, label in configurable]
choices.append("Cancel")
idx = _prompt_choice(" Which tool would you like to reconfigure?", choices, len(choices) - 1)
if idx >= len(configurable):
return # Cancel
ts_key, ts_label = configurable[idx]
cat = TOOL_CATEGORIES.get(ts_key)
if cat:
_configure_tool_category_for_reconfig(ts_key, cat, config)
else:
_reconfigure_simple_requirements(ts_key)
save_config(config)
def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
"""Reconfigure a tool category - provider selection + API key update."""
icon = cat.get("icon", "")
name = cat["name"]
providers = cat["providers"]
if len(providers) == 1:
provider = providers[0]
print()
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
_reconfigure_provider(provider, config)
else:
print()
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
print()
provider_choices = []
for p in providers:
tag = f" ({p['tag']})" if p.get("tag") else ""
configured = ""
env_vars = p.get("env_vars", [])
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
configured = " [active]"
elif not env_vars:
configured = ""
else:
configured = " [configured]"
provider_choices.append(f"{p['name']}{tag}{configured}")
default_idx = 0
for i, p in enumerate(providers):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
default_idx = i
break
env_vars = p.get("env_vars", [])
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
default_idx = i
break
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
_reconfigure_provider(providers[provider_idx], config)
def _reconfigure_provider(provider: dict, config: dict):
"""Reconfigure a provider - update API keys."""
env_vars = provider.get("env_vars", [])
if provider.get("tts_provider"):
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
_print_success(f" TTS provider set to: {provider['tts_provider']}")
if not env_vars:
_print_success(f" {provider['name']} - no configuration needed!")
return
for var in env_vars:
existing = get_env_value(var["key"])
if existing:
_print_info(f" {var['key']}: configured ({existing[:8]}...)")
url = var.get("url", "")
if url:
_print_info(f" Get yours at: {url}")
default_val = var.get("default", "")
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
if value and value.strip():
save_env_value(var["key"], value.strip())
_print_success(f" Updated")
else:
_print_info(f" Kept current")
def _reconfigure_simple_requirements(ts_key: str):
"""Reconfigure simple env var requirements."""
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
return
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
print()
print(color(f" {ts_label}:", Colors.CYAN))
for var, url in requirements:
existing = get_env_value(var)
if existing:
_print_info(f" {var}: configured ({existing[:8]}...)")
if url:
_print_info(f" Get key at: {url}")
value = _prompt(f" {var} (Enter to keep current)", password=True)
if value and value.strip():
save_env_value(var, value.strip())
_print_success(f" Updated")
else:
_print_info(f" Kept current")
# ─── Main Entry Point ─────────────────────────────────────────────────────────
def tools_command(args=None):
"""Entry point for `hermes tools` and `hermes setup tools`."""
config = load_config()
enabled_platforms = _get_enabled_platforms()
print()
print(color("⚕ Hermes Tool Configuration", Colors.CYAN, Colors.BOLD))
print(color(" Enable or disable tools per platform.", Colors.DIM))
print(color(" Tools that need API keys will be configured when enabled.", Colors.DIM))
print()
# Build platform choices
@@ -380,22 +851,28 @@ def tools_command(args):
platform_keys = []
for pkey in enabled_platforms:
pinfo = PLATFORMS[pkey]
# Count currently enabled toolsets
current = _get_platform_tools(config, pkey)
count = len(current)
total = len(CONFIGURABLE_TOOLSETS)
platform_choices.append(f"Configure {pinfo['label']} ({count}/{total} enabled)")
platform_keys.append(pkey)
platform_choices.append("Done — save and exit")
platform_choices.append("Reconfigure an existing tool's provider or API key")
platform_choices.append("Done")
while True:
idx = _prompt_choice("Select a platform to configure:", platform_choices, default=0)
idx = _prompt_choice("Select an option:", platform_choices, default=0)
# "Done" selected
if idx == len(platform_keys):
if idx == len(platform_keys) + 1:
break
# "Reconfigure" selected
if idx == len(platform_keys):
_reconfigure_tool(config)
print()
continue
pkey = platform_keys[idx]
pinfo = PLATFORMS[pkey]
@@ -418,11 +895,15 @@ def tools_command(args):
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
print(color(f" - {label}", Colors.RED))
# Prompt for missing API keys on newly enabled toolsets
# Configure newly enabled toolsets that need API keys
if added:
_check_and_prompt_requirements(added)
for ts_key in sorted(added):
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key):
if not _toolset_has_keys(ts_key):
_configure_toolset(ts_key, config)
_save_platform_tools(config, pkey, new_enabled)
save_config(config)
print(color(f" ✓ Saved {pinfo['label']} configuration", Colors.GREEN))
else:
print(color(f" No changes to {pinfo['label']}", Colors.DIM))

View File

@@ -24,7 +24,7 @@ from typing import Dict, Any, List, Optional
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
SCHEMA_VERSION = 2
SCHEMA_VERSION = 4
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS schema_version (
@@ -46,6 +46,7 @@ CREATE TABLE IF NOT EXISTS sessions (
tool_call_count INTEGER DEFAULT 0,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
title TEXT,
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
);
@@ -133,7 +134,33 @@ class SessionDB:
except sqlite3.OperationalError:
pass # Column already exists
cursor.execute("UPDATE schema_version SET version = 2")
if current_version < 3:
# v3: add title column to sessions
try:
cursor.execute("ALTER TABLE sessions ADD COLUMN title TEXT")
except sqlite3.OperationalError:
pass # Column already exists
cursor.execute("UPDATE schema_version SET version = 3")
if current_version < 4:
# v4: add unique index on title (NULLs allowed, only non-NULL must be unique)
try:
cursor.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
"ON sessions(title) WHERE title IS NOT NULL"
)
except sqlite3.OperationalError:
pass # Index already exists
cursor.execute("UPDATE schema_version SET version = 4")
# Unique title index — always ensure it exists (safe to run after migrations
# since the title column is guaranteed to exist at this point)
try:
cursor.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
"ON sessions(title) WHERE title IS NOT NULL"
)
except sqlite3.OperationalError:
pass # Index already exists
# FTS5 setup (separate because CREATE VIRTUAL TABLE can't be in executescript with IF NOT EXISTS reliably)
try:
@@ -219,6 +246,210 @@ class SessionDB:
row = cursor.fetchone()
return dict(row) if row else None
# Maximum length for session titles
MAX_TITLE_LENGTH = 100
@staticmethod
def sanitize_title(title: Optional[str]) -> Optional[str]:
"""Validate and sanitize a session title.
- Strips leading/trailing whitespace
- Removes ASCII control characters (0x00-0x1F, 0x7F) and problematic
Unicode control chars (zero-width, RTL/LTR overrides, etc.)
- Collapses internal whitespace runs to single spaces
- Normalizes empty/whitespace-only strings to None
- Enforces MAX_TITLE_LENGTH
Returns the cleaned title string or None.
Raises ValueError if the title exceeds MAX_TITLE_LENGTH after cleaning.
"""
if not title:
return None
import re
# Remove ASCII control characters (0x00-0x1F, 0x7F) but keep
# whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be
# normalized to spaces by the whitespace collapsing step below
cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title)
# Remove problematic Unicode control characters:
# - Zero-width chars (U+200B-U+200F, U+FEFF)
# - Directional overrides (U+202A-U+202E, U+2066-U+2069)
# - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB)
cleaned = re.sub(
r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]',
'', cleaned,
)
# Collapse internal whitespace runs and strip
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
if not cleaned:
return None
if len(cleaned) > SessionDB.MAX_TITLE_LENGTH:
raise ValueError(
f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})"
)
return cleaned
def set_session_title(self, session_id: str, title: str) -> bool:
"""Set or update a session's title.
Returns True if session was found and title was set.
Raises ValueError if title is already in use by another session,
or if the title fails validation (too long, invalid characters).
Empty/whitespace-only strings are normalized to None (clearing the title).
"""
title = self.sanitize_title(title)
if title:
# Check uniqueness (allow the same session to keep its own title)
cursor = self._conn.execute(
"SELECT id FROM sessions WHERE title = ? AND id != ?",
(title, session_id),
)
conflict = cursor.fetchone()
if conflict:
raise ValueError(
f"Title '{title}' is already in use by session {conflict['id']}"
)
cursor = self._conn.execute(
"UPDATE sessions SET title = ? WHERE id = ?",
(title, session_id),
)
self._conn.commit()
return cursor.rowcount > 0
def get_session_title(self, session_id: str) -> Optional[str]:
"""Get the title for a session, or None."""
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE id = ?", (session_id,)
)
row = cursor.fetchone()
return row["title"] if row else None
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
"""Look up a session by exact title. Returns session dict or None."""
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE title = ?", (title,)
)
row = cursor.fetchone()
return dict(row) if row else None
def resolve_session_by_title(self, title: str) -> Optional[str]:
"""Resolve a title to a session ID, preferring the latest in a lineage.
If the exact title exists, returns that session's ID.
If not, searches for "title #N" variants and returns the latest one.
If the exact title exists AND numbered variants exist, returns the
latest numbered variant (the most recent continuation).
"""
# First try exact match
exact = self.get_session_by_title(title)
# Also search for numbered variants: "title #2", "title #3", etc.
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
cursor = self._conn.execute(
"SELECT id, title, started_at FROM sessions "
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
(f"{escaped} #%",),
)
numbered = cursor.fetchall()
if numbered:
# Return the most recent numbered variant
return numbered[0]["id"]
elif exact:
return exact["id"]
return None
def get_next_title_in_lineage(self, base_title: str) -> str:
"""Generate the next title in a lineage (e.g., "my session""my session #2").
Strips any existing " #N" suffix to find the base name, then finds
the highest existing number and increments.
"""
import re
# Strip existing #N suffix to find the true base
match = re.match(r'^(.*?) #(\d+)$', base_title)
if match:
base = match.group(1)
else:
base = base_title
# Find all existing numbered variants
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
(base, f"{escaped} #%"),
)
existing = [row["title"] for row in cursor.fetchall()]
if not existing:
return base # No conflict, use the base name as-is
# Find the highest number
max_num = 1 # The unnumbered original counts as #1
for t in existing:
m = re.match(r'^.* #(\d+)$', t)
if m:
max_num = max(max_num, int(m.group(1)))
return f"{base} #{max_num + 1}"
def list_sessions_rich(
self,
source: str = None,
limit: int = 20,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""List sessions with preview (first user message) and last active timestamp.
Returns dicts with keys: id, source, model, title, started_at, ended_at,
message_count, preview (first 60 chars of first user message),
last_active (timestamp of last message).
Uses a single query with correlated subqueries instead of N+2 queries.
"""
source_clause = "WHERE s.source = ?" if source else ""
query = f"""
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active
FROM sessions s
{source_clause}
ORDER BY s.started_at DESC
LIMIT ? OFFSET ?
"""
params = (source, limit, offset) if source else (limit, offset)
cursor = self._conn.execute(query, params)
sessions = []
for row in cursor.fetchall():
s = dict(row)
# Build the preview from the raw substring
raw = s.pop("_preview_raw", "").strip()
if raw:
text = raw[:60]
s["preview"] = text + ("..." if len(raw) > 60 else "")
else:
s["preview"] = ""
sessions.append(s)
return sessions
# =========================================================================
# Message storage
# =========================================================================

119
hermes_time.py Normal file
View File

@@ -0,0 +1,119 @@
"""
Timezone-aware clock for Hermes.
Provides a single ``now()`` helper that returns a timezone-aware datetime
based on the user's configured IANA timezone (e.g. ``Asia/Kolkata``).
Resolution order:
1. ``HERMES_TIMEZONE`` environment variable
2. ``timezone`` key in ``~/.hermes/config.yaml``
3. Falls back to the server's local time (``datetime.now().astimezone()``)
Invalid timezone values log a warning and fall back safely — Hermes never
crashes due to a bad timezone string.
"""
import logging
import os
from datetime import datetime, timezone as _tz
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
try:
from zoneinfo import ZoneInfo
except ImportError:
# Python 3.8 fallback (shouldn't be needed — Hermes requires 3.9+)
from backports.zoneinfo import ZoneInfo # type: ignore[no-redef]
# Cached state — resolved once, reused on every call.
# Call reset_cache() to force re-resolution (e.g. after config changes).
_cached_tz: Optional[ZoneInfo] = None
_cached_tz_name: Optional[str] = None
_cache_resolved: bool = False
def _resolve_timezone_name() -> str:
"""Read the configured IANA timezone string (or empty string).
This does file I/O when falling through to config.yaml, so callers
should cache the result rather than calling on every ``now()``.
"""
# 1. Environment variable (highest priority — set by Supervisor, etc.)
tz_env = os.getenv("HERMES_TIMEZONE", "").strip()
if tz_env:
return tz_env
# 2. config.yaml ``timezone`` key
try:
import yaml
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
config_path = hermes_home / "config.yaml"
if config_path.exists():
with open(config_path) as f:
cfg = yaml.safe_load(f) or {}
tz_cfg = cfg.get("timezone", "")
if isinstance(tz_cfg, str) and tz_cfg.strip():
return tz_cfg.strip()
except Exception:
pass
return ""
def _get_zoneinfo(name: str) -> Optional[ZoneInfo]:
"""Validate and return a ZoneInfo, or None if invalid."""
if not name:
return None
try:
return ZoneInfo(name)
except (KeyError, Exception) as exc:
logger.warning(
"Invalid timezone '%s': %s. Falling back to server local time.",
name, exc,
)
return None
def get_timezone() -> Optional[ZoneInfo]:
"""Return the user's configured ZoneInfo, or None (meaning server-local).
Resolved once and cached. Call ``reset_cache()`` after config changes.
"""
global _cached_tz, _cached_tz_name, _cache_resolved
if not _cache_resolved:
_cached_tz_name = _resolve_timezone_name()
_cached_tz = _get_zoneinfo(_cached_tz_name)
_cache_resolved = True
return _cached_tz
def get_timezone_name() -> str:
"""Return the IANA name of the configured timezone, or empty string."""
global _cached_tz_name, _cache_resolved
if not _cache_resolved:
get_timezone() # populates cache
return _cached_tz_name or ""
def now() -> datetime:
"""
Return the current time as a timezone-aware datetime.
If a valid timezone is configured, returns wall-clock time in that zone.
Otherwise returns the server's local time (via ``astimezone()``).
"""
tz = get_timezone()
if tz is not None:
return datetime.now(tz)
# No timezone configured — use server-local (still tz-aware)
return datetime.now().astimezone()
def reset_cache() -> None:
"""Clear the cached timezone. Used by tests and after config changes."""
global _cached_tz, _cached_tz_name, _cache_resolved
_cached_tz = None
_cached_tz_name = None
_cache_resolved = False

View File

@@ -149,7 +149,7 @@ class MiniSWERunner:
def __init__(
self,
model: str = "anthropic/claude-sonnet-4-20250514",
model: str = "anthropic/claude-sonnet-4.6",
base_url: str = None,
api_key: str = None,
env_type: str = "local",
@@ -200,13 +200,7 @@ class MiniSWERunner:
else:
client_kwargs["base_url"] = "https://openrouter.ai/api/v1"
if base_url and "api.anthropic.com" in base_url.strip().lower():
raise ValueError(
"Anthropic's native /v1/messages API is not supported yet (planned for a future release). "
"Hermes currently requires OpenAI-compatible /chat/completions endpoints. "
"To use Claude models now, route through OpenRouter (OPENROUTER_API_KEY) "
"or any OpenAI-compatible proxy that wraps the Anthropic API."
)
# Handle API key - OpenRouter is the primary provider
if api_key:

View File

@@ -225,6 +225,18 @@ def get_tool_definitions(
# Ask the registry for schemas (only returns tools whose check_fn passes)
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
# Rebuild execute_code schema to only list sandbox tools that are actually
# enabled. Without this, the model sees "web_search is available in
# execute_code" even when the user disabled the web toolset (#560-discord).
if "execute_code" in tools_to_include:
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
dynamic_schema = build_execute_code_schema(sandbox_enabled)
for i, td in enumerate(filtered_tools):
if td.get("function", {}).get("name") == "execute_code":
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
break
if not quiet_mode:
if filtered_tools:
tool_names = [t["function"]["name"] for t in filtered_tools]

View File

@@ -0,0 +1,24 @@
# Optional Skills
Official skills maintained by Nous Research that are **not activated by default**.
These skills ship with the hermes-agent repository but are not copied to
`~/.hermes/skills/` during setup. They are discoverable via the Skills Hub:
```bash
hermes skills browse # browse all skills, official shown first
hermes skills browse --source official # browse only official optional skills
hermes skills search <query> # finds optional skills labeled "official"
hermes skills install <identifier> # copies to ~/.hermes/skills/ and activates
```
## Why optional?
Some skills are useful but not broadly needed by every user:
- **Niche integrations** — specific paid services, specialized tools
- **Experimental features** — promising but not yet proven
- **Heavyweight dependencies** — require significant setup (API keys, installs)
By keeping them optional, we keep the default skill set lean while still
providing curated, tested, official skills for users who want them.

View File

@@ -0,0 +1,2 @@
Optional autonomous AI agent integrations — external coding agent CLIs
that can be delegated to for independent coding tasks.

View File

@@ -0,0 +1,143 @@
---
name: blackbox
description: Delegate coding tasks to Blackbox AI CLI agent. Multi-model agent with built-in judge that runs tasks through multiple LLMs and picks the best result. Requires the blackbox CLI and a Blackbox AI API key.
version: 1.0.0
author: Hermes Agent (Nous Research)
license: MIT
metadata:
hermes:
tags: [Coding-Agent, Blackbox, Multi-Agent, Judge, Multi-Model]
related_skills: [claude-code, codex, hermes-agent]
---
# Blackbox CLI
Delegate coding tasks to [Blackbox AI](https://www.blackbox.ai/) via the Hermes terminal. Blackbox is a multi-model coding agent CLI that dispatches tasks to multiple LLMs (Claude, Codex, Gemini, Blackbox Pro) and uses a judge to select the best implementation.
The CLI is [open-source](https://github.com/blackboxaicode/cli) (GPL-3.0, TypeScript, forked from Gemini CLI) and supports interactive sessions, non-interactive one-shots, checkpointing, MCP, and vision model switching.
## Prerequisites
- Node.js 20+ installed
- Blackbox CLI installed: `npm install -g @blackboxai/cli`
- Or install from source:
```
git clone https://github.com/blackboxaicode/cli.git
cd cli && npm install && npm install -g .
```
- API key from [app.blackbox.ai/dashboard](https://app.blackbox.ai/dashboard)
- Configured: run `blackbox configure` and enter your API key
- Use `pty=true` in terminal calls — Blackbox CLI is an interactive terminal app
## One-Shot Tasks
```
terminal(command="blackbox --prompt 'Add JWT authentication with refresh tokens to the Express API'", workdir="/path/to/project", pty=true)
```
For quick scratch work:
```
terminal(command="cd $(mktemp -d) && git init && blackbox --prompt 'Build a REST API for todos with SQLite'", pty=true)
```
## Background Mode (Long Tasks)
For tasks that take minutes, use background mode so you can monitor progress:
```
# Start in background with PTY
terminal(command="blackbox --prompt 'Refactor the auth module to use OAuth 2.0'", workdir="~/project", background=true, pty=true)
# Returns session_id
# Monitor progress
process(action="poll", session_id="<id>")
process(action="log", session_id="<id>")
# Send input if Blackbox asks a question
process(action="submit", session_id="<id>", data="yes")
# Kill if needed
process(action="kill", session_id="<id>")
```
## Checkpoints & Resume
Blackbox CLI has built-in checkpoint support for pausing and resuming tasks:
```
# After a task completes, Blackbox shows a checkpoint tag
# Resume with a follow-up task:
terminal(command="blackbox --resume-checkpoint 'task-abc123-2026-03-06' --prompt 'Now add rate limiting to the endpoints'", workdir="~/project", pty=true)
```
## Session Commands
During an interactive session, use these commands:
| Command | Effect |
|---------|--------|
| `/compress` | Shrink conversation history to save tokens |
| `/clear` | Wipe history and start fresh |
| `/stats` | View current token usage |
| `Ctrl+C` | Cancel current operation |
## PR Reviews
Clone to a temp directory to avoid modifying the working tree:
```
terminal(command="REVIEW=$(mktemp -d) && git clone https://github.com/user/repo.git $REVIEW && cd $REVIEW && gh pr checkout 42 && blackbox --prompt 'Review this PR against main. Check for bugs, security issues, and code quality.'", pty=true)
```
## Parallel Work
Spawn multiple Blackbox instances for independent tasks:
```
terminal(command="blackbox --prompt 'Fix the login bug'", workdir="/tmp/issue-1", background=true, pty=true)
terminal(command="blackbox --prompt 'Add unit tests for auth'", workdir="/tmp/issue-2", background=true, pty=true)
# Monitor all
process(action="list")
```
## Multi-Model Mode
Blackbox's unique feature is running the same task through multiple models and judging the results. Configure which models to use via `blackbox configure` — select multiple providers to enable the Chairman/judge workflow where the CLI evaluates outputs from different models and picks the best one.
## Key Flags
| Flag | Effect |
|------|--------|
| `--prompt "task"` | Non-interactive one-shot execution |
| `--resume-checkpoint "tag"` | Resume from a saved checkpoint |
| `--yolo` | Auto-approve all actions and model switches |
| `blackbox session` | Start interactive chat session |
| `blackbox configure` | Change settings, providers, models |
| `blackbox info` | Display system information |
## Vision Support
Blackbox automatically detects images in input and can switch to multimodal analysis. VLM modes:
- `"once"` — Switch model for current query only
- `"session"` — Switch for entire session
- `"persist"` — Stay on current model (no switch)
## Token Limits
Control token usage via `.blackboxcli/settings.json`:
```json
{
"sessionTokenLimit": 32000
}
```
## Rules
1. **Always use `pty=true`** — Blackbox CLI is an interactive terminal app and will hang without a PTY
2. **Use `workdir`** — keep the agent focused on the right directory
3. **Background for long tasks** — use `background=true` and monitor with `process` tool
4. **Don't interfere** — monitor with `poll`/`log`, don't kill sessions because they're slow
5. **Report results** — after completion, check what changed and summarize for the user
6. **Credits cost money** — Blackbox uses a credit-based system; multi-model mode consumes credits faster
7. **Check prerequisites** — verify `blackbox` CLI is installed before attempting delegation

View File

@@ -0,0 +1,441 @@
---
name: qmd
description: Search personal knowledge bases, notes, docs, and meeting transcripts locally using qmd — a hybrid retrieval engine with BM25, vector search, and LLM reranking. Supports CLI and MCP integration.
version: 1.0.0
author: Hermes Agent + Teknium
license: MIT
platforms: [macos, linux]
metadata:
hermes:
tags: [Search, Knowledge-Base, RAG, Notes, MCP, Local-AI]
related_skills: [obsidian, native-mcp, arxiv]
---
# QMD — Query Markup Documents
Local, on-device search engine for personal knowledge bases. Indexes markdown
notes, meeting transcripts, documentation, and any text-based files, then
provides hybrid search combining keyword matching, semantic understanding, and
LLM-powered reranking — all running locally with no cloud dependencies.
Created by [Tobi Lütke](https://github.com/tobi/qmd). MIT licensed.
## When to Use
- User asks to search their notes, docs, knowledge base, or meeting transcripts
- User wants to find something across a large collection of markdown/text files
- User wants semantic search ("find notes about X concept") not just keyword grep
- User has already set up qmd collections and wants to query them
- User asks to set up a local knowledge base or document search system
- Keywords: "search my notes", "find in my docs", "knowledge base", "qmd"
## Prerequisites
### Node.js >= 22 (required)
```bash
# Check version
node --version # must be >= 22
# macOS — install or upgrade via Homebrew
brew install node@22
# Linux — use NodeSource or nvm
curl -fsSL https://deb.nodesource.com/setup_22.x | sudo -E bash -
sudo apt-get install -y nodejs
# or with nvm:
nvm install 22 && nvm use 22
```
### SQLite with Extension Support (macOS only)
macOS system SQLite lacks extension loading. Install via Homebrew:
```bash
brew install sqlite
```
### Install qmd
```bash
npm install -g @tobilu/qmd
# or with Bun:
bun install -g @tobilu/qmd
```
First run auto-downloads 3 local GGUF models (~2GB total):
| Model | Purpose | Size |
|-------|---------|------|
| embeddinggemma-300M-Q8_0 | Vector embeddings | ~300MB |
| qwen3-reranker-0.6b-q8_0 | Result reranking | ~640MB |
| qmd-query-expansion-1.7B | Query expansion | ~1.1GB |
### Verify Installation
```bash
qmd --version
qmd status
```
## Quick Reference
| Command | What It Does | Speed |
|---------|-------------|-------|
| `qmd search "query"` | BM25 keyword search (no models) | ~0.2s |
| `qmd vsearch "query"` | Semantic vector search (1 model) | ~3s |
| `qmd query "query"` | Hybrid + reranking (all 3 models) | ~2-3s warm, ~19s cold |
| `qmd get <docid>` | Retrieve full document content | instant |
| `qmd multi-get "glob"` | Retrieve multiple files | instant |
| `qmd collection add <path> --name <n>` | Add a directory as a collection | instant |
| `qmd context add <path> "description"` | Add context metadata to improve retrieval | instant |
| `qmd embed` | Generate/update vector embeddings | varies |
| `qmd status` | Show index health and collection info | instant |
| `qmd mcp` | Start MCP server (stdio) | persistent |
| `qmd mcp --http --daemon` | Start MCP server (HTTP, warm models) | persistent |
## Setup Workflow
### 1. Add Collections
Point qmd at directories containing your documents:
```bash
# Add a notes directory
qmd collection add ~/notes --name notes
# Add project docs
qmd collection add ~/projects/myproject/docs --name project-docs
# Add meeting transcripts
qmd collection add ~/meetings --name meetings
# List all collections
qmd collection list
```
### 2. Add Context Descriptions
Context metadata helps the search engine understand what each collection
contains. This significantly improves retrieval quality:
```bash
qmd context add qmd://notes "Personal notes, ideas, and journal entries"
qmd context add qmd://project-docs "Technical documentation for the main project"
qmd context add qmd://meetings "Meeting transcripts and action items from team syncs"
```
### 3. Generate Embeddings
```bash
qmd embed
```
This processes all documents in all collections and generates vector
embeddings. Re-run after adding new documents or collections.
### 4. Verify
```bash
qmd status # shows index health, collection stats, model info
```
## Search Patterns
### Fast Keyword Search (BM25)
Best for: exact terms, code identifiers, names, known phrases.
No models loaded — near-instant results.
```bash
qmd search "authentication middleware"
qmd search "handleError async"
```
### Semantic Vector Search
Best for: natural language questions, conceptual queries.
Loads embedding model (~3s first query).
```bash
qmd vsearch "how does the rate limiter handle burst traffic"
qmd vsearch "ideas for improving onboarding flow"
```
### Hybrid Search with Reranking (Best Quality)
Best for: important queries where quality matters most.
Uses all 3 models — query expansion, parallel BM25+vector, reranking.
```bash
qmd query "what decisions were made about the database migration"
```
### Structured Multi-Mode Queries
Combine different search types in a single query for precision:
```bash
# BM25 for exact term + vector for concept
qmd query $'lex: rate limiter\nvec: how does throttling work under load'
# With query expansion
qmd query $'expand: database migration plan\nlex: "schema change"'
```
### Query Syntax (lex/BM25 mode)
| Syntax | Effect | Example |
|--------|--------|---------|
| `term` | Prefix match | `perf` matches "performance" |
| `"phrase"` | Exact phrase | `"rate limiter"` |
| `-term` | Exclude term | `performance -sports` |
### HyDE (Hypothetical Document Embeddings)
For complex topics, write what you expect the answer to look like:
```bash
qmd query $'hyde: The migration plan involves three phases. First, we add the new columns without dropping the old ones. Then we backfill data. Finally we cut over and remove legacy columns.'
```
### Scoping to Collections
```bash
qmd search "query" --collection notes
qmd query "query" --collection project-docs
```
### Output Formats
```bash
qmd search "query" --json # JSON output (best for parsing)
qmd search "query" --limit 5 # Limit results
qmd get "#abc123" # Get by document ID
qmd get "path/to/file.md" # Get by file path
qmd get "file.md:50" -l 100 # Get specific line range
qmd multi-get "journals/*.md" --json # Batch retrieve by glob
```
## MCP Integration (Recommended)
qmd exposes an MCP server that provides search tools directly to
Hermes Agent via the native MCP client. This is the preferred
integration — once configured, the agent gets qmd tools automatically
without needing to load this skill.
### Option A: Stdio Mode (Simple)
Add to `~/.hermes/config.yaml`:
```yaml
mcp_servers:
qmd:
command: "qmd"
args: ["mcp"]
timeout: 30
connect_timeout: 45
```
This registers tools: `mcp_qmd_search`, `mcp_qmd_vsearch`,
`mcp_qmd_deep_search`, `mcp_qmd_get`, `mcp_qmd_status`.
**Tradeoff:** Models load on first search call (~19s cold start),
then stay warm for the session. Acceptable for occasional use.
### Option B: HTTP Daemon Mode (Fast, Recommended for Heavy Use)
Start the qmd daemon separately — it keeps models warm in memory:
```bash
# Start daemon (persists across agent restarts)
qmd mcp --http --daemon
# Runs on http://localhost:8181 by default
```
Then configure Hermes Agent to connect via HTTP:
```yaml
mcp_servers:
qmd:
url: "http://localhost:8181/mcp"
timeout: 30
```
**Tradeoff:** Uses ~2GB RAM while running, but every query is fast
(~2-3s). Best for users who search frequently.
### Keeping the Daemon Running
#### macOS (launchd)
```bash
cat > ~/Library/LaunchAgents/com.qmd.daemon.plist << 'EOF'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.qmd.daemon</string>
<key>ProgramArguments</key>
<array>
<string>qmd</string>
<string>mcp</string>
<string>--http</string>
<string>--daemon</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<true/>
<key>StandardOutPath</key>
<string>/tmp/qmd-daemon.log</string>
<key>StandardErrorPath</key>
<string>/tmp/qmd-daemon.log</string>
</dict>
</plist>
EOF
launchctl load ~/Library/LaunchAgents/com.qmd.daemon.plist
```
#### Linux (systemd user service)
```bash
mkdir -p ~/.config/systemd/user
cat > ~/.config/systemd/user/qmd-daemon.service << 'EOF'
[Unit]
Description=QMD MCP Daemon
After=network.target
[Service]
ExecStart=qmd mcp --http --daemon
Restart=on-failure
RestartSec=10
Environment=PATH=/usr/local/bin:/usr/bin:/bin
[Install]
WantedBy=default.target
EOF
systemctl --user daemon-reload
systemctl --user enable --now qmd-daemon
systemctl --user status qmd-daemon
```
### MCP Tools Reference
Once connected, these tools are available as `mcp_qmd_*`:
| MCP Tool | Maps To | Description |
|----------|---------|-------------|
| `mcp_qmd_search` | `qmd search` | BM25 keyword search |
| `mcp_qmd_vsearch` | `qmd vsearch` | Semantic vector search |
| `mcp_qmd_deep_search` | `qmd query` | Hybrid search + reranking |
| `mcp_qmd_get` | `qmd get` | Retrieve document by ID or path |
| `mcp_qmd_status` | `qmd status` | Index health and stats |
The MCP tools accept structured JSON queries for multi-mode search:
```json
{
"searches": [
{"type": "lex", "query": "authentication middleware"},
{"type": "vec", "query": "how user login is verified"}
],
"collections": ["project-docs"],
"limit": 10
}
```
## CLI Usage (Without MCP)
When MCP is not configured, use qmd directly via terminal:
```
terminal(command="qmd query 'what was decided about the API redesign' --json", timeout=30)
```
For setup and management tasks, always use terminal:
```
terminal(command="qmd collection add ~/Documents/notes --name notes")
terminal(command="qmd context add qmd://notes 'Personal research notes and ideas'")
terminal(command="qmd embed")
terminal(command="qmd status")
```
## How the Search Pipeline Works
Understanding the internals helps choose the right search mode:
1. **Query Expansion** — A fine-tuned 1.7B model generates 2 alternative
queries. The original gets 2x weight in fusion.
2. **Parallel Retrieval** — BM25 (SQLite FTS5) and vector search run
simultaneously across all query variants.
3. **RRF Fusion** — Reciprocal Rank Fusion (k=60) merges results.
Top-rank bonus: #1 gets +0.05, #2-3 get +0.02.
4. **LLM Reranking** — qwen3-reranker scores top 30 candidates (0.0-1.0).
5. **Position-Aware Blending** — Ranks 1-3: 75% retrieval / 25% reranker.
Ranks 4-10: 60/40. Ranks 11+: 40/60 (trusts reranker more for long tail).
**Smart Chunking:** Documents are split at natural break points (headings,
code blocks, blank lines) targeting ~900 tokens with 15% overlap. Code
blocks are never split mid-block.
## Best Practices
1. **Always add context descriptions**`qmd context add` dramatically
improves retrieval accuracy. Describe what each collection contains.
2. **Re-embed after adding documents**`qmd embed` must be re-run when
new files are added to collections.
3. **Use `qmd search` for speed** — when you need fast keyword lookup
(code identifiers, exact names), BM25 is instant and needs no models.
4. **Use `qmd query` for quality** — when the question is conceptual or
the user needs the best possible results, use hybrid search.
5. **Prefer MCP integration** — once configured, the agent gets native
tools without needing to load this skill each time.
6. **Daemon mode for frequent users** — if the user searches their
knowledge base regularly, recommend the HTTP daemon setup.
7. **First query in structured search gets 2x weight** — put the most
important/certain query first when combining lex and vec.
## Troubleshooting
### "Models downloading on first run"
Normal — qmd auto-downloads ~2GB of GGUF models on first use.
This is a one-time operation.
### Cold start latency (~19s)
This happens when models aren't loaded in memory. Solutions:
- Use HTTP daemon mode (`qmd mcp --http --daemon`) to keep warm
- Use `qmd search` (BM25 only) when models aren't needed
- MCP stdio mode loads models on first search, stays warm for session
### macOS: "unable to load extension"
Install Homebrew SQLite: `brew install sqlite`
Then ensure it's on PATH before system SQLite.
### "No collections found"
Run `qmd collection add <path> --name <name>` to add directories,
then `qmd embed` to index them.
### Embedding model override (CJK/multilingual)
Set `QMD_EMBED_MODEL` environment variable for non-English content:
```bash
export QMD_EMBED_MODEL="your-multilingual-model"
```
## Data Storage
- **Index & vectors:** `~/.cache/qmd/index.sqlite`
- **Models:** Auto-downloaded to local cache on first run
- **No cloud dependencies** — everything runs locally
## References
- [GitHub: tobi/qmd](https://github.com/tobi/qmd)
- [QMD Changelog](https://github.com/tobi/qmd/blob/main/CHANGELOG.md)

View File

@@ -5,9 +5,9 @@ build-backend = "setuptools.build_meta"
[project]
name = "hermes-agent"
version = "0.1.0"
description = "AI agent with advanced tool-calling and toolsets"
description = "The self-improving AI agent — creates skills from experience, improves them during use, and runs anywhere"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
authors = [{ name = "Nous Research" }]
license = { text = "MIT" }
dependencies = [
@@ -39,6 +39,7 @@ dependencies = [
[project.optional-dependencies]
modal = ["swe-rex[modal]>=1.4.0"]
daytona = ["daytona>=0.148.0"]
dev = ["pytest", "pytest-asyncio"]
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
cron = ["croniter"]
@@ -49,8 +50,10 @@ pty = ["ptyprocess>=0.7.0"]
honcho = ["honcho-ai>=2.0.1"]
mcp = ["mcp>=1.2.0"]
homeassistant = ["aiohttp>=3.9.0"]
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git"]
all = [
"hermes-agent[modal]",
"hermes-agent[daytona]",
"hermes-agent[messaging]",
"hermes-agent[cron]",
"hermes-agent[cli]",

View File

@@ -82,6 +82,8 @@ from agent.prompt_builder import (
from agent.model_metadata import (
fetch_model_metadata, get_model_context_length,
estimate_tokens_rough, estimate_messages_tokens_rough,
get_next_probe_tier, parse_context_limit_from_error,
save_context_length,
)
from agent.context_compressor import ContextCompressor
from agent.prompt_caching import apply_anthropic_cache_control
@@ -97,6 +99,46 @@ from agent.trajectory import (
)
class IterationBudget:
"""Thread-safe shared iteration counter for parent and child agents.
Tracks total LLM-call iterations consumed across a parent agent and all
its subagents. A single ``IterationBudget`` is created by the parent
and passed to every child so they share the same cap.
``execute_code`` (programmatic tool calling) iterations are refunded via
:meth:`refund` so they don't eat into the budget.
"""
def __init__(self, max_total: int):
self.max_total = max_total
self._used = 0
self._lock = threading.Lock()
def consume(self) -> bool:
"""Try to consume one iteration. Returns True if allowed."""
with self._lock:
if self._used >= self.max_total:
return False
self._used += 1
return True
def refund(self) -> None:
"""Give back one iteration (e.g. for execute_code turns)."""
with self._lock:
if self._used > 0:
self._used -= 1
@property
def used(self) -> int:
return self._used
@property
def remaining(self) -> int:
with self._lock:
return max(0, self.max_total - self._used)
class AIAgent:
"""
AI Agent with tool calling capabilities.
@@ -112,7 +154,7 @@ class AIAgent:
provider: str = None,
api_mode: str = None,
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
max_iterations: int = 60, # Default tool-calling iterations
max_iterations: int = 90, # Default tool-calling iterations (shared with subagents)
tool_delay: float = 1.0,
enabled_toolsets: List[str] = None,
disabled_toolsets: List[str] = None,
@@ -140,6 +182,7 @@ class AIAgent:
skip_memory: bool = False,
session_db=None,
honcho_session_key: str = None,
iteration_budget: "IterationBudget" = None,
):
"""
Initialize the AI Agent.
@@ -150,7 +193,7 @@ class AIAgent:
provider (str): Provider identifier (optional; used for telemetry/routing hints)
api_mode (str): API mode override: "chat_completions" or "codex_responses"
model (str): Model name to use (default: "anthropic/claude-opus-4.6")
max_iterations (int): Maximum number of tool calling iterations (default: 60)
max_iterations (int): Maximum number of tool calling iterations (default: 90)
tool_delay (float): Delay between tool calls in seconds (default: 1.0)
enabled_toolsets (List[str]): Only enable tools from these toolsets (optional)
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
@@ -170,7 +213,7 @@ class AIAgent:
Provided by the platform layer (CLI or gateway). If None, the clarify tool returns an error.
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
reasoning_config (Dict): OpenRouter reasoning configuration override (e.g. {"effort": "none"} to disable thinking).
If None, defaults to {"enabled": True, "effort": "xhigh"} for OpenRouter. Set to disable/customize reasoning.
If None, defaults to {"enabled": True, "effort": "medium"} for OpenRouter. Set to disable/customize reasoning.
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
Useful for injecting a few-shot example or priming the model's response style.
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
@@ -184,6 +227,9 @@ class AIAgent:
"""
self.model = model
self.max_iterations = max_iterations
# Shared iteration budget — parent creates, children inherit.
# Consumed by every LLM turn across parent + all subagents.
self.iteration_budget = iteration_budget or IterationBudget(max_iterations)
self.tool_delay = tool_delay
self.save_trajectories = save_trajectories
self.verbose_logging = verbose_logging
@@ -207,13 +253,7 @@ class AIAgent:
self.provider = "openai-codex"
else:
self.api_mode = "chat_completions"
if base_url and "api.anthropic.com" in base_url.strip().lower():
raise ValueError(
"Anthropic's native /v1/messages API is not supported yet (planned for a future release). "
"Hermes currently requires OpenAI-compatible /chat/completions endpoints. "
"To use Claude models now, route through OpenRouter (OPENROUTER_API_KEY) "
"or any OpenAI-compatible proxy that wraps the Anthropic API."
)
self.tool_progress_callback = tool_progress_callback
self.clarify_callback = clarify_callback
self.step_callback = step_callback
@@ -241,7 +281,7 @@ class AIAgent:
# Model response configuration
self.max_tokens = max_tokens # None = use model default
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
self.reasoning_config = reasoning_config # None = use default (medium for OpenRouter)
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
@@ -343,6 +383,12 @@ class AIAgent:
"X-OpenRouter-Title": "Hermes Agent",
"X-OpenRouter-Categories": "productivity,cli-agent",
}
elif "api.kimi.com" in effective_base.lower():
# Kimi Code API requires a recognized coding-agent User-Agent
# (see https://github.com/MoonshotAI/kimi-cli)
client_kwargs["default_headers"] = {
"User-Agent": "KimiCLI/1.0",
}
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
try:
@@ -536,6 +582,7 @@ class AIAgent:
summary_target_tokens=500,
summary_model_override=compression_summary_model,
quiet_mode=self.quiet_mode,
base_url=self.base_url,
)
self.compression_enabled = compression_enabled
self._user_turn_count = 0
@@ -1360,7 +1407,8 @@ class AIAgent:
if context_files_prompt:
prompt_parts.append(context_files_prompt)
now = datetime.now()
from hermes_time import now as _hermes_now
now = _hermes_now()
prompt_parts.append(
f"Conversation started: {now.strftime('%A, %B %d, %Y %I:%M %p')}"
)
@@ -2015,6 +2063,49 @@ class AIAgent:
return True
def _try_refresh_nous_client_credentials(self, *, force: bool = True) -> bool:
if self.api_mode != "chat_completions" or self.provider != "nous":
return False
try:
from hermes_cli.auth import resolve_nous_runtime_credentials
creds = resolve_nous_runtime_credentials(
min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))),
timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")),
force_mint=force,
)
except Exception as exc:
logger.debug("Nous credential refresh failed: %s", exc)
return False
api_key = creds.get("api_key")
base_url = creds.get("base_url")
if not isinstance(api_key, str) or not api_key.strip():
return False
if not isinstance(base_url, str) or not base_url.strip():
return False
self.api_key = api_key.strip()
self.base_url = base_url.strip().rstrip("/")
self._client_kwargs["api_key"] = self.api_key
self._client_kwargs["base_url"] = self.base_url
# Nous requests should not inherit OpenRouter-only attribution headers.
self._client_kwargs.pop("default_headers", None)
try:
self.client.close()
except Exception:
pass
try:
self.client = OpenAI(**self._client_kwargs)
except Exception as exc:
logger.warning("Failed to rebuild OpenAI client after Nous refresh: %s", exc)
return False
return True
def _interruptible_api_call(self, api_kwargs: dict):
"""
Run the API call in a background thread so the main conversation loop
@@ -2066,8 +2157,8 @@ class AIAgent:
if not instructions:
instructions = DEFAULT_AGENT_IDENTITY
# Resolve reasoning effort: config > default (xhigh)
reasoning_effort = "xhigh"
# Resolve reasoning effort: config > default (medium)
reasoning_effort = "medium"
reasoning_enabled = True
if self.reasoning_config and isinstance(self.reasoning_config, dict):
if self.reasoning_config.get("enabled") is False:
@@ -2133,7 +2224,7 @@ class AIAgent:
else:
extra_body["reasoning"] = {
"enabled": True,
"effort": "xhigh"
"effort": "medium"
}
# Nous Portal product attribution
@@ -2393,6 +2484,8 @@ class AIAgent:
if self._session_db:
try:
# Propagate title to the new session with auto-numbering
old_title = self._session_db.get_session_title(self.session_id)
self._session_db.end_session(self.session_id, "compression")
old_session_id = self.session_id
self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
@@ -2402,6 +2495,13 @@ class AIAgent:
model=self.model,
parent_session_id=old_session_id,
)
# Auto-number the title for the continuation session
if old_title:
try:
new_title = self._session_db.get_next_title_in_lineage(old_title)
self._session_db.set_session_title(self.session_id, new_title)
except (ValueError, Exception) as e:
logger.debug("Could not propagate title on compression: %s", e)
self._session_db.update_system_prompt(self.session_id, new_system_prompt)
except Exception as e:
logger.debug("Session DB compression split failed: %s", e)
@@ -2528,7 +2628,6 @@ class AIAgent:
context=function_args.get("context"),
toolsets=function_args.get("toolsets"),
tasks=tasks_arg,
model=function_args.get("model"),
max_iterations=function_args.get("max_iterations"),
parent_agent=self,
)
@@ -2677,7 +2776,7 @@ class AIAgent:
else:
summary_extra_body["reasoning"] = {
"enabled": True,
"effort": "xhigh"
"effort": "medium"
}
if _is_nous:
summary_extra_body["tags"] = ["product=hermes-agent"]
@@ -2740,7 +2839,7 @@ class AIAgent:
"messages": api_messages,
}
if self.max_tokens is not None:
summary_kwargs["max_tokens"] = self.max_tokens
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
if summary_extra_body:
summary_kwargs["extra_body"] = summary_extra_body
@@ -2789,13 +2888,15 @@ class AIAgent:
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
effective_task_id = task_id or str(uuid.uuid4())
# Reset retry counters at the start of each conversation to prevent state leakage
# Reset retry counters and iteration budget at the start of each turn
# so subagent usage from a previous turn doesn't eat into the next one.
self._invalid_tool_retries = 0
self._invalid_json_retries = 0
self._empty_content_retries = 0
self._last_content_with_tools = None
self._turns_since_memory = 0
self._iters_since_skill = 0
self.iteration_budget = IterationBudget(self.max_iterations)
# Initialize conversation (copy to avoid mutating the caller's list)
messages = list(conversation_history) if conversation_history else []
@@ -2927,7 +3028,7 @@ class AIAgent:
# Clear any stale interrupt state at start
self.clear_interrupt()
while api_call_count < self.max_iterations:
while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0:
# Check for interrupt request (e.g., user sent new message)
if self._interrupt_requested:
interrupted = True
@@ -2936,6 +3037,10 @@ class AIAgent:
break
api_call_count += 1
if not self.iteration_budget.consume():
if not self.quiet_mode:
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.max_total} total across agent + subagents)")
break
# Fire step_callback for gateway hooks (agent:step event)
if self.step_callback is not None:
@@ -3012,6 +3117,13 @@ class AIAgent:
if self._use_prompt_caching:
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
# Safety net: strip orphaned tool results / add stubs for missing
# results before sending to the API. The compressor handles this
# during compression, but orphans can also sneak in from session
# loading or manual message manipulation.
if hasattr(self, 'context_compressor') and self.context_compressor:
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
# Calculate approximate request size for logging
total_chars = sum(len(str(msg)) for msg in api_messages)
approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token
@@ -3040,9 +3152,13 @@ class AIAgent:
api_start_time = time.time()
retry_count = 0
max_retries = 6 # Increased to allow longer backoff periods
compression_attempts = 0
max_compression_attempts = 3
codex_auth_retry_attempted = False
nous_auth_retry_attempted = False
finish_reason = "stop"
response = None # Guard against UnboundLocalError if all retries fail
while retry_count < max_retries:
try:
@@ -3236,6 +3352,13 @@ class AIAgent:
}
self.context_compressor.update_from_response(usage_dict)
# Cache discovered context length after successful call
if self.context_compressor._context_probed:
ctx = self.context_compressor.context_length
save_context_length(self.model, self.base_url, ctx)
print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}")
self.context_compressor._context_probed = False
self.session_prompt_tokens += prompt_tokens
self.session_completion_tokens += completion_tokens
self.session_total_tokens += total_tokens
@@ -3283,6 +3406,16 @@ class AIAgent:
if self._try_refresh_codex_client_credentials(force=True):
print(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...")
continue
if (
self.api_mode == "chat_completions"
and self.provider == "nous"
and status_code == 401
and not nous_auth_retry_attempted
):
nous_auth_retry_attempted = True
if self._try_refresh_nous_client_credentials(force=True):
print(f"{self.log_prefix}🔐 Nous agent key refreshed after 401. Retrying request...")
continue
retry_count += 1
elapsed_time = time.time() - api_start_time
@@ -3321,7 +3454,19 @@ class AIAgent:
)
if is_payload_too_large:
print(f"{self.log_prefix}⚠️ Request payload too large (413) - attempting compression...")
compression_attempts += 1
if compression_attempts > max_compression_attempts:
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.")
logging.error(f"{self.log_prefix}413 compression failed after {max_compression_attempts} attempts.")
self._persist_session(messages, conversation_history)
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": f"Request payload too large: max compression attempts ({max_compression_attempts}) reached.",
"partial": True
}
print(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
original_len = len(messages)
messages, active_system_prompt = self._compress_context(
@@ -3330,6 +3475,7 @@ class AIAgent:
if len(messages) < original_len:
print(f"{self.log_prefix} 🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
time.sleep(2) # Brief pause between compression retries
continue # Retry with compressed messages
else:
print(f"{self.log_prefix}❌ Payload too large and cannot compress further.")
@@ -3355,18 +3501,52 @@ class AIAgent:
])
if is_context_length_error:
print(f"{self.log_prefix}⚠️ Context length exceeded - attempting compression...")
compressor = self.context_compressor
old_ctx = compressor.context_length
# Try to parse the actual limit from the error message
parsed_limit = parse_context_limit_from_error(error_msg)
if parsed_limit and parsed_limit < old_ctx:
new_ctx = parsed_limit
print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})")
else:
# Step down to the next probe tier
new_ctx = get_next_probe_tier(old_ctx)
if new_ctx and new_ctx < old_ctx:
compressor.context_length = new_ctx
compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent)
compressor._context_probed = True
print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,}{new_ctx:,} tokens")
else:
print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...")
compression_attempts += 1
if compression_attempts > max_compression_attempts:
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.")
logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.")
self._persist_session(messages, conversation_history)
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
"partial": True
}
print(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...")
original_len = len(messages)
messages, active_system_prompt = self._compress_context(
messages, system_message, approx_tokens=approx_tokens
)
if len(messages) < original_len:
print(f"{self.log_prefix} 🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
continue # Retry with compressed messages
if len(messages) < original_len or new_ctx and new_ctx < old_ctx:
if len(messages) < original_len:
print(f"{self.log_prefix} 🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
time.sleep(2) # Brief pause between compression retries
continue # Retry with compressed messages or new tier
else:
# Can't compress further
# Can't compress further and already at minimum tier
print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.")
print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.")
logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.")
@@ -3442,6 +3622,14 @@ class AIAgent:
if interrupted:
break
# Guard: if all retries exhausted without a successful response
# (e.g. repeated context-length errors that exhausted retry_count),
# the `response` variable is still None. Break out cleanly.
if response is None:
print(f"{self.log_prefix}❌ All API retries exhausted with no successful response.")
self._persist_session(messages, conversation_history)
break
try:
if self.api_mode == "codex_responses":
assistant_message, finish_reason = self._normalize_codex_response(response)
@@ -3658,6 +3846,13 @@ class AIAgent:
self._log_msg_to_db(assistant_msg)
self._execute_tool_calls(assistant_message, messages, effective_task_id)
# Refund the iteration if the ONLY tool(s) called were
# execute_code (programmatic tool calling). These are
# cheap RPC-style calls that shouldn't eat the budget.
_tc_names = {tc.function.name for tc in assistant_message.tool_calls}
if _tc_names == {"execute_code"}:
self.iteration_budget.refund()
if self.compression_enabled and self.context_compressor.should_compress():
messages, active_system_prompt = self._compress_context(
@@ -3678,13 +3873,33 @@ class AIAgent:
# Check if response only has think block with no actual content after it
if not self._has_content_after_think_block(final_response):
# Track retries for empty-after-think responses
# If the previous turn already delivered real content alongside
# tool calls (e.g. "You're welcome!" + memory save), the model
# has nothing more to say. Use the earlier content immediately
# instead of wasting API calls on retries that won't help.
fallback = getattr(self, '_last_content_with_tools', None)
if fallback:
logger.debug("Empty follow-up after tool calls — using prior turn content as final response")
self._last_content_with_tools = None
self._empty_content_retries = 0
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_names = []
for tc in msg["tool_calls"]:
fn = tc.get("function", {})
tool_names.append(fn.get("name", "unknown"))
msg["content"] = f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..."
break
final_response = self._strip_think_blocks(fallback).strip()
break
# No fallback available — this is a genuine empty response.
# Retry in case the model just had a bad generation.
if not hasattr(self, '_empty_content_retries'):
self._empty_content_retries = 0
self._empty_content_retries += 1
# Show the reasoning/thinking content so the user can see
# what the model was thinking even though content is empty
reasoning_text = self._extract_reasoning(assistant_message)
print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
if reasoning_text:
@@ -3840,7 +4055,12 @@ class AIAgent:
final_response = f"I apologize, but I encountered repeated errors: {error_msg}"
break
if api_call_count >= self.max_iterations and final_response is None:
if final_response is None and (
api_call_count >= self.max_iterations
or self.iteration_budget.remaining <= 0
):
if self.iteration_budget.remaining <= 0 and not self.quiet_mode:
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} used, including subagents)")
final_response = self._handle_max_iterations(messages, api_call_count)
# Determine if conversation completed successfully
@@ -3911,7 +4131,7 @@ def main(
Args:
query (str): Natural language query for the agent. Defaults to Python 3.13 example.
model (str): Model name to use (OpenRouter format: provider/model). Defaults to anthropic/claude-sonnet-4-20250514.
model (str): Model name to use (OpenRouter format: provider/model). Defaults to anthropic/claude-sonnet-4.6.
api_key (str): API key for authentication. Uses OPENROUTER_API_KEY env var if not provided.
base_url (str): Base URL for the model API. Defaults to https://openrouter.ai/api/v1
max_turns (int): Maximum number of API call iterations. Defaults to 10.

View File

@@ -829,6 +829,33 @@ install_node_deps() {
log_warn "npm install failed (browser tools may not work)"
}
log_success "Node.js dependencies installed"
# Install Playwright browser + system dependencies.
# Playwright's install-deps only supports apt/dnf/zypper natively.
# For Arch/Manjaro we install the system libs via pacman first.
log_info "Installing browser engine (Playwright Chromium)..."
case "$DISTRO" in
arch|manjaro)
if command -v pacman &> /dev/null; then
log_info "Arch/Manjaro detected — installing Chromium system dependencies via pacman..."
if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then
sudo NEEDRESTART_MODE=a pacman -S --noconfirm --needed \
nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib >/dev/null 2>&1 || true
elif [ "$(id -u)" -eq 0 ]; then
pacman -S --noconfirm --needed \
nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib >/dev/null 2>&1 || true
else
log_warn "Cannot install browser deps without sudo. Run manually:"
log_warn " sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
fi
fi
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
;;
*)
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || true
;;
esac
log_success "Browser engine installed"
fi
# Install WhatsApp bridge dependencies

View File

@@ -0,0 +1,3 @@
---
description: Apple/macOS-specific skills — iMessage, Reminders, Notes, FindMy, and macOS automation. These skills only load on macOS systems.
---

View File

@@ -0,0 +1,88 @@
---
name: apple-notes
description: Manage Apple Notes via the memo CLI on macOS (create, view, search, edit).
version: 1.0.0
author: Hermes Agent
license: MIT
platforms: [macos]
metadata:
hermes:
tags: [Notes, Apple, macOS, note-taking]
related_skills: [obsidian]
---
# Apple Notes
Use `memo` to manage Apple Notes directly from the terminal. Notes sync across all Apple devices via iCloud.
## Prerequisites
- **macOS** with Notes.app
- Install: `brew tap antoniorodr/memo && brew install antoniorodr/memo/memo`
- Grant Automation access to Notes.app when prompted (System Settings → Privacy → Automation)
## When to Use
- User asks to create, view, or search Apple Notes
- Saving information to Notes.app for cross-device access
- Organizing notes into folders
- Exporting notes to Markdown/HTML
## When NOT to Use
- Obsidian vault management → use the `obsidian` skill
- Bear Notes → separate app (not supported here)
- Quick agent-only notes → use the `memory` tool instead
## Quick Reference
### View Notes
```bash
memo notes # List all notes
memo notes -f "Folder Name" # Filter by folder
memo notes -s "query" # Search notes (fuzzy)
```
### Create Notes
```bash
memo notes -a # Interactive editor
memo notes -a "Note Title" # Quick add with title
```
### Edit Notes
```bash
memo notes -e # Interactive selection to edit
```
### Delete Notes
```bash
memo notes -d # Interactive selection to delete
```
### Move Notes
```bash
memo notes -m # Move note to folder (interactive)
```
### Export Notes
```bash
memo notes -ex # Export to HTML/Markdown
```
## Limitations
- Cannot edit notes containing images or attachments
- Interactive prompts require terminal access (use pty=true if needed)
- macOS only — requires Apple Notes.app
## Rules
1. Prefer Apple Notes when user wants cross-device sync (iPhone/iPad/Mac)
2. Use the `memory` tool for agent-internal notes that don't need to sync
3. Use the `obsidian` skill for Markdown-native knowledge management

View File

@@ -0,0 +1,96 @@
---
name: apple-reminders
description: Manage Apple Reminders via remindctl CLI (list, add, complete, delete).
version: 1.0.0
author: Hermes Agent
license: MIT
platforms: [macos]
metadata:
hermes:
tags: [Reminders, tasks, todo, macOS, Apple]
---
# Apple Reminders
Use `remindctl` to manage Apple Reminders directly from the terminal. Tasks sync across all Apple devices via iCloud.
## Prerequisites
- **macOS** with Reminders.app
- Install: `brew install steipete/tap/remindctl`
- Grant Reminders permission when prompted
- Check: `remindctl status` / Request: `remindctl authorize`
## When to Use
- User mentions "reminder" or "Reminders app"
- Creating personal to-dos with due dates that sync to iOS
- Managing Apple Reminders lists
- User wants tasks to appear on their iPhone/iPad
## When NOT to Use
- Scheduling agent alerts → use the cronjob tool instead
- Calendar events → use Apple Calendar or Google Calendar
- Project task management → use GitHub Issues, Notion, etc.
- If user says "remind me" but means an agent alert → clarify first
## Quick Reference
### View Reminders
```bash
remindctl # Today's reminders
remindctl today # Today
remindctl tomorrow # Tomorrow
remindctl week # This week
remindctl overdue # Past due
remindctl all # Everything
remindctl 2026-01-04 # Specific date
```
### Manage Lists
```bash
remindctl list # List all lists
remindctl list Work # Show specific list
remindctl list Projects --create # Create list
remindctl list Work --delete # Delete list
```
### Create Reminders
```bash
remindctl add "Buy milk"
remindctl add --title "Call mom" --list Personal --due tomorrow
remindctl add --title "Meeting prep" --due "2026-02-15 09:00"
```
### Complete / Delete
```bash
remindctl complete 1 2 3 # Complete by ID
remindctl delete 4A83 --force # Delete by ID
```
### Output Formats
```bash
remindctl today --json # JSON for scripting
remindctl today --plain # TSV format
remindctl today --quiet # Counts only
```
## Date Formats
Accepted by `--due` and date filters:
- `today`, `tomorrow`, `yesterday`
- `YYYY-MM-DD`
- `YYYY-MM-DD HH:mm`
- ISO 8601 (`2026-01-04T12:34:56Z`)
## Rules
1. When user says "remind me", clarify: Apple Reminders (syncs to phone) vs agent cronjob alert
2. Always confirm reminder content and due date before creating
3. Use `--json` for programmatic parsing

View File

@@ -0,0 +1,131 @@
---
name: findmy
description: Track Apple devices and AirTags via FindMy.app on macOS using AppleScript and screen capture.
version: 1.0.0
author: Hermes Agent
license: MIT
platforms: [macos]
metadata:
hermes:
tags: [FindMy, AirTag, location, tracking, macOS, Apple]
---
# Find My (Apple)
Track Apple devices and AirTags via the FindMy.app on macOS. Since Apple doesn't
provide a CLI for FindMy, this skill uses AppleScript to open the app and
screen capture to read device locations.
## Prerequisites
- **macOS** with Find My app and iCloud signed in
- Devices/AirTags already registered in Find My
- Screen Recording permission for terminal (System Settings → Privacy → Screen Recording)
- **Optional but recommended**: Install `peekaboo` for better UI automation:
`brew install steipete/tap/peekaboo`
## When to Use
- User asks "where is my [device/cat/keys/bag]?"
- Tracking AirTag locations
- Checking device locations (iPhone, iPad, Mac, AirPods)
- Monitoring pet or item movement over time (AirTag patrol routes)
## Method 1: AppleScript + Screenshot (Basic)
### Open FindMy and Navigate
```bash
# Open Find My app
osascript -e 'tell application "FindMy" to activate'
# Wait for it to load
sleep 3
# Take a screenshot of the Find My window
screencapture -w -o /tmp/findmy.png
```
Then use `vision_analyze` to read the screenshot:
```
vision_analyze(image_url="/tmp/findmy.png", question="What devices/items are shown and what are their locations?")
```
### Switch Between Tabs
```bash
# Switch to Devices tab
osascript -e '
tell application "System Events"
tell process "FindMy"
click button "Devices" of toolbar 1 of window 1
end tell
end tell'
# Switch to Items tab (AirTags)
osascript -e '
tell application "System Events"
tell process "FindMy"
click button "Items" of toolbar 1 of window 1
end tell
end tell'
```
## Method 2: Peekaboo UI Automation (Recommended)
If `peekaboo` is installed, use it for more reliable UI interaction:
```bash
# Open Find My
osascript -e 'tell application "FindMy" to activate'
sleep 3
# Capture and annotate the UI
peekaboo see --app "FindMy" --annotate --path /tmp/findmy-ui.png
# Click on a specific device/item by element ID
peekaboo click --on B3 --app "FindMy"
# Capture the detail view
peekaboo image --app "FindMy" --path /tmp/findmy-detail.png
```
Then analyze with vision:
```
vision_analyze(image_url="/tmp/findmy-detail.png", question="What is the location shown for this device/item? Include address and coordinates if visible.")
```
## Workflow: Track AirTag Location Over Time
For monitoring an AirTag (e.g., tracking a cat's patrol route):
```bash
# 1. Open FindMy to Items tab
osascript -e 'tell application "FindMy" to activate'
sleep 3
# 2. Click on the AirTag item (stay on page — AirTag only updates when page is open)
# 3. Periodically capture location
while true; do
screencapture -w -o /tmp/findmy-$(date +%H%M%S).png
sleep 300 # Every 5 minutes
done
```
Analyze each screenshot with vision to extract coordinates, then compile a route.
## Limitations
- FindMy has **no CLI or API** — must use UI automation
- AirTags only update location while the FindMy page is actively displayed
- Location accuracy depends on nearby Apple devices in the FindMy network
- Screen Recording permission required for screenshots
- AppleScript UI automation may break across macOS versions
## Rules
1. Keep FindMy app in the foreground when tracking AirTags (updates stop when minimized)
2. Use `vision_analyze` to read screenshot content — don't try to parse pixels
3. For ongoing tracking, use a cronjob to periodically capture and log locations
4. Respect privacy — only track devices/items the user owns

View File

@@ -0,0 +1,100 @@
---
name: imessage
description: Send and receive iMessages/SMS via the imsg CLI on macOS.
version: 1.0.0
author: Hermes Agent
license: MIT
platforms: [macos]
metadata:
hermes:
tags: [iMessage, SMS, messaging, macOS, Apple]
---
# iMessage
Use `imsg` to read and send iMessage/SMS via macOS Messages.app.
## Prerequisites
- **macOS** with Messages.app signed in
- Install: `brew install steipete/tap/imsg`
- Grant Full Disk Access for terminal (System Settings → Privacy → Full Disk Access)
- Grant Automation permission for Messages.app when prompted
## When to Use
- User asks to send an iMessage or text message
- Reading iMessage conversation history
- Checking recent Messages.app chats
- Sending to phone numbers or Apple IDs
## When NOT to Use
- Telegram/Discord/Slack/WhatsApp messages → use the appropriate gateway channel
- Group chat management (adding/removing members) → not supported
- Bulk/mass messaging → always confirm with user first
## Quick Reference
### List Chats
```bash
imsg chats --limit 10 --json
```
### View History
```bash
# By chat ID
imsg history --chat-id 1 --limit 20 --json
# With attachments info
imsg history --chat-id 1 --limit 20 --attachments --json
```
### Send Messages
```bash
# Text only
imsg send --to "+14155551212" --text "Hello!"
# With attachment
imsg send --to "+14155551212" --text "Check this out" --file /path/to/image.jpg
# Force iMessage or SMS
imsg send --to "+14155551212" --text "Hi" --service imessage
imsg send --to "+14155551212" --text "Hi" --service sms
```
### Watch for New Messages
```bash
imsg watch --chat-id 1 --attachments
```
## Service Options
- `--service imessage` — Force iMessage (requires recipient has iMessage)
- `--service sms` — Force SMS (green bubble)
- `--service auto` — Let Messages.app decide (default)
## Rules
1. **Always confirm recipient and message content** before sending
2. **Never send to unknown numbers** without explicit user approval
3. **Verify file paths** exist before attaching
4. **Don't spam** — rate-limit yourself
## Example Workflow
User: "Text mom that I'll be late"
```bash
# 1. Find mom's chat
imsg chats --limit 20 --json | jq '.[] | select(.displayName | contains("Mom"))'
# 2. Confirm with user: "Found Mom at +1555123456. Send 'I'll be late' via iMessage?"
# 3. Send after confirmation
imsg send --to "+1555123456" --text "I'll be late"
```

View File

@@ -0,0 +1,76 @@
---
name: polymarket
description: Query Polymarket prediction market data — search markets, get prices, orderbooks, and price history. Read-only via public REST APIs, no API key needed.
version: 1.0.0
author: Hermes Agent + Teknium
tags: [polymarket, prediction-markets, market-data, trading]
---
# Polymarket — Prediction Market Data
Query prediction market data from Polymarket using their public REST APIs.
All endpoints are read-only and require zero authentication.
See `references/api-endpoints.md` for the full endpoint reference with curl examples.
## When to Use
- User asks about prediction markets, betting odds, or event probabilities
- User wants to know "what are the odds of X happening?"
- User asks about Polymarket specifically
- User wants market prices, orderbook data, or price history
- User asks to monitor or track prediction market movements
## Key Concepts
- **Events** contain one or more **Markets** (1:many relationship)
- **Markets** are binary outcomes with Yes/No prices between 0.00 and 1.00
- Prices ARE probabilities: price 0.65 means the market thinks 65% likely
- `outcomePrices` field: JSON-encoded array like `["0.80", "0.20"]`
- `clobTokenIds` field: JSON-encoded array of two token IDs [Yes, No] for price/book queries
- `conditionId` field: hex string used for price history queries
- Volume is in USDC (US dollars)
## Three Public APIs
1. **Gamma API** at `gamma-api.polymarket.com` — Discovery, search, browsing
2. **CLOB API** at `clob.polymarket.com` — Real-time prices, orderbooks, history
3. **Data API** at `data-api.polymarket.com` — Trades, open interest
## Typical Workflow
When a user asks about prediction market odds:
1. **Search** using the Gamma API public-search endpoint with their query
2. **Parse** the response — extract events and their nested markets
3. **Present** market question, current prices as percentages, and volume
4. **Deep dive** if asked — use clobTokenIds for orderbook, conditionId for history
## Presenting Results
Format prices as percentages for readability:
- outcomePrices `["0.652", "0.348"]` becomes "Yes: 65.2%, No: 34.8%"
- Always show the market question and probability
- Include volume when available
Example: `"Will X happen?" — 65.2% Yes ($1.2M volume)`
## Parsing Double-Encoded Fields
The Gamma API returns `outcomePrices`, `outcomes`, and `clobTokenIds` as JSON strings
inside JSON responses (double-encoded). When processing with Python, parse them with
`json.loads(market['outcomePrices'])` to get the actual array.
## Rate Limits
Generous — unlikely to hit for normal usage:
- Gamma: 4,000 requests per 10 seconds (general)
- CLOB: 9,000 requests per 10 seconds (general)
- Data: 1,000 requests per 10 seconds (general)
## Limitations
- This skill is read-only — it does not support placing trades
- Trading requires wallet-based crypto authentication (EIP-712 signatures)
- Some new markets may have empty price history
- Geographic restrictions apply to trading but read-only data is globally accessible

View File

@@ -0,0 +1,220 @@
# Polymarket API Endpoints Reference
All endpoints are public REST (GET), return JSON, and need no authentication.
## Gamma API — gamma-api.polymarket.com
### Search Markets
```
GET /public-search?q=QUERY
```
Response structure:
```json
{
"events": [
{
"id": "12345",
"title": "Event title",
"slug": "event-slug",
"volume": 1234567.89,
"markets": [
{
"question": "Will X happen?",
"outcomePrices": "[\"0.65\", \"0.35\"]",
"outcomes": "[\"Yes\", \"No\"]",
"clobTokenIds": "[\"TOKEN_YES\", \"TOKEN_NO\"]",
"conditionId": "0xabc...",
"volume": 500000
}
]
}
],
"pagination": {"hasMore": true, "totalResults": 100}
}
```
### List Events
```
GET /events?limit=N&active=true&closed=false&order=volume&ascending=false
```
Parameters:
- `limit` — max results (default varies)
- `offset` — pagination offset
- `active` — true/false
- `closed` — true/false
- `order` — sort field: `volume`, `createdAt`, `updatedAt`
- `ascending` — true/false
- `tag` — filter by tag slug
- `slug` — get specific event by slug
Response: array of event objects. Each event includes a `markets` array.
Event fields: `id`, `title`, `slug`, `description`, `volume`, `liquidity`,
`openInterest`, `active`, `closed`, `category`, `startDate`, `endDate`,
`markets` (array of market objects).
### List Markets
```
GET /markets?limit=N&active=true&closed=false&order=volume&ascending=false
```
Same filter parameters as events, plus:
- `slug` — get specific market by slug
Market fields: `id`, `question`, `conditionId`, `slug`, `description`,
`outcomes`, `outcomePrices`, `volume`, `liquidity`, `active`, `closed`,
`marketType`, `clobTokenIds`, `endDate`, `category`, `createdAt`.
Important: `outcomePrices`, `outcomes`, and `clobTokenIds` are JSON strings
(double-encoded). Parse with json.loads() in Python.
### List Tags
```
GET /tags
```
Returns array of tag objects: `id`, `label`, `slug`.
Use the `slug` value when filtering events/markets by tag.
---
## CLOB API — clob.polymarket.com
All CLOB price endpoints use `token_id` from the market's `clobTokenIds` field.
Index 0 = Yes outcome, Index 1 = No outcome.
### Current Price
```
GET /price?token_id=TOKEN_ID&side=buy
```
Response: `{"price": "0.650"}`
The `side` parameter: `buy` or `sell`.
### Midpoint Price
```
GET /midpoint?token_id=TOKEN_ID
```
Response: `{"mid": "0.645"}`
### Spread
```
GET /spread?token_id=TOKEN_ID
```
Response: `{"spread": "0.02"}`
### Orderbook
```
GET /book?token_id=TOKEN_ID
```
Response:
```json
{
"market": "condition_id",
"asset_id": "token_id",
"bids": [{"price": "0.64", "size": "500"}, ...],
"asks": [{"price": "0.66", "size": "300"}, ...],
"min_order_size": "5",
"tick_size": "0.01",
"last_trade_price": "0.65"
}
```
Bids and asks are sorted by price. Size is in shares (USDC-denominated).
### Price History
```
GET /prices-history?market=CONDITION_ID&interval=INTERVAL&fidelity=N
```
Parameters:
- `market` — the conditionId (hex string with 0x prefix)
- `interval` — time range: `all`, `1d`, `1w`, `1m`, `3m`, `6m`, `1y`
- `fidelity` — number of data points to return
Response:
```json
{
"history": [
{"t": 1709000000, "p": "0.55"},
{"t": 1709100000, "p": "0.58"}
]
}
```
`t` is Unix timestamp, `p` is price (probability).
Note: Very new markets may return empty history.
### CLOB Markets List
```
GET /markets?limit=N
```
Response:
```json
{
"data": [
{
"condition_id": "0xabc...",
"question": "Will X?",
"tokens": [
{"token_id": "123...", "outcome": "Yes", "price": 0.65},
{"token_id": "456...", "outcome": "No", "price": 0.35}
],
"active": true,
"closed": false
}
],
"next_cursor": "cursor_string",
"limit": 100,
"count": 1000
}
```
---
## Data API — data-api.polymarket.com
### Recent Trades
```
GET /trades?limit=N
GET /trades?market=CONDITION_ID&limit=N
```
Trade fields: `side` (BUY/SELL), `size`, `price`, `timestamp`,
`title`, `slug`, `outcome`, `transactionHash`, `conditionId`.
### Open Interest
```
GET /oi?market=CONDITION_ID
```
---
## Field Cross-Reference
To go from a Gamma market to CLOB data:
1. Get market from Gamma: has `clobTokenIds` and `conditionId`
2. Parse `clobTokenIds` (JSON string): `["YES_TOKEN", "NO_TOKEN"]`
3. Use YES_TOKEN with `/price`, `/book`, `/midpoint`, `/spread`
4. Use `conditionId` with `/prices-history` and Data API endpoints

View File

@@ -0,0 +1,284 @@
#!/usr/bin/env python3
"""Polymarket CLI helper — query prediction market data.
Usage:
python3 polymarket.py search "bitcoin"
python3 polymarket.py trending [--limit 10]
python3 polymarket.py market <slug>
python3 polymarket.py event <slug>
python3 polymarket.py price <token_id>
python3 polymarket.py book <token_id>
python3 polymarket.py history <condition_id> [--interval all] [--fidelity 50]
python3 polymarket.py trades [--limit 10] [--market CONDITION_ID]
"""
import json
import sys
import urllib.request
import urllib.parse
import urllib.error
GAMMA = "https://gamma-api.polymarket.com"
CLOB = "https://clob.polymarket.com"
DATA = "https://data-api.polymarket.com"
def _get(url: str) -> dict | list:
"""GET request, return parsed JSON."""
req = urllib.request.Request(url, headers={"User-Agent": "hermes-agent/1.0"})
try:
with urllib.request.urlopen(req, timeout=15) as resp:
return json.loads(resp.read().decode())
except urllib.error.HTTPError as e:
print(f"HTTP {e.code}: {e.reason}", file=sys.stderr)
sys.exit(1)
except urllib.error.URLError as e:
print(f"Connection error: {e.reason}", file=sys.stderr)
sys.exit(1)
def _parse_json_field(val):
"""Parse double-encoded JSON fields (outcomePrices, outcomes, clobTokenIds)."""
if isinstance(val, str):
try:
return json.loads(val)
except (json.JSONDecodeError, TypeError):
return val
return val
def _fmt_pct(price_str: str) -> str:
"""Format price string as percentage."""
try:
return f"{float(price_str) * 100:.1f}%"
except (ValueError, TypeError):
return price_str
def _fmt_volume(vol) -> str:
"""Format volume as human-readable."""
try:
v = float(vol)
if v >= 1_000_000:
return f"${v / 1_000_000:.1f}M"
if v >= 1_000:
return f"${v / 1_000:.1f}K"
return f"${v:.0f}"
except (ValueError, TypeError):
return str(vol)
def _print_market(m: dict, indent: str = ""):
"""Print a market summary."""
question = m.get("question", "?")
prices = _parse_json_field(m.get("outcomePrices", "[]"))
outcomes = _parse_json_field(m.get("outcomes", "[]"))
vol = _fmt_volume(m.get("volume", 0))
closed = m.get("closed", False)
status = " [CLOSED]" if closed else ""
if isinstance(prices, list) and len(prices) >= 2:
outcome_labels = outcomes if isinstance(outcomes, list) else ["Yes", "No"]
price_str = " / ".join(
f"{outcome_labels[i]}: {_fmt_pct(prices[i])}"
for i in range(min(len(prices), len(outcome_labels)))
)
print(f"{indent}{question}{status}")
print(f"{indent} {price_str} | Volume: {vol}")
else:
print(f"{indent}{question}{status} | Volume: {vol}")
slug = m.get("slug", "")
if slug:
print(f"{indent} slug: {slug}")
def cmd_search(query: str):
"""Search for markets."""
q = urllib.parse.quote(query)
data = _get(f"{GAMMA}/public-search?q={q}")
events = data.get("events", [])
total = data.get("pagination", {}).get("totalResults", len(events))
print(f"Found {total} results for \"{query}\":\n")
for evt in events[:10]:
print(f"=== {evt['title']} ===")
print(f" Volume: {_fmt_volume(evt.get('volume', 0))} | slug: {evt.get('slug', '')}")
markets = evt.get("markets", [])
for m in markets[:5]:
_print_market(m, indent=" ")
if len(markets) > 5:
print(f" ... and {len(markets) - 5} more markets")
print()
def cmd_trending(limit: int = 10):
"""Show trending events by volume."""
events = _get(f"{GAMMA}/events?limit={limit}&active=true&closed=false&order=volume&ascending=false")
print(f"Top {len(events)} trending events:\n")
for i, evt in enumerate(events, 1):
print(f"{i}. {evt['title']}")
print(f" Volume: {_fmt_volume(evt.get('volume', 0))} | Markets: {len(evt.get('markets', []))}")
print(f" slug: {evt.get('slug', '')}")
markets = evt.get("markets", [])
for m in markets[:3]:
_print_market(m, indent=" ")
if len(markets) > 3:
print(f" ... and {len(markets) - 3} more markets")
print()
def cmd_market(slug: str):
"""Get market details by slug."""
markets = _get(f"{GAMMA}/markets?slug={urllib.parse.quote(slug)}")
if not markets:
print(f"No market found with slug: {slug}")
return
m = markets[0]
print(f"Market: {m.get('question', '?')}")
print(f"Status: {'CLOSED' if m.get('closed') else 'ACTIVE'}")
_print_market(m)
print(f"\n conditionId: {m.get('conditionId', 'N/A')}")
tokens = _parse_json_field(m.get("clobTokenIds", "[]"))
if isinstance(tokens, list):
outcomes = _parse_json_field(m.get("outcomes", "[]"))
for i, t in enumerate(tokens):
label = outcomes[i] if isinstance(outcomes, list) and i < len(outcomes) else f"Outcome {i}"
print(f" token ({label}): {t}")
desc = m.get("description", "")
if desc:
print(f"\n Description: {desc[:500]}")
def cmd_event(slug: str):
"""Get event details by slug."""
events = _get(f"{GAMMA}/events?slug={urllib.parse.quote(slug)}")
if not events:
print(f"No event found with slug: {slug}")
return
evt = events[0]
print(f"Event: {evt['title']}")
print(f"Volume: {_fmt_volume(evt.get('volume', 0))}")
print(f"Status: {'CLOSED' if evt.get('closed') else 'ACTIVE'}")
print(f"Markets: {len(evt.get('markets', []))}\n")
for m in evt.get("markets", []):
_print_market(m, indent=" ")
print()
def cmd_price(token_id: str):
"""Get current price for a token."""
buy = _get(f"{CLOB}/price?token_id={token_id}&side=buy")
mid = _get(f"{CLOB}/midpoint?token_id={token_id}")
spread = _get(f"{CLOB}/spread?token_id={token_id}")
print(f"Token: {token_id[:30]}...")
print(f" Buy price: {_fmt_pct(buy.get('price', '?'))}")
print(f" Midpoint: {_fmt_pct(mid.get('mid', '?'))}")
print(f" Spread: {spread.get('spread', '?')}")
def cmd_book(token_id: str):
"""Get orderbook for a token."""
book = _get(f"{CLOB}/book?token_id={token_id}")
bids = book.get("bids", [])
asks = book.get("asks", [])
last = book.get("last_trade_price", "?")
print(f"Orderbook for {token_id[:30]}...")
print(f"Last trade: {_fmt_pct(last)} | Tick size: {book.get('tick_size', '?')}")
print(f"\n Top bids ({len(bids)} total):")
# Show bids sorted by price descending (best bids first)
sorted_bids = sorted(bids, key=lambda x: float(x.get("price", 0)), reverse=True)
for b in sorted_bids[:10]:
print(f" {_fmt_pct(b['price']):>7} | Size: {float(b['size']):>10.2f}")
print(f"\n Top asks ({len(asks)} total):")
sorted_asks = sorted(asks, key=lambda x: float(x.get("price", 0)))
for a in sorted_asks[:10]:
print(f" {_fmt_pct(a['price']):>7} | Size: {float(a['size']):>10.2f}")
def cmd_history(condition_id: str, interval: str = "all", fidelity: int = 50):
"""Get price history for a market."""
data = _get(f"{CLOB}/prices-history?market={condition_id}&interval={interval}&fidelity={fidelity}")
history = data.get("history", [])
if not history:
print("No price history available for this market.")
return
print(f"Price history ({len(history)} points, interval={interval}):\n")
from datetime import datetime, timezone
for pt in history:
ts = datetime.fromtimestamp(pt["t"], tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
price = _fmt_pct(pt["p"])
bar = "" * int(float(pt["p"]) * 40)
print(f" {ts} {price:>7} {bar}")
def cmd_trades(limit: int = 10, market: str = None):
"""Get recent trades."""
url = f"{DATA}/trades?limit={limit}"
if market:
url += f"&market={market}"
trades = _get(url)
if not isinstance(trades, list):
print(f"Unexpected response: {trades}")
return
print(f"Recent trades ({len(trades)}):\n")
for t in trades:
side = t.get("side", "?")
price = _fmt_pct(t.get("price", "?"))
size = t.get("size", "?")
outcome = t.get("outcome", "?")
title = t.get("title", "?")[:50]
ts = t.get("timestamp", "")
print(f" {side:4} {price:>7} x{float(size):>8.2f} [{outcome}] {title}")
def main():
args = sys.argv[1:]
if not args or args[0] in ("-h", "--help", "help"):
print(__doc__)
return
cmd = args[0]
if cmd == "search" and len(args) >= 2:
cmd_search(" ".join(args[1:]))
elif cmd == "trending":
limit = 10
if "--limit" in args:
idx = args.index("--limit")
limit = int(args[idx + 1]) if idx + 1 < len(args) else 10
cmd_trending(limit)
elif cmd == "market" and len(args) >= 2:
cmd_market(args[1])
elif cmd == "event" and len(args) >= 2:
cmd_event(args[1])
elif cmd == "price" and len(args) >= 2:
cmd_price(args[1])
elif cmd == "book" and len(args) >= 2:
cmd_book(args[1])
elif cmd == "history" and len(args) >= 2:
interval = "all"
fidelity = 50
if "--interval" in args:
idx = args.index("--interval")
interval = args[idx + 1] if idx + 1 < len(args) else "all"
if "--fidelity" in args:
idx = args.index("--fidelity")
fidelity = int(args[idx + 1]) if idx + 1 < len(args) else 50
cmd_history(args[1], interval, fidelity)
elif cmd == "trades":
limit = 10
market = None
if "--limit" in args:
idx = args.index("--limit")
limit = int(args[idx + 1]) if idx + 1 < len(args) else 10
if "--market" in args:
idx = args.index("--market")
market = args[idx + 1] if idx + 1 < len(args) else None
cmd_trades(limit, market)
else:
print(f"Unknown command: {cmd}")
print(__doc__)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,335 @@
---
name: huggingface-accelerate
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [accelerate, torch, transformers]
metadata:
hermes:
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
---
# HuggingFace Accelerate - Unified Distributed Training
## Quick start
Accelerate simplifies distributed training to 4 lines of code.
**Installation**:
```bash
pip install accelerate
```
**Convert PyTorch script** (4 lines):
```python
import torch
+ from accelerate import Accelerator
+ accelerator = Accelerator()
model = torch.nn.Transformer()
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset)
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
```
**Run** (single command):
```bash
accelerate launch train.py
```
## Common workflows
### Workflow 1: From single GPU to multi-GPU
**Original script**:
```python
# train.py
import torch
model = torch.nn.Linear(10, 2).to('cuda')
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
for epoch in range(10):
for batch in dataloader:
batch = batch.to('cuda')
optimizer.zero_grad()
loss = model(batch).mean()
loss.backward()
optimizer.step()
```
**With Accelerate** (4 lines added):
```python
# train.py
import torch
from accelerate import Accelerator # +1
accelerator = Accelerator() # +2
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
for epoch in range(10):
for batch in dataloader:
# No .to('cuda') needed - automatic!
optimizer.zero_grad()
loss = model(batch).mean()
accelerator.backward(loss) # +4
optimizer.step()
```
**Configure** (interactive):
```bash
accelerate config
```
**Questions**:
- Which machine? (single/multi GPU/TPU/CPU)
- How many machines? (1)
- Mixed precision? (no/fp16/bf16/fp8)
- DeepSpeed? (no/yes)
**Launch** (works on any setup):
```bash
# Single GPU
accelerate launch train.py
# Multi-GPU (8 GPUs)
accelerate launch --multi_gpu --num_processes 8 train.py
# Multi-node
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 0 \
--main_process_ip $MASTER_ADDR \
train.py
```
### Workflow 2: Mixed precision training
**Enable FP16/BF16**:
```python
from accelerate import Accelerator
# FP16 (with gradient scaling)
accelerator = Accelerator(mixed_precision='fp16')
# BF16 (no scaling, more stable)
accelerator = Accelerator(mixed_precision='bf16')
# FP8 (H100+)
accelerator = Accelerator(mixed_precision='fp8')
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
# Everything else is automatic!
for batch in dataloader:
with accelerator.autocast(): # Optional, done automatically
loss = model(batch)
accelerator.backward(loss)
```
### Workflow 3: DeepSpeed ZeRO integration
**Enable DeepSpeed ZeRO-2**:
```python
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision='bf16',
deepspeed_plugin={
"zero_stage": 2, # ZeRO-2
"offload_optimizer": False,
"gradient_accumulation_steps": 4
}
)
# Same code as before!
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
**Or via config**:
```bash
accelerate config
# Select: DeepSpeed → ZeRO-2
```
**deepspeed_config.json**:
```json
{
"fp16": {"enabled": false},
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "cpu"},
"allgather_bucket_size": 5e8,
"reduce_bucket_size": 5e8
}
}
```
**Launch**:
```bash
accelerate launch --config_file deepspeed_config.json train.py
```
### Workflow 4: FSDP (Fully Sharded Data Parallel)
**Enable FSDP**:
```python
from accelerate import Accelerator, FullyShardedDataParallelPlugin
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
cpu_offload=False
)
accelerator = Accelerator(
mixed_precision='bf16',
fsdp_plugin=fsdp_plugin
)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
**Or via config**:
```bash
accelerate config
# Select: FSDP → Full Shard → No CPU Offload
```
### Workflow 5: Gradient accumulation
**Accumulate gradients**:
```python
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=4)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
with accelerator.accumulate(model): # Handles accumulation
optimizer.zero_grad()
loss = model(batch)
accelerator.backward(loss)
optimizer.step()
```
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
## When to use vs alternatives
**Use Accelerate when**:
- Want simplest distributed training
- Need single script for any hardware
- Use HuggingFace ecosystem
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
- Need quick prototyping
**Key advantages**:
- **4 lines**: Minimal code changes
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
- **Automatic**: Device placement, mixed precision, sharding
- **Interactive config**: No manual launcher setup
- **Single launch**: Works everywhere
**Use alternatives instead**:
- **PyTorch Lightning**: Need callbacks, high-level abstractions
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
- **DeepSpeed**: Direct API control, advanced features
- **Raw DDP**: Maximum control, minimal abstraction
## Common issues
**Issue: Wrong device placement**
Don't manually move to device:
```python
# WRONG
batch = batch.to('cuda')
# CORRECT
# Accelerate handles it automatically after prepare()
```
**Issue: Gradient accumulation not working**
Use context manager:
```python
# CORRECT
with accelerator.accumulate(model):
optimizer.zero_grad()
accelerator.backward(loss)
optimizer.step()
```
**Issue: Checkpointing in distributed**
Use accelerator methods:
```python
# Save only on main process
if accelerator.is_main_process:
accelerator.save_state('checkpoint/')
# Load on all processes
accelerator.load_state('checkpoint/')
```
**Issue: Different results with FSDP**
Ensure same random seed:
```python
from accelerate.utils import set_seed
set_seed(42)
```
## Advanced topics
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
## Hardware requirements
- **CPU**: Works (slow)
- **Single GPU**: Works
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
- **TPU**: Supported
- **Apple MPS**: Supported
**Launcher requirements**:
- **DDP**: `torch.distributed.run` (built-in)
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
- **FSDP**: PyTorch 1.12+ (built-in)
- **Megatron**: Custom setup
## Resources
- Docs: https://huggingface.co/docs/accelerate
- GitHub: https://github.com/huggingface/accelerate
- Version: 1.11.0+
- Tutorial: "Accelerate your scripts"
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries

View File

@@ -0,0 +1,453 @@
# Custom Plugins for Accelerate
## Overview
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
## Plugin Architecture
### Base Plugin Structure
```python
from accelerate.utils import DistributedDataParallelKwargs
from dataclasses import dataclass
@dataclass
class CustomPlugin:
"""Custom training plugin."""
# Plugin configuration
param1: int = 1
param2: str = "default"
def __post_init__(self):
# Validation logic
if self.param1 < 1:
raise ValueError("param1 must be >= 1")
```
### Using Custom Plugin
```python
from accelerate import Accelerator
# Create plugin
custom_plugin = CustomPlugin(param1=4, param2="value")
# Pass to Accelerator
accelerator = Accelerator(
custom_plugin=custom_plugin # Not a real parameter, example only
)
```
## Built-In Plugin Examples
### 1. GradScalerKwargs (FP16 Configuration)
```python
from accelerate.utils import GradScalerKwargs
# Configure gradient scaler for FP16
scaler_kwargs = GradScalerKwargs(
init_scale=2.**16, # Initial loss scale
growth_factor=2.0, # Scale growth rate
backoff_factor=0.5, # Scale backoff rate
growth_interval=2000, # Steps between scale increases
enabled=True # Enable scaler
)
accelerator = Accelerator(
mixed_precision='fp16',
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
)
```
**Use case**: Fine-tune FP16 gradient scaling behavior
### 2. DistributedDataParallelKwargs
```python
from accelerate.utils import DistributedDataParallelKwargs
# Configure DDP behavior
ddp_kwargs = DistributedDataParallelKwargs(
bucket_cap_mb=25, # Gradient bucketing size
find_unused_parameters=False, # Find unused params (slower)
check_reduction=False, # Check gradient reduction
gradient_as_bucket_view=True, # Memory optimization
static_graph=False # Static computation graph
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs]
)
```
**Use case**: Optimize DDP performance for specific models
### 3. FP8RecipeKwargs (H100 FP8)
```python
from accelerate.utils import FP8RecipeKwargs
# Configure FP8 training (H100)
fp8_recipe = FP8RecipeKwargs(
backend="te", # TransformerEngine backend
margin=0, # Scaling margin
interval=1, # Scaling interval
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
amax_history_len=1024, # AMAX history length
amax_compute_algo="max" # AMAX computation algorithm
)
accelerator = Accelerator(
mixed_precision='fp8',
kwargs_handlers=[fp8_recipe]
)
```
**Use case**: Ultra-fast training on H100 GPUs
## Custom DeepSpeed Configuration
### ZeRO-3 with CPU Offload
```python
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
# Custom DeepSpeed config
ds_plugin = DeepSpeedPlugin(
zero_stage=3, # ZeRO-3
offload_optimizer_device="cpu", # CPU offload optimizer
offload_param_device="cpu", # CPU offload parameters
zero3_init_flag=True, # ZeRO-3 initialization
zero3_save_16bit_model=True, # Save FP16 weights
)
accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
mixed_precision='bf16'
)
```
### ZeRO-2 with NVMe Offload
```python
ds_plugin = DeepSpeedPlugin(
zero_stage=2,
offload_optimizer_device="nvme", # NVMe offload
offload_param_device="nvme",
nvme_path="/local_nvme", # NVMe mount path
)
```
### Custom JSON Config
```python
import json
# Load custom DeepSpeed config
with open('deepspeed_config.json', 'r') as f:
ds_config = json.load(f)
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
```
**Example config** (`deepspeed_config.json`):
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"steps_per_print": 100,
"wall_clock_breakdown": false
}
```
## Custom FSDP Configuration
### FSDP with Custom Auto-Wrap Policy
```python
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools
# Custom wrap policy (size-based)
wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=1e6 # Wrap layers with 1M+ params
)
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
mixed_precision_policy=None, # Use Accelerator's mixed precision
auto_wrap_policy=wrap_policy, # Custom wrapping
cpu_offload=False,
ignored_modules=None, # Modules to not wrap
state_dict_type="FULL_STATE_DICT", # Save format
optim_state_dict_config=None,
limit_all_gathers=False,
use_orig_params=True, # Use original param shapes
)
accelerator = Accelerator(
fsdp_plugin=fsdp_plugin,
mixed_precision='bf16'
)
```
### FSDP with Transformer Auto-Wrap
```python
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
# Wrap at transformer block level
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy
)
```
## Creating Custom Training Strategy
### Example: Custom Gradient Accumulation
```python
from accelerate import Accelerator
class CustomGradientAccumulation:
def __init__(self, steps=4, adaptive=False):
self.steps = steps
self.adaptive = adaptive
self.current_step = 0
def should_sync(self, loss):
"""Decide whether to sync gradients."""
self.current_step += 1
# Adaptive: sync on high loss
if self.adaptive and loss > threshold:
self.current_step = 0
return True
# Regular: sync every N steps
if self.current_step >= self.steps:
self.current_step = 0
return True
return False
# Usage
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
accelerator = Accelerator()
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
# Scale loss
loss = loss / custom_accum.steps
accelerator.backward(loss)
# Conditional sync
if custom_accum.should_sync(loss.item()):
optimizer.step()
optimizer.zero_grad()
```
### Example: Custom Mixed Precision
```python
import torch
class CustomMixedPrecision:
"""Custom mixed precision with dynamic loss scaling."""
def __init__(self, init_scale=2**16, scale_window=2000):
self.scaler = torch.cuda.amp.GradScaler(
init_scale=init_scale,
growth_interval=scale_window
)
self.scale_history = []
def scale_loss(self, loss):
"""Scale loss for backward."""
return self.scaler.scale(loss)
def unscale_and_clip(self, optimizer, max_norm=1.0):
"""Unscale gradients and clip."""
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
optimizer.param_groups[0]['params'],
max_norm
)
def step(self, optimizer):
"""Optimizer step with scaler update."""
scale_before = self.scaler.get_scale()
self.scaler.step(optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
# Track scale changes
if scale_before != scale_after:
self.scale_history.append(scale_after)
# Usage
custom_mp = CustomMixedPrecision()
for batch in dataloader:
with torch.cuda.amp.autocast(dtype=torch.float16):
loss = model(**batch).loss
scaled_loss = custom_mp.scale_loss(loss)
scaled_loss.backward()
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
custom_mp.step(optimizer)
optimizer.zero_grad()
```
## Advanced: Custom Distributed Backend
### Custom AllReduce Strategy
```python
import torch.distributed as dist
class CustomAllReduce:
"""Custom all-reduce with compression."""
def __init__(self, compression_ratio=0.1):
self.compression_ratio = compression_ratio
def compress_gradients(self, tensor):
"""Top-k gradient compression."""
k = int(tensor.numel() * self.compression_ratio)
values, indices = torch.topk(tensor.abs().view(-1), k)
return values, indices
def all_reduce_compressed(self, tensor):
"""All-reduce with gradient compression."""
# Compress
values, indices = self.compress_gradients(tensor)
# All-reduce compressed gradients
dist.all_reduce(values, op=dist.ReduceOp.SUM)
# Decompress
tensor_compressed = torch.zeros_like(tensor).view(-1)
tensor_compressed[indices] = values / dist.get_world_size()
return tensor_compressed.view_as(tensor)
# Usage in training loop
custom_ar = CustomAllReduce(compression_ratio=0.1)
for batch in dataloader:
loss = model(**batch).loss
loss.backward()
# Custom all-reduce
for param in model.parameters():
if param.grad is not None:
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
optimizer.step()
optimizer.zero_grad()
```
## Plugin Best Practices
### 1. Validation in `__post_init__`
```python
@dataclass
class CustomPlugin:
learning_rate: float = 1e-3
warmup_steps: int = 1000
def __post_init__(self):
# Validate parameters
if self.learning_rate <= 0:
raise ValueError("learning_rate must be positive")
if self.warmup_steps < 0:
raise ValueError("warmup_steps must be non-negative")
# Compute derived values
self.min_lr = self.learning_rate * 0.1
```
### 2. Compatibility Checks
```python
@dataclass
class CustomPlugin:
feature_enabled: bool = True
def is_compatible(self, accelerator):
"""Check if plugin is compatible with accelerator config."""
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
raise ValueError("Custom plugin not compatible with FP8")
return True
```
### 3. State Management
```python
@dataclass
class CustomPlugin:
counter: int = 0
history: list = None
def __post_init__(self):
if self.history is None:
self.history = []
def update_state(self, value):
"""Update plugin state during training."""
self.counter += 1
self.history.append(value)
```
## Resources
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu

View File

@@ -0,0 +1,489 @@
# Megatron Integration with Accelerate
## Overview
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
**Megatron capabilities**:
- **Tensor Parallelism (TP)**: Split layers across GPUs
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
- **Data Parallelism (DP)**: Replicate model across GPU groups
- **Sequence Parallelism**: Split sequences for long contexts
## Setup
### Install Megatron-LM
```bash
# Clone Megatron-LM repository
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
pip install -e .
# Install Apex (NVIDIA optimizations)
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
```
### Accelerate Configuration
```bash
accelerate config
```
**Questions**:
```
In which compute environment are you running?
> This machine
Which type of machine are you using?
> Multi-GPU
How many different machines will you use?
> 1
Do you want to use DeepSpeed/FSDP?
> No
Do you want to use Megatron-LM?
> Yes
What is the Tensor Parallelism degree? [1-8]
> 2
Do you want to enable Sequence Parallelism?
> No
What is the Pipeline Parallelism degree? [1-8]
> 2
What is the Data Parallelism degree? [1-8]
> 2
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
> SELECTIVE
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
> SEQUENTIAL
```
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MEGATRON_LM
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
megatron_lm_config:
megatron_lm_gradient_clipping: 1.0
megatron_lm_learning_rate_decay_iters: 320000
megatron_lm_num_micro_batches: 1
megatron_lm_pp_degree: 2
megatron_lm_recompute_activations: true
megatron_lm_sequence_parallelism: false
megatron_lm_tp_degree: 2
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
## Parallelism Strategies
### Tensor Parallelism (TP)
**Splits each transformer layer across GPUs**:
```python
# Layer split across 2 GPUs
# GPU 0: First half of attention heads
# GPU 1: Second half of attention heads
# Each GPU computes partial outputs
# All-reduce combines results
```
**TP degree recommendations**:
- **TP=1**: No tensor parallelism (single GPU per layer)
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
**Benefits**:
- Reduces memory per GPU
- All-reduce communication (fast)
**Drawbacks**:
- Requires fast inter-GPU bandwidth (NVLink)
- Communication overhead per layer
### Pipeline Parallelism (PP)
**Splits model depth across GPUs**:
```python
# 12-layer model, PP=4
# GPU 0: Layers 0-2
# GPU 1: Layers 3-5
# GPU 2: Layers 6-8
# GPU 3: Layers 9-11
```
**PP degree recommendations**:
- **PP=1**: No pipeline parallelism
- **PP=2**: 2 pipeline stages (good for 20-40B models)
- **PP=4**: 4 pipeline stages (good for 70B+ models)
- **PP=8**: 8 pipeline stages (good for 175B+ models)
**Benefits**:
- Linear memory reduction (4× PP = 4× less memory)
- Works across nodes (slower interconnect OK)
**Drawbacks**:
- Pipeline bubbles (idle time)
- Requires micro-batching
### Data Parallelism (DP)
**Replicates model across GPU groups**:
```python
# 8 GPUs, TP=2, PP=2, DP=2
# Group 0 (GPUs 0-3): Full model replica
# Group 1 (GPUs 4-7): Full model replica
```
**DP degree**:
- `DP = total_gpus / (TP × PP)`
- Example: 8 GPUs, TP=2, PP=2 → DP=2
**Benefits**:
- Increases throughput
- Scales batch size
### Sequence Parallelism
**Splits long sequences across GPUs** (extends TP):
```python
# 8K sequence, TP=2, Sequence Parallel=True
# GPU 0: Tokens 0-4095
# GPU 1: Tokens 4096-8191
```
**Benefits**:
- Enables very long sequences (100K+ tokens)
- Reduces activation memory
**Requirements**:
- Must use with TP > 1
- RoPE/ALiBi position encodings work best
## Accelerate Code Example
### Basic Setup
```python
from accelerate import Accelerator
from accelerate.utils import MegatronLMPlugin
# Configure Megatron
megatron_plugin = MegatronLMPlugin(
tp_degree=2, # Tensor parallelism degree
pp_degree=2, # Pipeline parallelism degree
num_micro_batches=4, # Micro-batches for pipeline
gradient_clipping=1.0, # Gradient clipping value
sequence_parallelism=False, # Enable sequence parallelism
recompute_activations=True, # Activation checkpointing
use_distributed_optimizer=True, # Distributed optimizer
custom_prepare_model_function=None, # Custom model prep
)
# Initialize accelerator
accelerator = Accelerator(
mixed_precision='bf16',
megatron_lm_plugin=megatron_plugin
)
# Prepare model and optimizer
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
# Training loop (same as DDP!)
for batch in train_dataloader:
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
```
### Full Training Script
```python
import torch
from accelerate import Accelerator
from accelerate.utils import MegatronLMPlugin
from transformers import GPT2Config, GPT2LMHeadModel
def main():
# Megatron configuration
megatron_plugin = MegatronLMPlugin(
tp_degree=2,
pp_degree=2,
num_micro_batches=4,
gradient_clipping=1.0,
)
accelerator = Accelerator(
mixed_precision='bf16',
gradient_accumulation_steps=8,
megatron_lm_plugin=megatron_plugin
)
# Model
config = GPT2Config(
n_layer=24,
n_head=16,
n_embd=1024,
)
model = GPT2LMHeadModel(config)
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
# Prepare
model, optimizer, train_loader = accelerator.prepare(
model, optimizer, train_loader
)
# Training loop
for epoch in range(num_epochs):
for batch in train_loader:
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Save checkpoint
accelerator.wait_for_everyone()
accelerator.save_state(f'checkpoint-epoch-{epoch}')
if __name__ == '__main__':
main()
```
### Launch Command
```bash
# 8 GPUs, TP=2, PP=2, DP=2
accelerate launch --multi_gpu --num_processes 8 train.py
# Multi-node (2 nodes, 8 GPUs each)
# Node 0
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 0 \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
train.py
# Node 1
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 1 \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
train.py
```
## Activation Checkpointing
**Reduces memory by recomputing activations**:
```python
megatron_plugin = MegatronLMPlugin(
recompute_activations=True, # Enable checkpointing
checkpoint_num_layers=1, # Checkpoint every N layers
distribute_checkpointed_activations=True, # Distribute across TP
partition_activations=True, # Partition in PP
check_for_nan_in_loss_and_grad=True, # Stability check
)
```
**Strategies**:
- `SELECTIVE`: Checkpoint transformer blocks only
- `FULL`: Checkpoint all layers
- `NONE`: No checkpointing
**Memory savings**: 30-50% with 10-15% slowdown
## Distributed Optimizer
**Shards optimizer state across DP ranks**:
```python
megatron_plugin = MegatronLMPlugin(
use_distributed_optimizer=True, # Enable sharded optimizer
)
```
**Benefits**:
- Reduces optimizer memory by DP degree
- Example: DP=4 → 4× less optimizer memory per GPU
**Compatible with**:
- AdamW, Adam, SGD
- Mixed precision training
## Performance Tuning
### Micro-Batch Size
```python
# Pipeline parallelism requires micro-batching
megatron_plugin = MegatronLMPlugin(
pp_degree=4,
num_micro_batches=16, # 16 micro-batches per pipeline
)
# Effective batch = num_micro_batches × micro_batch_size × DP
# Example: 16 × 2 × 4 = 128
```
**Recommendations**:
- More micro-batches → less pipeline bubble
- Typical: 4-16 micro-batches
### Sequence Length
```python
# For long sequences, enable sequence parallelism
megatron_plugin = MegatronLMPlugin(
tp_degree=4,
sequence_parallelism=True, # Required: TP > 1
)
# Enables sequences up to TP × normal limit
# Example: TP=4, 8K normal → 32K with sequence parallel
```
### GPU Topology
**NVLink required for TP**:
```bash
# Check NVLink topology
nvidia-smi topo -m
# Good topology (NVLink between all GPUs)
# GPU0 - GPU1: NV12 (fast)
# GPU0 - GPU2: NV12 (fast)
# Bad topology (PCIe only)
# GPU0 - GPU4: PHB (slow, avoid TP across these)
```
**Recommendations**:
- **TP**: Within same node (NVLink)
- **PP**: Across nodes (slower interconnect OK)
- **DP**: Any topology
## Model Size Guidelines
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|------------|------|----|----|----|--------------|
| 7B | 8 | 1 | 1 | 8 | 1 |
| 13B | 8 | 2 | 1 | 4 | 1 |
| 20B | 16 | 4 | 1 | 4 | 1 |
| 40B | 32 | 4 | 2 | 4 | 4 |
| 70B | 64 | 8 | 2 | 4 | 8 |
| 175B | 128 | 8 | 4 | 4 | 16 |
**Assumptions**: BF16, 2K sequence length, A100 80GB
## Checkpointing
### Save Checkpoint
```python
# Save full model state
accelerator.save_state('checkpoint-1000')
# Megatron saves separate files per rank
# checkpoint-1000/
# pytorch_model_tp_0_pp_0.bin
# pytorch_model_tp_0_pp_1.bin
# pytorch_model_tp_1_pp_0.bin
# pytorch_model_tp_1_pp_1.bin
# optimizer_tp_0_pp_0.bin
# ...
```
### Load Checkpoint
```python
# Resume training
accelerator.load_state('checkpoint-1000')
# Automatically loads correct shard per rank
```
### Convert to Standard PyTorch
```bash
# Merge Megatron checkpoint to single file
python merge_megatron_checkpoint.py \
--checkpoint-dir checkpoint-1000 \
--output pytorch_model.bin
```
## Common Issues
### Issue: OOM with Pipeline Parallelism
**Solution**: Increase micro-batches
```python
megatron_plugin = MegatronLMPlugin(
pp_degree=4,
num_micro_batches=16, # Increase from 4
)
```
### Issue: Slow Training
**Check 1**: Pipeline bubbles (PP too high)
```python
# Reduce PP, increase TP
tp_degree=4 # Increase
pp_degree=2 # Decrease
```
**Check 2**: Micro-batch size too small
```python
num_micro_batches=8 # Increase
```
### Issue: NVLink Not Detected
```bash
# Verify NVLink
nvidia-smi nvlink -s
# If no NVLink, avoid TP > 1
# Use PP or DP instead
```
## Resources
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
- NVIDIA Apex: https://github.com/NVIDIA/apex

View File

@@ -0,0 +1,525 @@
# Accelerate Performance Tuning
## Profiling
### Basic Profiling
```python
from accelerate import Accelerator
import time
accelerator = Accelerator()
# Warmup
for _ in range(10):
batch = next(iter(dataloader))
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Profile training loop
start = time.time()
total_batches = 100
for i, batch in enumerate(dataloader):
if i >= total_batches:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone() # Sync all processes
elapsed = time.time() - start
# Metrics
batches_per_sec = total_batches / elapsed
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
print(f"Batches/sec: {batches_per_sec:.2f}")
```
### PyTorch Profiler Integration
```python
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for i, batch in enumerate(dataloader):
if i >= 10: # Profile first 10 batches
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Print profiling results
print(prof.key_averages().table(
sort_by="cuda_time_total", row_limit=20
))
# Export to Chrome tracing
prof.export_chrome_trace("trace.json")
# View at chrome://tracing
```
## Memory Optimization
### 1. Gradient Accumulation
**Problem**: Large batch size causes OOM
**Solution**: Accumulate gradients across micro-batches
```python
accelerator = Accelerator(gradient_accumulation_steps=8)
# Effective batch = batch_size × accumulation_steps × num_gpus
# Example: 4 × 8 × 8 = 256
for batch in dataloader:
with accelerator.accumulate(model): # Handles accumulation logic
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
```
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
### 2. Gradient Checkpointing
**Enable in model**:
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
use_cache=False # Required for gradient checkpointing
)
# Enable checkpointing
model.gradient_checkpointing_enable()
# Prepare with Accelerate
model = accelerator.prepare(model)
```
**Memory savings**: 30-50% with 10-15% slowdown
### 3. Mixed Precision
**BF16 (A100/H100)**:
```python
accelerator = Accelerator(mixed_precision='bf16')
# Automatic mixed precision
for batch in dataloader:
outputs = model(**batch) # Forward in BF16
loss = outputs.loss
accelerator.backward(loss) # Backward in FP32
optimizer.step()
```
**FP16 (V100, older GPUs)**:
```python
from accelerate.utils import GradScalerKwargs
scaler_kwargs = GradScalerKwargs(
init_scale=2.**16,
growth_interval=2000
)
accelerator = Accelerator(
mixed_precision='fp16',
kwargs_handlers=[scaler_kwargs]
)
```
**Memory savings**: 50% compared to FP32
### 4. CPU Offloading (DeepSpeed)
```python
from accelerate.utils import DeepSpeedPlugin
ds_plugin = DeepSpeedPlugin(
zero_stage=3,
offload_optimizer_device="cpu", # Offload optimizer to CPU
offload_param_device="cpu", # Offload parameters to CPU
)
accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
mixed_precision='bf16'
)
```
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
**Trade-off**: 20-30% slower due to CPU-GPU transfers
### 5. Flash Attention
```python
# Install flash-attn
# pip install flash-attn
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
attn_implementation="flash_attention_2" # Enable Flash Attention 2
)
model = accelerator.prepare(model)
```
**Memory savings**: 50% for attention, 2× faster
**Requirements**: A100/H100, sequence length must be multiple of 128
## Communication Optimization
### 1. Gradient Bucketing (DDP)
```python
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(
bucket_cap_mb=25, # Bucket size for gradient reduction
gradient_as_bucket_view=True, # Reduce memory copies
static_graph=False # Set True if model doesn't change
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
```
**Recommended bucket sizes**:
- Small models (<1B): 25 MB
- Medium models (1-10B): 50-100 MB
- Large models (>10B): 100-200 MB
### 2. Find Unused Parameters
```python
# Only enable if model has unused parameters (slower!)
ddp_kwargs = DistributedDataParallelKwargs(
find_unused_parameters=True
)
```
**Use case**: Models with conditional branches (e.g., mixture of experts)
**Cost**: 10-20% slower
### 3. NCCL Tuning
```bash
# Set environment variables before launch
export NCCL_DEBUG=INFO # Debug info
export NCCL_IB_DISABLE=0 # Enable InfiniBand
export NCCL_SOCKET_IFNAME=eth0 # Network interface
export NCCL_P2P_LEVEL=NVL # Use NVLink
accelerate launch train.py
```
**NCCL_P2P_LEVEL options**:
- `NVL`: NVLink (fastest, within node)
- `PIX`: PCIe (fast, within node)
- `PHB`: PCIe host bridge (slow, cross-node)
## Data Loading Optimization
### 1. DataLoader Workers
```python
from torch.utils.data import DataLoader
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=4, # Parallel data loading
pin_memory=True, # Pin memory for faster GPU transfer
prefetch_factor=2, # Prefetch batches per worker
persistent_workers=True # Keep workers alive between epochs
)
train_loader = accelerator.prepare(train_loader)
```
**Recommendations**:
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
- `pin_memory`: Always True for GPU training
- `prefetch_factor`: 2-4 (higher for slow data loading)
### 2. Data Preprocessing
```python
from datasets import load_dataset
# Bad: Preprocess during training (slow)
dataset = load_dataset("openwebtext")
for batch in dataset:
tokens = tokenizer(batch['text']) # Slow!
...
# Good: Preprocess once, save
dataset = load_dataset("openwebtext")
tokenized = dataset.map(
lambda x: tokenizer(x['text']),
batched=True,
num_proc=8, # Parallel preprocessing
remove_columns=['text']
)
tokenized.save_to_disk("preprocessed_data")
# Load preprocessed
dataset = load_from_disk("preprocessed_data")
```
### 3. Faster Tokenization
```python
import os
# Enable Rust-based tokenizers (10× faster)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"gpt2",
use_fast=True # Use fast Rust tokenizer
)
```
## Compilation (PyTorch 2.0+)
### Compile Model
```python
import torch
# Compile model for faster execution
model = torch.compile(
model,
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
fullgraph=False, # Compile entire graph (stricter)
dynamic=True # Support dynamic shapes
)
model = accelerator.prepare(model)
```
**Speedup**: 10-50% depending on model
**Compilation modes**:
- `default`: Balanced (best for most cases)
- `reduce-overhead`: Min overhead (best for small batches)
- `max-autotune`: Max performance (slow compile, best for production)
### Compilation Best Practices
```python
# Bad: Compile after prepare (won't work)
model = accelerator.prepare(model)
model = torch.compile(model) # Error!
# Good: Compile before prepare
model = torch.compile(model)
model = accelerator.prepare(model)
# Training loop
for batch in dataloader:
# First iteration: slow (compilation)
# Subsequent iterations: fast (compiled)
outputs = model(**batch)
...
```
## Benchmarking Different Strategies
### Script Template
```python
import time
import torch
from accelerate import Accelerator
def benchmark_strategy(strategy_name, accelerator_kwargs):
"""Benchmark a specific training strategy."""
accelerator = Accelerator(**accelerator_kwargs)
# Setup
model = create_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
dataloader = create_dataloader()
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
# Warmup
for i, batch in enumerate(dataloader):
if i >= 10:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Benchmark
accelerator.wait_for_everyone()
torch.cuda.synchronize()
start = time.time()
num_batches = 100
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone()
torch.cuda.synchronize()
elapsed = time.time() - start
# Metrics
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
if accelerator.is_main_process:
print(f"\n{strategy_name}:")
print(f" Throughput: {throughput:.2f} samples/sec")
print(f" Memory: {memory_used:.2f} GB")
print(f" Time: {elapsed:.2f} sec")
torch.cuda.reset_peak_memory_stats()
# Benchmark different strategies
strategies = [
("DDP + FP32", {}),
("DDP + BF16", {"mixed_precision": "bf16"}),
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
("FSDP", {"fsdp_plugin": fsdp_plugin}),
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
]
for name, kwargs in strategies:
benchmark_strategy(name, kwargs)
```
## Performance Checklist
**Before training**:
- [ ] Use BF16/FP16 mixed precision
- [ ] Enable gradient checkpointing (if OOM)
- [ ] Set appropriate `num_workers` (2-4 per GPU)
- [ ] Enable `pin_memory=True`
- [ ] Preprocess data once, not during training
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
**For large models**:
- [ ] Use FSDP or DeepSpeed ZeRO-3
- [ ] Enable CPU offloading (if still OOM)
- [ ] Use Flash Attention
- [ ] Increase gradient accumulation
**For multi-node**:
- [ ] Check network topology (InfiniBand > Ethernet)
- [ ] Tune NCCL settings
- [ ] Use larger bucket sizes for DDP
- [ ] Verify NVLink for tensor parallelism
**Profiling**:
- [ ] Profile first 10-100 batches
- [ ] Check GPU utilization (`nvidia-smi dmon`)
- [ ] Check data loading time (should be <5% of iteration)
- [ ] Identify communication bottlenecks
## Common Performance Issues
### Issue: Low GPU Utilization (<80%)
**Cause 1**: Data loading bottleneck
```python
# Solution: Increase workers and prefetch
num_workers=8
prefetch_factor=4
```
**Cause 2**: Small batch size
```python
# Solution: Increase batch size or use gradient accumulation
batch_size=32 # Increase
gradient_accumulation_steps=4 # Or accumulate
```
### Issue: High Memory Usage
**Solution 1**: Gradient checkpointing
```python
model.gradient_checkpointing_enable()
```
**Solution 2**: Reduce batch size, increase accumulation
```python
batch_size=8 # Reduce from 32
gradient_accumulation_steps=16 # Maintain effective batch
```
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
```python
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
```
### Issue: Slow Multi-GPU Training
**Cause**: Communication bottleneck
**Check 1**: Gradient bucket size
```python
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
```
**Check 2**: NCCL settings
```bash
export NCCL_DEBUG=INFO
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
```
**Check 3**: Network bandwidth
```bash
# Test inter-GPU bandwidth
nvidia-smi nvlink -s
```
## Resources
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
- Flash Attention: https://github.com/Dao-AILab/flash-attention

View File

@@ -0,0 +1,567 @@
---
name: audiocraft-audio-generation
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
metadata:
hermes:
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
---
# AudioCraft: Audio Generation
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
## When to use AudioCraft
**Use AudioCraft when:**
- Need to generate music from text descriptions
- Creating sound effects and environmental audio
- Building music generation applications
- Need melody-conditioned music generation
- Want stereo audio output
- Require controllable music generation with style transfer
**Key features:**
- **MusicGen**: Text-to-music generation with melody conditioning
- **AudioGen**: Text-to-sound effects generation
- **EnCodec**: High-fidelity neural audio codec
- **Multiple model sizes**: Small (300M) to Large (3.3B)
- **Stereo support**: Full stereo audio generation
- **Style conditioning**: MusicGen-Style for reference-based generation
**Use alternatives instead:**
- **Stable Audio**: For longer commercial music generation
- **Bark**: For text-to-speech with music/sound effects
- **Riffusion**: For spectogram-based music generation
- **OpenAI Jukebox**: For raw audio generation with lyrics
## Quick start
### Installation
```bash
# From PyPI
pip install audiocraft
# From GitHub (latest)
pip install git+https://github.com/facebookresearch/audiocraft.git
# Or use HuggingFace Transformers
pip install transformers torch torchaudio
```
### Basic text-to-music (AudioCraft)
```python
import torchaudio
from audiocraft.models import MusicGen
# Load model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Set generation parameters
model.set_generation_params(
duration=8, # seconds
top_k=250,
temperature=1.0
)
# Generate from text
descriptions = ["happy upbeat electronic dance music with synths"]
wav = model.generate(descriptions)
# Save audio
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
```
### Using HuggingFace Transformers
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
# Load model and processor
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
model.to("cuda")
# Generate music
inputs = processor(
text=["80s pop track with bassy drums and synth"],
padding=True,
return_tensors="pt"
).to("cuda")
audio_values = model.generate(
**inputs,
do_sample=True,
guidance_scale=3,
max_new_tokens=256
)
# Save
sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
```
### Text-to-sound with AudioGen
```python
from audiocraft.models import AudioGen
# Load AudioGen
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=5)
# Generate sound effects
descriptions = ["dog barking in a park with birds chirping"]
wav = model.generate(descriptions)
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
```
## Core concepts
### Architecture overview
```
AudioCraft Architecture:
┌──────────────────────────────────────────────────────────────┐
│ Text Encoder (T5) │
│ │ │
│ Text Embeddings │
└────────────────────────┬─────────────────────────────────────┘
┌────────────────────────▼─────────────────────────────────────┐
│ Transformer Decoder (LM) │
│ Auto-regressively generates audio tokens │
│ Using efficient token interleaving patterns │
└────────────────────────┬─────────────────────────────────────┘
┌────────────────────────▼─────────────────────────────────────┐
│ EnCodec Audio Decoder │
│ Converts tokens back to audio waveform │
└──────────────────────────────────────────────────────────────┘
```
### Model variants
| Model | Size | Description | Use Case |
|-------|------|-------------|----------|
| `musicgen-small` | 300M | Text-to-music | Quick generation |
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
### Generation parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `duration` | 8.0 | Length in seconds (1-120) |
| `top_k` | 250 | Top-k sampling |
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
| `temperature` | 1.0 | Sampling temperature |
| `cfg_coef` | 3.0 | Classifier-free guidance |
## MusicGen usage
### Text-to-music generation
```python
from audiocraft.models import MusicGen
import torchaudio
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Configure generation
model.set_generation_params(
duration=30, # Up to 30 seconds
top_k=250, # Sampling diversity
top_p=0.0, # 0 = use top_k only
temperature=1.0, # Creativity (higher = more varied)
cfg_coef=3.0 # Text adherence (higher = stricter)
)
# Generate multiple samples
descriptions = [
"epic orchestral soundtrack with strings and brass",
"chill lo-fi hip hop beat with jazzy piano",
"energetic rock song with electric guitar"
]
# Generate (returns [batch, channels, samples])
wav = model.generate(descriptions)
# Save each
for i, audio in enumerate(wav):
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
```
### Melody-conditioned generation
```python
from audiocraft.models import MusicGen
import torchaudio
# Load melody model
model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(duration=30)
# Load melody audio
melody, sr = torchaudio.load("melody.wav")
# Generate with melody conditioning
descriptions = ["acoustic guitar folk song"]
wav = model.generate_with_chroma(descriptions, melody, sr)
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
```
### Stereo generation
```python
from audiocraft.models import MusicGen
# Load stereo model
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
model.set_generation_params(duration=15)
descriptions = ["ambient electronic music with wide stereo panning"]
wav = model.generate(descriptions)
# wav shape: [batch, 2, samples] for stereo
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
```
### Audio continuation
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
# Load audio to continue
import torchaudio
audio, sr = torchaudio.load("intro.wav")
# Process with text and audio
inputs = processor(
audio=audio.squeeze().numpy(),
sampling_rate=sr,
text=["continue with a epic chorus"],
padding=True,
return_tensors="pt"
)
# Generate continuation
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
```
## MusicGen-Style usage
### Style-conditioned generation
```python
from audiocraft.models import MusicGen
# Load style model
model = MusicGen.get_pretrained('facebook/musicgen-style')
# Configure generation with style
model.set_generation_params(
duration=30,
cfg_coef=3.0,
cfg_coef_beta=5.0 # Style influence
)
# Configure style conditioner
model.set_style_conditioner_params(
eval_q=3, # RVQ quantizers (1-6)
excerpt_length=3.0 # Style excerpt length
)
# Load style reference
style_audio, sr = torchaudio.load("reference_style.wav")
# Generate with text + style
descriptions = ["upbeat dance track"]
wav = model.generate_with_style(descriptions, style_audio, sr)
```
### Style-only generation (no text)
```python
# Generate matching style without text prompt
model.set_generation_params(
duration=30,
cfg_coef=3.0,
cfg_coef_beta=None # Disable double CFG for style-only
)
wav = model.generate_with_style([None], style_audio, sr)
```
## AudioGen usage
### Sound effect generation
```python
from audiocraft.models import AudioGen
import torchaudio
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=10)
# Generate various sounds
descriptions = [
"thunderstorm with heavy rain and lightning",
"busy city traffic with car horns",
"ocean waves crashing on rocks",
"crackling campfire in forest"
]
wav = model.generate(descriptions)
for i, audio in enumerate(wav):
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
```
## EnCodec usage
### Audio compression
```python
from audiocraft.models import CompressionModel
import torch
import torchaudio
# Load EnCodec
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
# Load audio
wav, sr = torchaudio.load("audio.wav")
# Ensure correct sample rate
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
wav = resampler(wav)
# Encode to tokens
with torch.no_grad():
encoded = model.encode(wav.unsqueeze(0))
codes = encoded[0] # Audio codes
# Decode back to audio
with torch.no_grad():
decoded = model.decode(codes)
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
```
## Common workflows
### Workflow 1: Music generation pipeline
```python
import torch
import torchaudio
from audiocraft.models import MusicGen
class MusicGenerator:
def __init__(self, model_name="facebook/musicgen-medium"):
self.model = MusicGen.get_pretrained(model_name)
self.sample_rate = 32000
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
self.model.set_generation_params(
duration=duration,
top_k=250,
temperature=temperature,
cfg_coef=cfg
)
with torch.no_grad():
wav = self.model.generate([prompt])
return wav[0].cpu()
def generate_batch(self, prompts, duration=30):
self.model.set_generation_params(duration=duration)
with torch.no_grad():
wav = self.model.generate(prompts)
return wav.cpu()
def save(self, audio, path):
torchaudio.save(path, audio, sample_rate=self.sample_rate)
# Usage
generator = MusicGenerator()
audio = generator.generate(
"epic cinematic orchestral music",
duration=30,
temperature=1.0
)
generator.save(audio, "epic_music.wav")
```
### Workflow 2: Sound design batch processing
```python
import json
from pathlib import Path
from audiocraft.models import AudioGen
import torchaudio
def batch_generate_sounds(sound_specs, output_dir):
"""
Generate multiple sounds from specifications.
Args:
sound_specs: list of {"name": str, "description": str, "duration": float}
output_dir: output directory path
"""
model = AudioGen.get_pretrained('facebook/audiogen-medium')
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
results = []
for spec in sound_specs:
model.set_generation_params(duration=spec.get("duration", 5))
wav = model.generate([spec["description"]])
output_path = output_dir / f"{spec['name']}.wav"
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
results.append({
"name": spec["name"],
"path": str(output_path),
"description": spec["description"]
})
return results
# Usage
sounds = [
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
]
results = batch_generate_sounds(sounds, "sound_effects/")
```
### Workflow 3: Gradio demo
```python
import gradio as gr
import torch
import torchaudio
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
def generate_music(prompt, duration, temperature, cfg_coef):
model.set_generation_params(
duration=duration,
temperature=temperature,
cfg_coef=cfg_coef
)
with torch.no_grad():
wav = model.generate([prompt])
# Save to temp file
path = "temp_output.wav"
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
return path
demo = gr.Interface(
fn=generate_music,
inputs=[
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
],
outputs=gr.Audio(label="Generated Music"),
title="MusicGen Demo"
)
demo.launch()
```
## Performance optimization
### Memory optimization
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Clear cache between generations
torch.cuda.empty_cache()
# Generate shorter durations
model.set_generation_params(duration=10) # Instead of 30
# Use half precision
model = model.half()
```
### Batch processing efficiency
```python
# Process multiple prompts at once (more efficient)
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
wav = model.generate(descriptions) # Single batch
# Instead of
for desc in descriptions:
wav = model.generate([desc]) # Multiple batches (slower)
```
### GPU memory requirements
| Model | FP32 VRAM | FP16 VRAM |
|-------|-----------|-----------|
| musicgen-small | ~4GB | ~2GB |
| musicgen-medium | ~8GB | ~4GB |
| musicgen-large | ~16GB | ~8GB |
## Common issues
| Issue | Solution |
|-------|----------|
| CUDA OOM | Use smaller model, reduce duration |
| Poor quality | Increase cfg_coef, better prompts |
| Generation too short | Check max duration setting |
| Audio artifacts | Try different temperature |
| Stereo not working | Use stereo model variant |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **GitHub**: https://github.com/facebookresearch/audiocraft
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen

View File

@@ -0,0 +1,666 @@
# AudioCraft Advanced Usage Guide
## Fine-tuning MusicGen
### Custom dataset preparation
```python
import os
import json
from pathlib import Path
import torchaudio
def prepare_dataset(audio_dir, output_dir, metadata_file):
"""
Prepare dataset for MusicGen fine-tuning.
Directory structure:
output_dir/
├── audio/
│ ├── 0001.wav
│ ├── 0002.wav
│ └── ...
└── metadata.json
"""
output_dir = Path(output_dir)
audio_output = output_dir / "audio"
audio_output.mkdir(parents=True, exist_ok=True)
# Load metadata (format: {"path": "...", "description": "..."})
with open(metadata_file) as f:
metadata = json.load(f)
processed = []
for idx, item in enumerate(metadata):
audio_path = Path(audio_dir) / item["path"]
# Load and resample to 32kHz
wav, sr = torchaudio.load(str(audio_path))
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
wav = resampler(wav)
# Convert to mono if stereo
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
# Save processed audio
output_path = audio_output / f"{idx:04d}.wav"
torchaudio.save(str(output_path), wav, sample_rate=32000)
processed.append({
"path": str(output_path.relative_to(output_dir)),
"description": item["description"],
"duration": wav.shape[1] / 32000
})
# Save processed metadata
with open(output_dir / "metadata.json", "w") as f:
json.dump(processed, f, indent=2)
print(f"Processed {len(processed)} samples")
return processed
```
### Fine-tuning with dora
```bash
# AudioCraft uses dora for experiment management
# Install dora
pip install dora-search
# Clone AudioCraft
git clone https://github.com/facebookresearch/audiocraft.git
cd audiocraft
# Create config for fine-tuning
cat > config/solver/musicgen/finetune.yaml << 'EOF'
defaults:
- musicgen/musicgen_base
- /model: lm/musicgen_lm
- /conditioner: cond_base
solver: musicgen
autocast: true
autocast_dtype: float16
optim:
epochs: 100
batch_size: 4
lr: 1e-4
ema: 0.999
optimizer: adamw
dataset:
batch_size: 4
num_workers: 4
train:
- dset: your_dataset
root: /path/to/dataset
valid:
- dset: your_dataset
root: /path/to/dataset
checkpoint:
save_every: 10
keep_every_states: null
EOF
# Run fine-tuning
dora run solver=musicgen/finetune
```
### LoRA fine-tuning
```python
from peft import LoraConfig, get_peft_model
from audiocraft.models import MusicGen
import torch
# Load base model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Get the language model component
lm = model.lm
# Configure LoRA
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
lora_dropout=0.05,
bias="none"
)
# Apply LoRA
lm = get_peft_model(lm, lora_config)
lm.print_trainable_parameters()
```
## Multi-GPU Training
### DataParallel
```python
import torch
import torch.nn as nn
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Wrap LM with DataParallel
if torch.cuda.device_count() > 1:
model.lm = nn.DataParallel(model.lm)
model.to("cuda")
```
### DistributedDataParallel
```python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train(rank, world_size):
setup(rank, world_size)
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.lm = model.lm.to(rank)
model.lm = DDP(model.lm, device_ids=[rank])
# Training loop
# ...
dist.destroy_process_group()
```
## Custom Conditioning
### Adding new conditioners
```python
from audiocraft.modules.conditioners import BaseConditioner
import torch
class CustomConditioner(BaseConditioner):
"""Custom conditioner for additional control signals."""
def __init__(self, dim, output_dim):
super().__init__(dim, output_dim)
self.embed = torch.nn.Linear(dim, output_dim)
def forward(self, x):
return self.embed(x)
def tokenize(self, x):
# Tokenize input for conditioning
return x
# Use with MusicGen
from audiocraft.models.builders import get_lm_model
# Modify model config to include custom conditioner
# This requires editing the model configuration
```
### Melody conditioning internals
```python
from audiocraft.models import MusicGen
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
import torch
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# Access chroma extractor
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
# Manual chroma extraction
def extract_chroma(audio, sr):
"""Extract chroma features from audio."""
import librosa
# Compute chroma
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
return torch.from_numpy(chroma).float()
# Use extracted chroma for conditioning
chroma = extract_chroma(melody_audio, sample_rate)
```
## EnCodec Deep Dive
### Custom compression settings
```python
from audiocraft.models import CompressionModel
import torch
# Load EnCodec
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
# Access codec parameters
print(f"Sample rate: {encodec.sample_rate}")
print(f"Channels: {encodec.channels}")
print(f"Cardinality: {encodec.cardinality}") # Codebook size
print(f"Num codebooks: {encodec.num_codebooks}")
print(f"Frame rate: {encodec.frame_rate}")
# Encode with specific bandwidth
# Lower bandwidth = more compression, lower quality
encodec.set_target_bandwidth(6.0) # 6 kbps
audio = torch.randn(1, 1, 32000) # 1 second
encoded = encodec.encode(audio)
decoded = encodec.decode(encoded[0])
```
### Streaming encoding
```python
import torch
from audiocraft.models import CompressionModel
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
def encode_streaming(audio_stream, chunk_size=32000):
"""Encode audio in streaming fashion."""
all_codes = []
for chunk in audio_stream:
# Ensure chunk is right shape
if chunk.dim() == 1:
chunk = chunk.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
codes = encodec.encode(chunk)[0]
all_codes.append(codes)
return torch.cat(all_codes, dim=-1)
def decode_streaming(codes_stream, output_stream):
"""Decode codes in streaming fashion."""
for codes in codes_stream:
with torch.no_grad():
audio = encodec.decode(codes)
output_stream.write(audio.cpu().numpy())
```
## MultiBand Diffusion
### Using MBD for enhanced quality
```python
from audiocraft.models import MusicGen, MultiBandDiffusion
# Load MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Load MultiBand Diffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
model.set_generation_params(duration=10)
# Generate with standard decoder
descriptions = ["epic orchestral music"]
wav_standard = model.generate(descriptions)
# Generate tokens and use MBD decoder
with torch.no_grad():
# Get tokens
gen_tokens = model.generate_tokens(descriptions)
# Decode with MBD
wav_mbd = mbd.tokens_to_wav(gen_tokens)
# Compare quality
print(f"Standard shape: {wav_standard.shape}")
print(f"MBD shape: {wav_mbd.shape}")
```
## API Server Deployment
### FastAPI server
```python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import torchaudio
from audiocraft.models import MusicGen
import io
import base64
app = FastAPI()
# Load model at startup
model = None
@app.on_event("startup")
async def load_model():
global model
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(duration=10)
class GenerateRequest(BaseModel):
prompt: str
duration: float = 10.0
temperature: float = 1.0
cfg_coef: float = 3.0
class GenerateResponse(BaseModel):
audio_base64: str
sample_rate: int
duration: float
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
model.set_generation_params(
duration=min(request.duration, 30),
temperature=request.temperature,
cfg_coef=request.cfg_coef
)
with torch.no_grad():
wav = model.generate([request.prompt])
# Convert to bytes
buffer = io.BytesIO()
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode()
return GenerateResponse(
audio_base64=audio_base64,
sample_rate=32000,
duration=wav.shape[-1] / 32000
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": model is not None}
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
```
### Batch processing service
```python
import asyncio
from concurrent.futures import ThreadPoolExecutor
import torch
from audiocraft.models import MusicGen
class MusicGenService:
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
self.model = MusicGen.get_pretrained(model_name)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.lock = asyncio.Lock()
async def generate_async(self, prompt, duration=10):
"""Async generation with thread pool."""
loop = asyncio.get_event_loop()
def _generate():
with torch.no_grad():
self.model.set_generation_params(duration=duration)
return self.model.generate([prompt])
# Run in thread pool
wav = await loop.run_in_executor(self.executor, _generate)
return wav[0].cpu()
async def generate_batch_async(self, prompts, duration=10):
"""Process multiple prompts concurrently."""
tasks = [self.generate_async(p, duration) for p in prompts]
return await asyncio.gather(*tasks)
# Usage
service = MusicGenService()
async def main():
prompts = ["jazz piano", "rock guitar", "electronic beats"]
results = await service.generate_batch_async(prompts)
return results
```
## Integration Patterns
### LangChain tool
```python
from langchain.tools import BaseTool
import torch
import torchaudio
from audiocraft.models import MusicGen
import tempfile
class MusicGeneratorTool(BaseTool):
name = "music_generator"
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
def __init__(self):
super().__init__()
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
self.model.set_generation_params(duration=15)
def _run(self, description: str) -> str:
with torch.no_grad():
wav = self.model.generate([description])
# Save to temp file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
return f"Generated music saved to: {f.name}"
async def _arun(self, description: str) -> str:
return self._run(description)
```
### Gradio with advanced controls
```python
import gradio as gr
import torch
import torchaudio
from audiocraft.models import MusicGen
models = {}
def load_model(model_size):
if model_size not in models:
model_name = f"facebook/musicgen-{model_size}"
models[model_size] = MusicGen.get_pretrained(model_name)
return models[model_size]
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
model = load_model(model_size)
model.set_generation_params(
duration=duration,
temperature=temperature,
cfg_coef=cfg_coef,
top_k=top_k
)
with torch.no_grad():
wav = model.generate([prompt])
# Save
path = "output.wav"
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
return path
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Prompt", lines=3),
gr.Slider(1, 30, value=10, label="Duration (s)"),
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
],
outputs=gr.Audio(label="Generated Music"),
title="MusicGen Advanced",
allow_flagging="never"
)
demo.launch(share=True)
```
## Audio Processing Pipeline
### Post-processing chain
```python
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
class AudioPostProcessor:
def __init__(self, sample_rate=32000):
self.sample_rate = sample_rate
def normalize(self, audio, target_db=-14.0):
"""Normalize audio to target loudness."""
rms = torch.sqrt(torch.mean(audio ** 2))
target_rms = 10 ** (target_db / 20)
gain = target_rms / (rms + 1e-8)
return audio * gain
def fade_in_out(self, audio, fade_duration=0.1):
"""Apply fade in/out."""
fade_samples = int(fade_duration * self.sample_rate)
# Create fade curves
fade_in = torch.linspace(0, 1, fade_samples)
fade_out = torch.linspace(1, 0, fade_samples)
# Apply fades
audio[..., :fade_samples] *= fade_in
audio[..., -fade_samples:] *= fade_out
return audio
def apply_reverb(self, audio, decay=0.5):
"""Apply simple reverb effect."""
impulse = torch.zeros(int(self.sample_rate * 0.5))
impulse[0] = 1.0
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
# Convolve
audio = torch.nn.functional.conv1d(
audio.unsqueeze(0),
impulse.unsqueeze(0).unsqueeze(0),
padding=len(impulse) // 2
).squeeze(0)
return audio
def process(self, audio):
"""Full processing pipeline."""
audio = self.normalize(audio)
audio = self.fade_in_out(audio)
return audio
# Usage with MusicGen
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(duration=10)
wav = model.generate(["chill ambient music"])
processor = AudioPostProcessor()
wav_processed = processor.process(wav[0].cpu())
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
```
## Evaluation
### Audio quality metrics
```python
import torch
from audiocraft.metrics import CLAPTextConsistencyMetric
from audiocraft.data.audio import audio_read
def evaluate_generation(audio_path, text_prompt):
"""Evaluate generated audio quality."""
# Load audio
wav, sr = audio_read(audio_path)
# CLAP consistency (text-audio alignment)
clap_metric = CLAPTextConsistencyMetric()
clap_score = clap_metric.compute(wav, [text_prompt])
return {
"clap_score": clap_score,
"duration": wav.shape[-1] / sr
}
# Batch evaluation
def evaluate_batch(generations):
"""Evaluate multiple generations."""
results = []
for gen in generations:
result = evaluate_generation(gen["path"], gen["prompt"])
result["prompt"] = gen["prompt"]
results.append(result)
# Aggregate
avg_clap = sum(r["clap_score"] for r in results) / len(results)
return {
"individual": results,
"average_clap": avg_clap
}
```
## Model Comparison
### MusicGen variants benchmark
| Model | CLAP Score | Generation Time (10s) | VRAM |
|-------|------------|----------------------|------|
| musicgen-small | 0.35 | ~5s | 2GB |
| musicgen-medium | 0.42 | ~15s | 4GB |
| musicgen-large | 0.48 | ~30s | 8GB |
| musicgen-melody | 0.45 | ~15s | 4GB |
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
### Prompt engineering tips
```python
# Good prompts - specific and descriptive
good_prompts = [
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
"funky disco groove with slap bass, brass section, and rhythmic guitar"
]
# Bad prompts - too vague
bad_prompts = [
"nice music",
"song",
"good beat"
]
# Structure: [mood] [genre] with [instruments] at [tempo/style]
```

View File

@@ -0,0 +1,504 @@
# AudioCraft Troubleshooting Guide
## Installation Issues
### Import errors
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
**Solutions**:
```bash
# Install from PyPI
pip install audiocraft
# Or from GitHub
pip install git+https://github.com/facebookresearch/audiocraft.git
# Verify installation
python -c "from audiocraft.models import MusicGen; print('OK')"
```
### FFmpeg not found
**Error**: `RuntimeError: ffmpeg not found`
**Solutions**:
```bash
# Ubuntu/Debian
sudo apt-get install ffmpeg
# macOS
brew install ffmpeg
# Windows (using conda)
conda install -c conda-forge ffmpeg
# Verify
ffmpeg -version
```
### PyTorch CUDA mismatch
**Error**: `RuntimeError: CUDA error: no kernel image is available`
**Solutions**:
```bash
# Check CUDA version
nvcc --version
python -c "import torch; print(torch.version.cuda)"
# Install matching PyTorch
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
# For CUDA 11.8
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
```
### xformers issues
**Error**: `ImportError: xformers` related errors
**Solutions**:
```bash
# Install xformers for memory efficiency
pip install xformers
# Or disable xformers
export AUDIOCRAFT_USE_XFORMERS=0
# In Python
import os
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
from audiocraft.models import MusicGen
```
## Model Loading Issues
### Out of memory during load
**Error**: `torch.cuda.OutOfMemoryError` during model loading
**Solutions**:
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Force CPU loading first
import torch
device = "cpu"
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
model = model.to("cuda")
# Use HuggingFace with device_map
from transformers import MusicgenForConditionalGeneration
model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-small",
device_map="auto"
)
```
### Download failures
**Error**: Connection errors or incomplete downloads
**Solutions**:
```python
# Set cache directory
import os
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
# Or for HuggingFace
os.environ["HF_HOME"] = "/path/to/hf_cache"
# Resume download
from huggingface_hub import snapshot_download
snapshot_download("facebook/musicgen-small", resume_download=True)
# Use local files
model = MusicGen.get_pretrained('/local/path/to/model')
```
### Wrong model type
**Error**: Loading wrong model for task
**Solutions**:
```python
# For text-to-music: use MusicGen
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# For text-to-sound: use AudioGen
from audiocraft.models import AudioGen
model = AudioGen.get_pretrained('facebook/audiogen-medium')
# For melody conditioning: use melody variant
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# For stereo: use stereo variant
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
```
## Generation Issues
### Empty or silent output
**Problem**: Generated audio is silent or very quiet
**Solutions**:
```python
import torch
# Check output
wav = model.generate(["upbeat music"])
print(f"Shape: {wav.shape}")
print(f"Max amplitude: {wav.abs().max().item()}")
print(f"Mean amplitude: {wav.abs().mean().item()}")
# If too quiet, normalize
def normalize_audio(audio, target_db=-14.0):
rms = torch.sqrt(torch.mean(audio ** 2))
target_rms = 10 ** (target_db / 20)
gain = target_rms / (rms + 1e-8)
return audio * gain
wav_normalized = normalize_audio(wav)
```
### Poor quality output
**Problem**: Generated music sounds bad or noisy
**Solutions**:
```python
# Use larger model
model = MusicGen.get_pretrained('facebook/musicgen-large')
# Adjust generation parameters
model.set_generation_params(
duration=15,
top_k=250, # Increase for more diversity
temperature=0.8, # Lower for more focused output
cfg_coef=4.0 # Increase for better text adherence
)
# Use better prompts
# Bad: "music"
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
# Try MultiBand Diffusion
from audiocraft.models import MultiBandDiffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
tokens = model.generate_tokens(["prompt"])
wav = mbd.tokens_to_wav(tokens)
```
### Generation too short
**Problem**: Audio shorter than expected
**Solutions**:
```python
# Check duration setting
model.set_generation_params(duration=30) # Set before generate
# Verify in generation
print(f"Duration setting: {model.generation_params}")
# Check output shape
wav = model.generate(["prompt"])
actual_duration = wav.shape[-1] / 32000
print(f"Actual duration: {actual_duration}s")
# Note: max duration is typically 30s
```
### Melody conditioning fails
**Error**: Issues with melody-conditioned generation
**Solutions**:
```python
import torchaudio
from audiocraft.models import MusicGen
# Load melody model (not base model)
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# Load and prepare melody
melody, sr = torchaudio.load("melody.wav")
# Resample to model sample rate if needed
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
melody = resampler(melody)
# Ensure correct shape [batch, channels, samples]
if melody.dim() == 1:
melody = melody.unsqueeze(0).unsqueeze(0)
elif melody.dim() == 2:
melody = melody.unsqueeze(0)
# Convert stereo to mono
if melody.shape[1] > 1:
melody = melody.mean(dim=1, keepdim=True)
# Generate with melody
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
```
## Memory Issues
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
import torch
# Clear cache before generation
torch.cuda.empty_cache()
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Reduce duration
model.set_generation_params(duration=10) # Instead of 30
# Generate one at a time
for prompt in prompts:
wav = model.generate([prompt])
save_audio(wav)
torch.cuda.empty_cache()
# Use CPU for very large generations
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
```
### Memory leak during batch processing
**Problem**: Memory grows over time
**Solutions**:
```python
import gc
import torch
def generate_with_cleanup(model, prompts):
results = []
for prompt in prompts:
with torch.no_grad():
wav = model.generate([prompt])
results.append(wav.cpu())
# Cleanup
del wav
gc.collect()
torch.cuda.empty_cache()
return results
# Use context manager
with torch.inference_mode():
wav = model.generate(["prompt"])
```
## Audio Format Issues
### Wrong sample rate
**Problem**: Audio plays at wrong speed
**Solutions**:
```python
import torchaudio
# MusicGen outputs at 32kHz
sample_rate = 32000
# AudioGen outputs at 16kHz
sample_rate = 16000
# Always use correct rate when saving
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
# Resample if needed
resampler = torchaudio.transforms.Resample(32000, 44100)
wav_resampled = resampler(wav)
```
### Stereo/mono mismatch
**Problem**: Wrong number of channels
**Solutions**:
```python
# Check model type
print(f"Audio channels: {wav.shape}")
# Mono: [batch, 1, samples]
# Stereo: [batch, 2, samples]
# Convert mono to stereo
if wav.shape[1] == 1:
wav_stereo = wav.repeat(1, 2, 1)
# Convert stereo to mono
if wav.shape[1] == 2:
wav_mono = wav.mean(dim=1, keepdim=True)
# Use stereo model for stereo output
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
```
### Clipping and distortion
**Problem**: Audio has clipping or distortion
**Solutions**:
```python
import torch
# Check for clipping
max_val = wav.abs().max().item()
print(f"Max amplitude: {max_val}")
# Normalize to prevent clipping
if max_val > 1.0:
wav = wav / max_val
# Apply soft clipping
def soft_clip(x, threshold=0.9):
return torch.tanh(x / threshold) * threshold
wav_clipped = soft_clip(wav)
# Lower temperature during generation
model.set_generation_params(temperature=0.7) # More controlled
```
## HuggingFace Transformers Issues
### Processor errors
**Error**: Issues with MusicgenProcessor
**Solutions**:
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
# Load matching processor and model
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
# Ensure inputs are on same device
inputs = processor(
text=["prompt"],
padding=True,
return_tensors="pt"
).to("cuda")
# Check processor configuration
print(processor.tokenizer)
print(processor.feature_extractor)
```
### Generation parameter errors
**Error**: Invalid generation parameters
**Solutions**:
```python
# HuggingFace uses different parameter names
audio_values = model.generate(
**inputs,
do_sample=True, # Enable sampling
guidance_scale=3.0, # CFG (not cfg_coef)
max_new_tokens=256, # Token limit (not duration)
temperature=1.0
)
# Calculate tokens from duration
# ~50 tokens per second
duration_seconds = 10
max_tokens = duration_seconds * 50
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
```
## Performance Issues
### Slow generation
**Problem**: Generation takes too long
**Solutions**:
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Reduce duration
model.set_generation_params(duration=10)
# Use GPU
model.to("cuda")
# Enable flash attention if available
# (requires compatible hardware)
# Batch multiple prompts
prompts = ["prompt1", "prompt2", "prompt3"]
wav = model.generate(prompts) # Single batch is faster than loop
# Use compile (PyTorch 2.0+)
model.lm = torch.compile(model.lm)
```
### CPU fallback
**Problem**: Generation running on CPU instead of GPU
**Solutions**:
```python
import torch
# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# Explicitly move to GPU
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.to("cuda")
# Verify model device
print(f"Model device: {next(model.lm.parameters()).device}")
```
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
## Getting Help
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
2. **HuggingFace Forums**: https://discuss.huggingface.co
3. **Paper**: https://arxiv.org/abs/2306.05284
### Reporting Issues
Include:
- Python version
- PyTorch version
- CUDA version
- AudioCraft version: `pip show audiocraft`
- Full error traceback
- Minimal reproducible code
- Hardware (GPU model, VRAM)

View File

@@ -0,0 +1,81 @@
---
name: code-review
description: Guidelines for performing thorough code reviews with security and quality focus
---
# Code Review Skill
Use this skill when reviewing code changes, pull requests, or auditing existing code.
## Review Checklist
### 1. Security First
- [ ] No hardcoded secrets, API keys, or credentials
- [ ] Input validation on all user-provided data
- [ ] SQL queries use parameterized statements (no string concatenation)
- [ ] File operations validate paths (no path traversal)
- [ ] Authentication/authorization checks present where needed
### 2. Error Handling
- [ ] All external calls (API, DB, file) have try/catch
- [ ] Errors are logged with context (but no sensitive data)
- [ ] User-facing errors are helpful but don't leak internals
- [ ] Resources are cleaned up in finally blocks or context managers
### 3. Code Quality
- [ ] Functions do one thing and are reasonably sized (<50 lines ideal)
- [ ] Variable names are descriptive (no single letters except loops)
- [ ] No commented-out code left behind
- [ ] Complex logic has explanatory comments
- [ ] No duplicate code (DRY principle)
### 4. Testing Considerations
- [ ] Edge cases handled (empty inputs, nulls, boundaries)
- [ ] Happy path and error paths both work
- [ ] New code has corresponding tests (if test suite exists)
## Review Response Format
When providing review feedback, structure it as:
```
## Summary
[1-2 sentence overall assessment]
## Critical Issues (Must Fix)
- Issue 1: [description + suggested fix]
- Issue 2: ...
## Suggestions (Nice to Have)
- Suggestion 1: [description]
## Questions
- [Any clarifying questions about intent]
```
## Common Patterns to Flag
### Python
```python
# Bad: SQL injection risk
cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
# Good: Parameterized query
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
```
### JavaScript
```javascript
// Bad: XSS risk
element.innerHTML = userInput;
// Good: Safe text content
element.textContent = userInput;
```
## Tone Guidelines
- Be constructive, not critical
- Explain *why* something is an issue, not just *what*
- Offer solutions, not just problems
- Acknowledge good patterns you see

224
skills/mlops/faiss/SKILL.md Normal file
View File

@@ -0,0 +1,224 @@
---
name: faiss
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [faiss-cpu, faiss-gpu, numpy]
metadata:
hermes:
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
---
# FAISS - Efficient Similarity Search
Facebook AI's library for billion-scale vector similarity search.
## When to use FAISS
**Use FAISS when:**
- Need fast similarity search on large vector datasets (millions/billions)
- GPU acceleration required
- Pure vector similarity (no metadata filtering needed)
- High throughput, low latency critical
- Offline/batch processing of embeddings
**Metrics**:
- **31,700+ GitHub stars**
- Meta/Facebook AI Research
- **Handles billions of vectors**
- **C++** with Python bindings
**Use alternatives instead**:
- **Chroma/Pinecone**: Need metadata filtering
- **Weaviate**: Need full database features
- **Annoy**: Simpler, fewer features
## Quick start
### Installation
```bash
# CPU only
pip install faiss-cpu
# GPU support
pip install faiss-gpu
```
### Basic usage
```python
import faiss
import numpy as np
# Create sample data (1000 vectors, 128 dimensions)
d = 128
nb = 1000
vectors = np.random.random((nb, d)).astype('float32')
# Create index
index = faiss.IndexFlatL2(d) # L2 distance
index.add(vectors) # Add vectors
# Search
k = 5 # Find 5 nearest neighbors
query = np.random.random((1, d)).astype('float32')
distances, indices = index.search(query, k)
print(f"Nearest neighbors: {indices}")
print(f"Distances: {distances}")
```
## Index types
### 1. Flat (exact search)
```python
# L2 (Euclidean) distance
index = faiss.IndexFlatL2(d)
# Inner product (cosine similarity if normalized)
index = faiss.IndexFlatIP(d)
# Slowest, most accurate
```
### 2. IVF (inverted file) - Fast approximate
```python
# Create quantizer
quantizer = faiss.IndexFlatL2(d)
# IVF index with 100 clusters
nlist = 100
index = faiss.IndexIVFFlat(quantizer, d, nlist)
# Train on data
index.train(vectors)
# Add vectors
index.add(vectors)
# Search (nprobe = clusters to search)
index.nprobe = 10
distances, indices = index.search(query, k)
```
### 3. HNSW (Hierarchical NSW) - Best quality/speed
```python
# HNSW index
M = 32 # Number of connections per layer
index = faiss.IndexHNSWFlat(d, M)
# No training needed
index.add(vectors)
# Search
distances, indices = index.search(query, k)
```
### 4. Product Quantization - Memory efficient
```python
# PQ reduces memory by 16-32×
m = 8 # Number of subquantizers
nbits = 8
index = faiss.IndexPQ(d, m, nbits)
# Train and add
index.train(vectors)
index.add(vectors)
```
## Save and load
```python
# Save index
faiss.write_index(index, "large.index")
# Load index
index = faiss.read_index("large.index")
# Continue using
distances, indices = index.search(query, k)
```
## GPU acceleration
```python
# Single GPU
res = faiss.StandardGpuResources()
index_cpu = faiss.IndexFlatL2(d)
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
# Multi-GPU
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
# 10-100× faster than CPU
```
## LangChain integration
```python
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
# Create FAISS vector store
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
# Save
vectorstore.save_local("faiss_index")
# Load
vectorstore = FAISS.load_local(
"faiss_index",
OpenAIEmbeddings(),
allow_dangerous_deserialization=True
)
# Search
results = vectorstore.similarity_search("query", k=5)
```
## LlamaIndex integration
```python
from llama_index.vector_stores.faiss import FaissVectorStore
import faiss
# Create FAISS index
d = 1536
faiss_index = faiss.IndexFlatL2(d)
vector_store = FaissVectorStore(faiss_index=faiss_index)
```
## Best practices
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
3. **Use GPU for large datasets** - 10-100× faster
4. **Save trained indices** - Training is expensive
5. **Tune nprobe/ef_search** - Balance speed/accuracy
6. **Monitor memory** - PQ for large datasets
7. **Batch queries** - Better GPU utilization
## Performance
| Index Type | Build Time | Search Time | Memory | Accuracy |
|------------|------------|-------------|--------|----------|
| Flat | Fast | Slow | High | 100% |
| IVF | Medium | Fast | Medium | 95-99% |
| HNSW | Slow | Fastest | High | 99% |
| PQ | Medium | Fast | Low | 90-95% |
## Resources
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
- **License**: MIT

View File

@@ -0,0 +1,280 @@
# FAISS Index Types Guide
Complete guide to choosing and using FAISS index types.
## Index selection guide
| Dataset Size | Index Type | Training | Accuracy | Speed |
|--------------|------------|----------|----------|-------|
| < 10K | Flat | No | 100% | Slow |
| 10K-1M | IVF | Yes | 95-99% | Fast |
| 1M-10M | HNSW | No | 99% | Fastest |
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
## Flat indices (exact search)
### IndexFlatL2 - L2 (Euclidean) distance
```python
import faiss
import numpy as np
d = 128 # Dimension
index = faiss.IndexFlatL2(d)
# Add vectors
vectors = np.random.random((1000, d)).astype('float32')
index.add(vectors)
# Search
k = 5
query = np.random.random((1, d)).astype('float32')
distances, indices = index.search(query, k)
```
**Use when:**
- Dataset < 10,000 vectors
- Need 100% accuracy
- Serving as baseline
### IndexFlatIP - Inner product (cosine similarity)
```python
# For cosine similarity, normalize vectors first
import faiss
d = 128
index = faiss.IndexFlatIP(d)
# Normalize vectors (required for cosine similarity)
faiss.normalize_L2(vectors)
index.add(vectors)
# Search
faiss.normalize_L2(query)
distances, indices = index.search(query, k)
```
**Use when:**
- Need cosine similarity
- Recommendation systems
- Text embeddings
## IVF indices (inverted file)
### IndexIVFFlat - Cluster-based search
```python
# Create quantizer
quantizer = faiss.IndexFlatL2(d)
# Create IVF index with 100 clusters
nlist = 100 # Number of clusters
index = faiss.IndexIVFFlat(quantizer, d, nlist)
# Train on data (required!)
index.train(vectors)
# Add vectors
index.add(vectors)
# Search (nprobe = clusters to search)
index.nprobe = 10 # Search 10 closest clusters
distances, indices = index.search(query, k)
```
**Parameters:**
- `nlist`: Number of clusters (√N to 4√N recommended)
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
**Use when:**
- Dataset 10K-1M vectors
- Need fast approximate search
- Can afford training time
### Tuning nprobe
```python
# Test different nprobe values
for nprobe in [1, 5, 10, 20, 50]:
index.nprobe = nprobe
distances, indices = index.search(query, k)
# Measure recall/speed trade-off
```
**Guidelines:**
- `nprobe=1`: Fastest, ~50% recall
- `nprobe=10`: Good balance, ~95% recall
- `nprobe=nlist`: Exact search (same as Flat)
## HNSW indices (graph-based)
### IndexHNSWFlat - Hierarchical NSW
```python
# HNSW index
M = 32 # Number of connections per layer (16-64)
index = faiss.IndexHNSWFlat(d, M)
# Optional: Set ef_construction (build time parameter)
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
# Add vectors (no training needed!)
index.add(vectors)
# Search
index.hnsw.efSearch = 16 # Search time parameter
distances, indices = index.search(query, k)
```
**Parameters:**
- `M`: Connections per layer (16-64, default 32)
- `efConstruction`: Build quality (40-200, higher = better)
- `efSearch`: Search quality (16-512, higher = more accurate)
**Use when:**
- Need best quality approximate search
- Can afford higher memory (more connections)
- Dataset 1M-10M vectors
## PQ indices (product quantization)
### IndexPQ - Memory-efficient
```python
# PQ reduces memory by 16-32×
m = 8 # Number of subquantizers (divides d)
nbits = 8 # Bits per subquantizer
index = faiss.IndexPQ(d, m, nbits)
# Train (required!)
index.train(vectors)
# Add vectors
index.add(vectors)
# Search
distances, indices = index.search(query, k)
```
**Parameters:**
- `m`: Subquantizers (d must be divisible by m)
- `nbits`: Bits per code (8 or 16)
**Memory savings:**
- Original: d × 4 bytes (float32)
- PQ: m bytes
- Compression ratio: 4d/m
**Use when:**
- Limited memory
- Large datasets (> 10M vectors)
- Can accept ~90-95% accuracy
### IndexIVFPQ - IVF + PQ combined
```python
# Best for very large datasets
nlist = 4096
m = 8
nbits = 8
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
# Train
index.train(vectors)
index.add(vectors)
# Search
index.nprobe = 32
distances, indices = index.search(query, k)
```
**Use when:**
- Dataset > 10M vectors
- Need fast search + low memory
- Can accept 90-95% accuracy
## GPU indices
### Single GPU
```python
import faiss
# Create CPU index
index_cpu = faiss.IndexFlatL2(d)
# Move to GPU
res = faiss.StandardGpuResources() # GPU resources
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
# Use normally
index_gpu.add(vectors)
distances, indices = index_gpu.search(query, k)
```
### Multi-GPU
```python
# Use all available GPUs
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
# Or specific GPUs
gpus = [0, 1, 2, 3] # Use GPUs 0-3
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
```
**Speedup:**
- Single GPU: 10-50× faster than CPU
- Multi-GPU: Near-linear scaling
## Index factory
```python
# Easy index creation with string descriptors
index = faiss.index_factory(d, "IVF100,Flat")
index = faiss.index_factory(d, "HNSW32")
index = faiss.index_factory(d, "IVF4096,PQ8")
# Train and use
index.train(vectors)
index.add(vectors)
```
**Common descriptors:**
- `"Flat"`: Exact search
- `"IVF100,Flat"`: IVF with 100 clusters
- `"HNSW32"`: HNSW with M=32
- `"IVF4096,PQ8"`: IVF + PQ compression
## Performance comparison
### Search speed (1M vectors, k=10)
| Index | Build Time | Search Time | Memory | Recall |
|-------|------------|-------------|--------|--------|
| Flat | 0s | 50ms | 512 MB | 100% |
| IVF100 | 5s | 2ms | 512 MB | 95% |
| HNSW32 | 60s | 1ms | 1GB | 99% |
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
*CPU (16 cores), 128-dim vectors*
## Best practices
1. **Start with Flat** - Baseline for comparison
2. **Use IVF for medium datasets** - Good balance
3. **Use HNSW for best quality** - If memory allows
4. **Add PQ for memory savings** - Large datasets
5. **GPU for > 100K vectors** - 10-50× speedup
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
7. **Train on representative data** - Better clustering
8. **Save trained indices** - Avoid retraining
## Resources
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
- **Paper**: https://arxiv.org/abs/1702.08734

View File

@@ -0,0 +1,370 @@
---
name: optimizing-attention-flash
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [flash-attn, torch, transformers]
metadata:
hermes:
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
---
# Flash Attention - Fast Memory-Efficient Attention
## Quick start
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
**PyTorch native (easiest, PyTorch 2.2+)**:
```python
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
```
**flash-attn library (more features)**:
```bash
pip install flash-attn --no-build-isolation
```
```python
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
```
## Common workflows
### Workflow 1: Enable in existing PyTorch model
Copy this checklist:
```
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline
```
**Step 1: Check PyTorch version**
```bash
python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0
```
If <2.2, upgrade:
```bash
pip install --upgrade torch
```
**Step 2: Enable Flash Attention backend**
Replace standard attention:
```python
# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
```
Force Flash Attention backend:
```python
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
```
**Step 3: Verify speedup with profiling**
```python
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
```
Expected: 2-4x speedup for sequences >512 tokens.
**Step 4: Test accuracy matches baseline**
```python
# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16
```
### Workflow 2: Use flash-attn library for advanced features
For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
```
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance
```
**Step 1: Install flash-attn library**
```bash
# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation
# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
```
**Step 2: Modify attention code**
```python
from flash_attn import flash_attn_func
# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
```
**Step 3: Enable advanced features**
Multi-query attention (shared K/V across heads):
```python
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA
```
Sliding window attention (local attention):
```python
# Only attend to window of 256 tokens before/after
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)
```
**Step 4: Benchmark performance**
```python
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v)
# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
```
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
For maximum performance on H100 GPUs.
```
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention
```
**Step 1: Verify H100 GPU**
```bash
nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"
```
**Step 2: Install flash-attn with FP8 support**
```bash
pip install flash-attn --no-build-isolation
# FP8 support included for H100
```
**Step 3: Convert inputs to FP8**
```python
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
```
**Step 4: Run with FP8 attention**
```python
from flash_attn import flash_attn_func
# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
```
## When to use vs alternatives
**Use Flash Attention when:**
- Training transformers with sequences >512 tokens
- Running inference with long context (>2K tokens)
- GPU memory constrained (OOM with standard attention)
- Need 2-4x speedup without accuracy loss
- Using PyTorch 2.2+ or can install flash-attn
**Use alternatives instead:**
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
- **xFormers**: Need more attention variants (not just speed)
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
## Common issues
**Issue: ImportError: cannot import flash_attn**
Install with no-build-isolation flag:
```bash
pip install flash-attn --no-build-isolation
```
Or install CUDA toolkit first:
```bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
```
**Issue: Slower than expected (no speedup)**
Flash Attention benefits increase with sequence length:
- <512 tokens: Minimal speedup (10-20%)
- 512-2K tokens: 2-3x speedup
- >2K tokens: 3-4x speedup
Check sequence length is sufficient.
**Issue: RuntimeError: CUDA error**
Verify GPU supports Flash Attention:
```python
import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+
```
Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported
**Issue: Accuracy degradation**
Check dtype is float16 or bfloat16 (not float32):
```python
q = q.to(torch.float16) # Or torch.bfloat16
```
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
## Advanced topics
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
## Hardware requirements
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
- **CUDA**: 12.0+ (11.8 minimum)
- **PyTorch**: 2.2+ for native support
**Not supported**: V100 (Volta), CPU inference
## Resources
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
- Blog: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

View File

@@ -0,0 +1,215 @@
# Performance Benchmarks
## Contents
- Speed comparisons across GPUs
- Memory usage analysis
- Scaling with sequence length
- Training vs inference performance
- Flash Attention versions comparison
## Speed comparisons across GPUs
### A100 80GB (Ampere)
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|------------|----------|--------------|--------------|---------------|
| 512 | 1.2 | 0.9 | N/A | 1.3x |
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
### H100 80GB (Hopper)
**Forward pass time** (milliseconds, same config):
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|------------|----------|--------------|---------------------|--------------------|--------------|
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
### A10G 24GB (Ampere)
**Forward pass time** (milliseconds, batch=4):
| Seq Length | Standard | Flash Attn 2 | Speedup |
|------------|----------|--------------|---------|
| 512 | 2.1 | 1.6 | 1.3x |
| 1024 | 6.8 | 2.8 | 2.4x |
| 2048 | 25.9 | 9.4 | 2.8x |
| 4096 | 102.1 | 35.2 | 2.9x |
## Memory usage analysis
### GPU memory consumption (batch=8, heads=32, dim=64)
**Standard attention memory**:
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|------------|------------------|----------|-------|-------|
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
**Flash Attention 2 memory**:
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|------------|---------------------|----------|-------|-----------|
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
### Memory scaling comparison
**Llama 2 7B model memory** (float16, batch=1):
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|----------------|-------------------|-------------------|-------------------|
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
| 32K | OOM | 14.2 GB | Only Flash: Yes |
### Training memory (Llama 2 7B, batch=4)
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|---------|---------------|-----------------|-----------|
| 2K | 18.2 | 12.4 | 32% |
| 4K | 34.8 | 16.8 | 52% |
| 8K | OOM (>40GB) | 26.2 | Fits! |
## Scaling with sequence length
### Computational complexity
**Standard attention**:
- Time: O(N² × d)
- Memory: O(N² + N × d)
**Flash Attention**:
- Time: O(N² × d) (same, but with better constants)
- Memory: O(N × d) (linear!)
### Empirical scaling (A100, batch=1, heads=32, dim=64)
**Time per token (milliseconds)**:
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|----------|-----|-----|-----|-----|-----|------|
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
**Observation**: Speedup increases quadratically with sequence length!
### Memory per token (MB)
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|----------|-----|-----|-----|-----|-----|------|
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
**Observation**: Flash Attention memory per token is constant!
## Training vs inference performance
### Training (forward + backward, Llama 2 7B, A100)
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|-------------|------------------------|--------------------------|---------|
| 4 × 2K | 1.2 | 3.1 | 2.6x |
| 8 × 2K | 2.1 | 5.8 | 2.8x |
| 4 × 4K | 0.4 | 1.3 | 3.3x |
| 8 × 4K | OOM | 2.4 | Enabled |
| 2 × 8K | 0.1 | 0.4 | 4.0x |
### Inference (generation, Llama 2 7B, A100)
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|----------------|----------------------|-------------------------|---------|
| 512 | 48 | 52 | 1.1x |
| 2K | 42 | 62 | 1.5x |
| 4K | 31 | 58 | 1.9x |
| 8K | 18 | 51 | 2.8x |
| 16K | OOM | 42 | Enabled |
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
## Flash Attention versions comparison
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|--------|-----|-----|------------|-----------|
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
| TFLOPS | 180 | 420 | 740 | 1150 |
| GPU util % | 35% | 55% | 75% | 82% |
**Key improvements**:
- FA2: 2.3x faster than FA1 (better parallelism)
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
- FA3 (FP8): 2.6x faster than FA2 (low precision)
### Features by version
| Feature | FA1 | FA2 | FA3 |
|---------|-----|-----|-----|
| Basic attention | ✅ | ✅ | ✅ |
| Causal masking | ✅ | ✅ | ✅ |
| Multi-query attention | ❌ | ✅ | ✅ |
| Sliding window | ❌ | ✅ | ✅ |
| Paged KV cache | ❌ | ✅ | ✅ |
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
| Work partitioning | Basic | Advanced | Optimal |
## Real-world model benchmarks
### Llama 2 models (A100 80GB, batch=4, seq=2048)
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|-------|--------|------------------------|--------------------------|---------|
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
### GPT-style models (seq=1024)
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|-------|----------------------|-------------------------|---------|
| GPT-2 (124M) | 520 | 680 | 1.3x |
| GPT-J (6B) | 42 | 98 | 2.3x |
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
## Recommendations by use case
**Training large models (>7B parameters)**:
- Use Flash Attention 2 on A100
- Use Flash Attention 3 FP8 on H100 for maximum speed
- Expected: 2.5-3x speedup
**Long context inference (>4K tokens)**:
- Flash Attention essential (enables contexts standard attention can't handle)
- Expected: 2-4x speedup, 5-10x memory reduction
**Short sequences (<512 tokens)**:
- Flash Attention provides 1.2-1.5x speedup
- Minimal memory benefit
- Still worth enabling (no downside)
**Multi-user serving**:
- Flash Attention reduces per-request memory
- Allows higher concurrent batch sizes
- Can serve 2-3x more users on same hardware

View File

@@ -0,0 +1,293 @@
# HuggingFace Transformers Integration
## Contents
- Enabling Flash Attention in Transformers
- Supported model architectures
- Configuration examples
- Performance comparisons
- Troubleshooting model-specific issues
## Enabling Flash Attention in Transformers
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
**Simple enable for any supported model**:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
```
**Install requirements**:
```bash
pip install transformers>=4.36
pip install flash-attn --no-build-isolation
```
## Supported model architectures
As of Transformers 4.40:
**Fully supported**:
- Llama / Llama 2 / Llama 3
- Mistral / Mixtral
- Falcon
- GPT-NeoX
- Phi / Phi-2 / Phi-3
- Qwen / Qwen2
- Gemma
- Starcoder2
- GPT-J
- OPT
- BLOOM
**Partially supported** (encoder-decoder):
- BART
- T5 / Flan-T5
- Whisper
**Check support**:
```python
from transformers import AutoConfig
config = AutoConfig.from_pretrained("model-name")
print(config._attn_implementation_internal)
# 'flash_attention_2' if supported
```
## Configuration examples
### Llama 2 with Flash Attention
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Generate
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))
```
### Mistral with Flash Attention for long context
```python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # Better for long context
device_map="auto",
max_position_embeddings=32768 # Extended context
)
# Process long document (32K tokens)
long_text = "..." * 10000
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
```
### Fine-tuning with Flash Attention
```python
from transformers import Trainer, TrainingArguments
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=3,
fp16=True, # Must match model dtype
optim="adamw_torch_fused" # Fast optimizer
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
trainer.train()
```
### Multi-GPU training
```python
from transformers import AutoModelForCausalLM
import torch
# Model parallelism with Flash Attention
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto", # Automatic multi-GPU placement
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
)
```
## Performance comparisons
### Memory usage (Llama 2 7B, batch=1)
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|-----------------|-------------------|-------------------|-----------|
| 512 | 1.2 GB | 0.9 GB | 25% |
| 2048 | 3.8 GB | 1.4 GB | 63% |
| 8192 | 14.2 GB | 3.2 GB | 77% |
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
### Speed (tokens/sec, A100 80GB)
| Model | Standard | Flash Attn 2 | Speedup |
|-------|----------|--------------|---------|
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
### Training throughput (samples/sec)
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|-------|------------|----------|--------------|---------|
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
## Troubleshooting model-specific issues
### Issue: Model doesn't support Flash Attention
Check support list above. If not supported, use PyTorch SDPA as fallback:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="sdpa", # PyTorch native (still faster)
torch_dtype=torch.float16
)
```
### Issue: CUDA out of memory during loading
Reduce memory footprint:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto",
max_memory={0: "18GB"}, # Reserve memory for KV cache
low_cpu_mem_usage=True
)
```
### Issue: Slower inference than expected
Ensure dtype matches:
```python
# Model and inputs must both be float16/bfloat16
model = model.to(torch.float16)
inputs = tokenizer(..., return_tensors="pt").to("cuda")
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
for k, v in inputs.items()}
```
### Issue: Different outputs vs standard attention
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
```python
# Compare outputs
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
model_flash = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
with torch.no_grad():
out_standard = model_standard(**inputs).logits
out_flash = model_flash(**inputs).logits
diff = (out_standard - out_flash).abs().max()
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
```
### Issue: ImportError during model loading
Install flash-attn:
```bash
pip install flash-attn --no-build-isolation
```
Or disable Flash Attention:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="eager", # Standard PyTorch
torch_dtype=torch.float16
)
```
## Best practices
1. **Always use float16/bfloat16** with Flash Attention (not float32)
2. **Set device_map="auto"** for automatic memory management
3. **Use bfloat16 for long context** (better numerical stability)
4. **Enable gradient checkpointing** for training large models
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
**Example with all best practices**:
```python
from transformers import AutoModelForCausalLM, TrainingArguments
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # Better for training
device_map="auto",
low_cpu_mem_usage=True
)
# Enable gradient checkpointing for memory
model.gradient_checkpointing_enable()
# Training with optimizations
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
bf16=True, # Match model dtype
optim="adamw_torch_fused",
gradient_checkpointing=True
)
```

430
skills/mlops/gguf/SKILL.md Normal file
View File

@@ -0,0 +1,430 @@
---
name: gguf-quantization
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [llama-cpp-python>=0.2.0]
metadata:
hermes:
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
---
# GGUF - Quantization Format for llama.cpp
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
## When to use GGUF
**Use GGUF when:**
- Deploying on consumer hardware (laptops, desktops)
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
- Need CPU inference without GPU requirements
- Want flexible quantization (Q2_K to Q8_0)
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
**Key advantages:**
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
- **No Python runtime**: Pure C/C++ inference
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
- **imatrix**: Importance matrix for better low-bit quality
**Use alternatives instead:**
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
- **HQQ**: Fast calibration-free quantization for HuggingFace
- **bitsandbytes**: Simple integration with transformers library
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
## Quick start
### Installation
```bash
# Clone llama.cpp
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
# Build (CPU)
make
# Build with CUDA (NVIDIA)
make GGML_CUDA=1
# Build with Metal (Apple Silicon)
make GGML_METAL=1
# Install Python bindings (optional)
pip install llama-cpp-python
```
### Convert model to GGUF
```bash
# Install requirements
pip install -r requirements.txt
# Convert HuggingFace model to GGUF (FP16)
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
# Or specify output type
python convert_hf_to_gguf.py ./path/to/model \
--outfile model-f16.gguf \
--outtype f16
```
### Quantize model
```bash
# Basic quantization to Q4_K_M
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
# Quantize with importance matrix (better quality)
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
### Run inference
```bash
# CLI inference
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
# Interactive mode
./llama-cli -m model-q4_k_m.gguf --interactive
# With GPU offload
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
```
## Quantization types
### K-quant methods (recommended)
| Type | Bits | Size (7B) | Quality | Use Case |
|------|------|-----------|---------|----------|
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
### Legacy methods
| Type | Description |
|------|-------------|
| Q4_0 | 4-bit, basic |
| Q4_1 | 4-bit with delta |
| Q5_0 | 5-bit, basic |
| Q5_1 | 5-bit with delta |
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
## Conversion workflows
### Workflow 1: HuggingFace to GGUF
```bash
# 1. Download model
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
# 2. Convert to GGUF (FP16)
python convert_hf_to_gguf.py ./llama-3.1-8b \
--outfile llama-3.1-8b-f16.gguf \
--outtype f16
# 3. Quantize
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
# 4. Test
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
```
### Workflow 2: With importance matrix (better quality)
```bash
# 1. Convert to GGUF
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
# 2. Create calibration text (diverse samples)
cat > calibration.txt << 'EOF'
The quick brown fox jumps over the lazy dog.
Machine learning is a subset of artificial intelligence.
Python is a popular programming language.
# Add more diverse text samples...
EOF
# 3. Generate importance matrix
./llama-imatrix -m model-f16.gguf \
-f calibration.txt \
--chunk 512 \
-o model.imatrix \
-ngl 35 # GPU layers if available
# 4. Quantize with imatrix
./llama-quantize --imatrix model.imatrix \
model-f16.gguf \
model-q4_k_m.gguf \
Q4_K_M
```
### Workflow 3: Multiple quantizations
```bash
#!/bin/bash
MODEL="llama-3.1-8b-f16.gguf"
IMATRIX="llama-3.1-8b.imatrix"
# Generate imatrix once
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
# Create multiple quantizations
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
done
```
## Python usage
### llama-cpp-python
```python
from llama_cpp import Llama
# Load model
llm = Llama(
model_path="./model-q4_k_m.gguf",
n_ctx=4096, # Context window
n_gpu_layers=35, # GPU offload (0 for CPU only)
n_threads=8 # CPU threads
)
# Generate
output = llm(
"What is machine learning?",
max_tokens=256,
temperature=0.7,
stop=["</s>", "\n\n"]
)
print(output["choices"][0]["text"])
```
### Chat completion
```python
from llama_cpp import Llama
llm = Llama(
model_path="./model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35,
chat_format="llama-3" # Or "chatml", "mistral", etc.
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is Python?"}
]
response = llm.create_chat_completion(
messages=messages,
max_tokens=256,
temperature=0.7
)
print(response["choices"][0]["message"]["content"])
```
### Streaming
```python
from llama_cpp import Llama
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
# Stream tokens
for chunk in llm(
"Explain quantum computing:",
max_tokens=256,
stream=True
):
print(chunk["choices"][0]["text"], end="", flush=True)
```
## Server mode
### Start OpenAI-compatible server
```bash
# Start server
./llama-server -m model-q4_k_m.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 35 \
-c 4096
# Or with Python bindings
python -m llama_cpp.server \
--model model-q4_k_m.gguf \
--n_gpu_layers 35 \
--host 0.0.0.0 \
--port 8080
```
### Use with OpenAI client
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="not-needed"
)
response = client.chat.completions.create(
model="local-model",
messages=[{"role": "user", "content": "Hello!"}],
max_tokens=256
)
print(response.choices[0].message.content)
```
## Hardware optimization
### Apple Silicon (Metal)
```bash
# Build with Metal
make clean && make GGML_METAL=1
# Run with Metal acceleration
./llama-cli -m model.gguf -ngl 99 -p "Hello"
# Python with Metal
llm = Llama(
model_path="model.gguf",
n_gpu_layers=99, # Offload all layers
n_threads=1 # Metal handles parallelism
)
```
### NVIDIA CUDA
```bash
# Build with CUDA
make clean && make GGML_CUDA=1
# Run with CUDA
./llama-cli -m model.gguf -ngl 35 -p "Hello"
# Specify GPU
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
```
### CPU optimization
```bash
# Build with AVX2/AVX512
make clean && make
# Run with optimal threads
./llama-cli -m model.gguf -t 8 -p "Hello"
# Python CPU config
llm = Llama(
model_path="model.gguf",
n_gpu_layers=0, # CPU only
n_threads=8, # Match physical cores
n_batch=512 # Batch size for prompt processing
)
```
## Integration with tools
### Ollama
```bash
# Create Modelfile
cat > Modelfile << 'EOF'
FROM ./model-q4_k_m.gguf
TEMPLATE """{{ .System }}
{{ .Prompt }}"""
PARAMETER temperature 0.7
PARAMETER num_ctx 4096
EOF
# Create Ollama model
ollama create mymodel -f Modelfile
# Run
ollama run mymodel "Hello!"
```
### LM Studio
1. Place GGUF file in `~/.cache/lm-studio/models/`
2. Open LM Studio and select the model
3. Configure context length and GPU offload
4. Start inference
### text-generation-webui
```bash
# Place in models folder
cp model-q4_k_m.gguf text-generation-webui/models/
# Start with llama.cpp loader
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
```
## Best practices
1. **Use K-quants**: Q4_K_M offers best quality/size balance
2. **Use imatrix**: Always use importance matrix for Q4 and below
3. **GPU offload**: Offload as many layers as VRAM allows
4. **Context length**: Start with 4096, increase if needed
5. **Thread count**: Match physical CPU cores, not logical
6. **Batch size**: Increase n_batch for faster prompt processing
## Common issues
**Model loads slowly:**
```bash
# Use mmap for faster loading
./llama-cli -m model.gguf --mmap
```
**Out of memory:**
```bash
# Reduce GPU layers
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
# Or use smaller quantization
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
```
**Poor quality at low bits:**
```bash
# Always use imatrix for Q4 and below
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
## References
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
## Resources
- **Repository**: https://github.com/ggml-org/llama.cpp
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
- **Pre-quantized Models**: https://huggingface.co/TheBloke
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
- **License**: MIT

View File

@@ -0,0 +1,504 @@
# GGUF Advanced Usage Guide
## Speculative Decoding
### Draft Model Approach
```bash
# Use smaller model as draft for faster generation
./llama-speculative \
-m large-model-q4_k_m.gguf \
-md draft-model-q4_k_m.gguf \
-p "Write a story about AI" \
-n 500 \
--draft 8 # Draft tokens before verification
```
### Self-Speculative Decoding
```bash
# Use same model with different context for speculation
./llama-cli -m model-q4_k_m.gguf \
--lookup-cache-static lookup.bin \
--lookup-cache-dynamic lookup-dynamic.bin \
-p "Hello world"
```
## Batched Inference
### Process Multiple Prompts
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35,
n_batch=512 # Larger batch for parallel processing
)
prompts = [
"What is Python?",
"Explain machine learning.",
"Describe neural networks."
]
# Process in batch (each prompt gets separate context)
for prompt in prompts:
output = llm(prompt, max_tokens=100)
print(f"Q: {prompt}")
print(f"A: {output['choices'][0]['text']}\n")
```
### Server Batching
```bash
# Start server with batching
./llama-server -m model-q4_k_m.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 35 \
-c 4096 \
--parallel 4 # Concurrent requests
--cont-batching # Continuous batching
```
## Custom Model Conversion
### Convert with Vocabulary Modifications
```python
# custom_convert.py
import sys
sys.path.insert(0, './llama.cpp')
from convert_hf_to_gguf import main
from gguf import GGUFWriter
# Custom conversion with modified vocab
def convert_with_custom_vocab(model_path, output_path):
# Load and modify tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Add special tokens if needed
special_tokens = {"additional_special_tokens": ["<|custom|>"]}
tokenizer.add_special_tokens(special_tokens)
tokenizer.save_pretrained(model_path)
# Then run standard conversion
main([model_path, "--outfile", output_path])
```
### Convert Specific Architecture
```bash
# For Mistral-style models
python convert_hf_to_gguf.py ./mistral-model \
--outfile mistral-f16.gguf \
--outtype f16
# For Qwen models
python convert_hf_to_gguf.py ./qwen-model \
--outfile qwen-f16.gguf \
--outtype f16
# For Phi models
python convert_hf_to_gguf.py ./phi-model \
--outfile phi-f16.gguf \
--outtype f16
```
## Advanced Quantization
### Mixed Quantization
```bash
# Quantize different layer types differently
./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \
--allow-requantize \
--leave-output-tensor
```
### Quantization with Token Embeddings
```bash
# Keep embeddings at higher precision
./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \
--token-embedding-type f16
```
### IQ Quantization (Importance-aware)
```bash
# Ultra-low bit quantization with importance
./llama-quantize --imatrix model.imatrix \
model-f16.gguf model-iq2_xxs.gguf IQ2_XXS
# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS
```
## Memory Optimization
### Memory Mapping
```python
from llama_cpp import Llama
# Use memory mapping for large models
llm = Llama(
model_path="model-q4_k_m.gguf",
use_mmap=True, # Memory map the model
use_mlock=False, # Don't lock in RAM
n_gpu_layers=35
)
```
### Partial GPU Offload
```python
# Calculate layers to offload based on VRAM
import subprocess
def get_free_vram_gb():
result = subprocess.run(
['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
capture_output=True, text=True
)
return int(result.stdout.strip()) / 1024
# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4)
free_vram = get_free_vram_gb()
layers_to_offload = int(free_vram / 0.5)
llm = Llama(
model_path="model-q4_k_m.gguf",
n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers
)
```
### KV Cache Optimization
```python
from llama_cpp import Llama
# Optimize KV cache for long contexts
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=8192, # Large context
n_gpu_layers=35,
type_k=1, # Q8_0 for K cache (1)
type_v=1, # Q8_0 for V cache (1)
# Or use Q4_0 (2) for more compression
)
```
## Context Management
### Context Shifting
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35
)
# Handle long conversations with context shifting
conversation = []
max_history = 10
def chat(user_message):
conversation.append({"role": "user", "content": user_message})
# Keep only recent history
if len(conversation) > max_history * 2:
conversation = conversation[-max_history * 2:]
response = llm.create_chat_completion(
messages=conversation,
max_tokens=256
)
assistant_message = response["choices"][0]["message"]["content"]
conversation.append({"role": "assistant", "content": assistant_message})
return assistant_message
```
### Save and Load State
```bash
# Save state to file
./llama-cli -m model.gguf \
-p "Once upon a time" \
--save-session session.bin \
-n 100
# Load and continue
./llama-cli -m model.gguf \
--load-session session.bin \
-p " and they lived" \
-n 100
```
## Grammar Constrained Generation
### JSON Output
```python
from llama_cpp import Llama, LlamaGrammar
# Define JSON grammar
json_grammar = LlamaGrammar.from_string('''
root ::= object
object ::= "{" ws pair ("," ws pair)* "}" ws
pair ::= string ":" ws value
value ::= string | number | object | array | "true" | "false" | "null"
array ::= "[" ws value ("," ws value)* "]" ws
string ::= "\\"" [^"\\\\]* "\\""
number ::= [0-9]+
ws ::= [ \\t\\n]*
''')
llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35)
output = llm(
"Output a JSON object with name and age:",
grammar=json_grammar,
max_tokens=100
)
print(output["choices"][0]["text"])
```
### Custom Grammar
```python
# Grammar for specific format
answer_grammar = LlamaGrammar.from_string('''
root ::= "Answer: " letter "\\n" "Explanation: " explanation
letter ::= [A-D]
explanation ::= [a-zA-Z0-9 .,!?]+
''')
output = llm(
"Q: What is 2+2? A) 3 B) 4 C) 5 D) 6",
grammar=answer_grammar,
max_tokens=100
)
```
## LoRA Integration
### Load LoRA Adapter
```bash
# Apply LoRA at runtime
./llama-cli -m base-model-q4_k_m.gguf \
--lora lora-adapter.gguf \
--lora-scale 1.0 \
-p "Hello!"
```
### Multiple LoRA Adapters
```bash
# Stack multiple adapters
./llama-cli -m base-model.gguf \
--lora adapter1.gguf --lora-scale 0.5 \
--lora adapter2.gguf --lora-scale 0.5 \
-p "Hello!"
```
### Python LoRA Usage
```python
from llama_cpp import Llama
llm = Llama(
model_path="base-model-q4_k_m.gguf",
lora_path="lora-adapter.gguf",
lora_scale=1.0,
n_gpu_layers=35
)
```
## Embedding Generation
### Extract Embeddings
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
embedding=True, # Enable embedding mode
n_gpu_layers=35
)
# Get embeddings
embeddings = llm.embed("This is a test sentence.")
print(f"Embedding dimension: {len(embeddings)}")
```
### Batch Embeddings
```python
texts = [
"Machine learning is fascinating.",
"Deep learning uses neural networks.",
"Python is a programming language."
]
embeddings = [llm.embed(text) for text in texts]
# Calculate similarity
import numpy as np
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
sim = cosine_similarity(embeddings[0], embeddings[1])
print(f"Similarity: {sim:.4f}")
```
## Performance Tuning
### Benchmark Script
```python
import time
from llama_cpp import Llama
def benchmark(model_path, prompt, n_tokens=100, n_runs=5):
llm = Llama(
model_path=model_path,
n_gpu_layers=35,
n_ctx=2048,
verbose=False
)
# Warmup
llm(prompt, max_tokens=10)
# Benchmark
times = []
for _ in range(n_runs):
start = time.time()
output = llm(prompt, max_tokens=n_tokens)
elapsed = time.time() - start
times.append(elapsed)
avg_time = sum(times) / len(times)
tokens_per_sec = n_tokens / avg_time
print(f"Model: {model_path}")
print(f"Avg time: {avg_time:.2f}s")
print(f"Tokens/sec: {tokens_per_sec:.1f}")
return tokens_per_sec
# Compare quantizations
for quant in ["q4_k_m", "q5_k_m", "q8_0"]:
benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100)
```
### Optimal Configuration Finder
```python
def find_optimal_config(model_path, target_vram_gb=8):
"""Find optimal n_gpu_layers and n_batch for target VRAM."""
from llama_cpp import Llama
import gc
best_config = None
best_speed = 0
for n_gpu_layers in range(0, 50, 5):
for n_batch in [128, 256, 512, 1024]:
try:
gc.collect()
llm = Llama(
model_path=model_path,
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx=2048,
verbose=False
)
# Quick benchmark
start = time.time()
llm("Hello", max_tokens=50)
speed = 50 / (time.time() - start)
if speed > best_speed:
best_speed = speed
best_config = {
"n_gpu_layers": n_gpu_layers,
"n_batch": n_batch,
"speed": speed
}
del llm
gc.collect()
except Exception as e:
print(f"OOM at layers={n_gpu_layers}, batch={n_batch}")
break
return best_config
```
## Multi-GPU Setup
### Distribute Across GPUs
```bash
# Split model across multiple GPUs
./llama-cli -m large-model.gguf \
--tensor-split 0.5,0.5 \
-ngl 60 \
-p "Hello!"
```
### Python Multi-GPU
```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
from llama_cpp import Llama
llm = Llama(
model_path="large-model-q4_k_m.gguf",
n_gpu_layers=60,
tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs
)
```
## Custom Builds
### Build with All Optimizations
```bash
# Clean build with all CPU optimizations
make clean
LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j
# With CUDA and cuBLAS
make clean
GGML_CUDA=1 LLAMA_CUBLAS=1 make -j
# With specific CUDA architecture
GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j
```
### CMake Build
```bash
mkdir build && cd build
cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release
cmake --build . --config Release -j
```

View File

@@ -0,0 +1,442 @@
# GGUF Troubleshooting Guide
## Installation Issues
### Build Fails
**Error**: `make: *** No targets specified and no makefile found`
**Fix**:
```bash
# Ensure you're in llama.cpp directory
cd llama.cpp
make
```
**Error**: `fatal error: cuda_runtime.h: No such file or directory`
**Fix**:
```bash
# Install CUDA toolkit
# Ubuntu
sudo apt install nvidia-cuda-toolkit
# Or set CUDA path
export CUDA_PATH=/usr/local/cuda
export PATH=$CUDA_PATH/bin:$PATH
make GGML_CUDA=1
```
### Python Bindings Issues
**Error**: `ERROR: Failed building wheel for llama-cpp-python`
**Fix**:
```bash
# Install build dependencies
pip install cmake scikit-build-core
# For CUDA support
CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
# For Metal (macOS)
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
```
**Error**: `ImportError: libcudart.so.XX: cannot open shared object file`
**Fix**:
```bash
# Add CUDA libraries to path
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Or reinstall with correct CUDA version
pip uninstall llama-cpp-python
CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
```
## Conversion Issues
### Model Not Supported
**Error**: `KeyError: 'model.embed_tokens.weight'`
**Fix**:
```bash
# Check model architecture
python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)"
# Use appropriate conversion script
# For most models:
python convert_hf_to_gguf.py ./model --outfile model.gguf
# For older models, check if legacy script needed
```
### Vocabulary Mismatch
**Error**: `RuntimeError: Vocabulary size mismatch`
**Fix**:
```python
# Ensure tokenizer matches model
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("./model")
model = AutoModelForCausalLM.from_pretrained("./model")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model vocab size: {model.config.vocab_size}")
# If mismatch, resize embeddings before conversion
model.resize_token_embeddings(len(tokenizer))
model.save_pretrained("./model-fixed")
```
### Out of Memory During Conversion
**Error**: `torch.cuda.OutOfMemoryError` during conversion
**Fix**:
```bash
# Use CPU for conversion
CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf
# Or use low memory mode
python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16
```
## Quantization Issues
### Wrong Output File Size
**Problem**: Quantized file is larger than expected
**Check**:
```bash
# Verify quantization type
./llama-cli -m model.gguf --verbose
# Expected sizes for 7B model:
# Q4_K_M: ~4.1 GB
# Q5_K_M: ~4.8 GB
# Q8_0: ~7.2 GB
# F16: ~13.5 GB
```
### Quantization Crashes
**Error**: `Segmentation fault` during quantization
**Fix**:
```bash
# Increase stack size
ulimit -s unlimited
# Or use less threads
./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M
```
### Poor Quality After Quantization
**Problem**: Model outputs gibberish after quantization
**Solutions**:
1. **Use importance matrix**:
```bash
# Generate imatrix with good calibration data
./llama-imatrix -m model-f16.gguf \
-f wiki_sample.txt \
--chunk 512 \
-o model.imatrix
# Quantize with imatrix
./llama-quantize --imatrix model.imatrix \
model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
2. **Try higher precision**:
```bash
# Use Q5_K_M or Q6_K instead of Q4
./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M
```
3. **Check original model**:
```bash
# Test FP16 version first
./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50
```
## Inference Issues
### Slow Generation
**Problem**: Generation is slower than expected
**Solutions**:
1. **Enable GPU offload**:
```bash
./llama-cli -m model.gguf -ngl 35 -p "Hello"
```
2. **Optimize batch size**:
```python
llm = Llama(
model_path="model.gguf",
n_batch=512, # Increase for faster prompt processing
n_gpu_layers=35
)
```
3. **Use appropriate threads**:
```bash
# Match physical cores, not logical
./llama-cli -m model.gguf -t 8 -p "Hello"
```
4. **Enable Flash Attention** (if supported):
```bash
./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello"
```
### Out of Memory
**Error**: `CUDA out of memory` or system freeze
**Solutions**:
1. **Reduce GPU layers**:
```python
# Start low and increase
llm = Llama(model_path="model.gguf", n_gpu_layers=10)
```
2. **Use smaller quantization**:
```bash
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
```
3. **Reduce context length**:
```python
llm = Llama(
model_path="model.gguf",
n_ctx=2048, # Reduce from 4096
n_gpu_layers=35
)
```
4. **Quantize KV cache**:
```python
llm = Llama(
model_path="model.gguf",
type_k=2, # Q4_0 for K cache
type_v=2, # Q4_0 for V cache
n_gpu_layers=35
)
```
### Garbage Output
**Problem**: Model outputs random characters or nonsense
**Diagnose**:
```python
# Check model loading
llm = Llama(model_path="model.gguf", verbose=True)
# Test with simple prompt
output = llm("1+1=", max_tokens=5, temperature=0)
print(output)
```
**Solutions**:
1. **Check model integrity**:
```bash
# Verify GGUF file
./llama-cli -m model.gguf --verbose 2>&1 | head -50
```
2. **Use correct chat format**:
```python
llm = Llama(
model_path="model.gguf",
chat_format="llama-3" # Match your model: chatml, mistral, etc.
)
```
3. **Check temperature**:
```python
# Use lower temperature for deterministic output
output = llm("Hello", max_tokens=50, temperature=0.1)
```
### Token Issues
**Error**: `RuntimeError: unknown token` or encoding errors
**Fix**:
```python
# Ensure UTF-8 encoding
prompt = "Hello, world!".encode('utf-8').decode('utf-8')
output = llm(prompt, max_tokens=50)
```
## Server Issues
### Connection Refused
**Error**: `Connection refused` when accessing server
**Fix**:
```bash
# Bind to all interfaces
./llama-server -m model.gguf --host 0.0.0.0 --port 8080
# Check if port is in use
lsof -i :8080
```
### Server Crashes Under Load
**Problem**: Server crashes with multiple concurrent requests
**Solutions**:
1. **Limit parallelism**:
```bash
./llama-server -m model.gguf \
--parallel 2 \
-c 4096 \
--cont-batching
```
2. **Add request timeout**:
```bash
./llama-server -m model.gguf --timeout 300
```
3. **Monitor memory**:
```bash
watch -n 1 nvidia-smi # For GPU
watch -n 1 free -h # For RAM
```
### API Compatibility Issues
**Problem**: OpenAI client not working with server
**Fix**:
```python
from openai import OpenAI
# Use correct base URL format
client = OpenAI(
base_url="http://localhost:8080/v1", # Include /v1
api_key="not-needed"
)
# Use correct model name
response = client.chat.completions.create(
model="local", # Or the actual model name
messages=[{"role": "user", "content": "Hello"}]
)
```
## Apple Silicon Issues
### Metal Not Working
**Problem**: Metal acceleration not enabled
**Check**:
```bash
# Verify Metal support
./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal
```
**Fix**:
```bash
# Rebuild with Metal
make clean
make GGML_METAL=1
# Python bindings
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall
```
### Incorrect Memory Usage on M1/M2
**Problem**: Model uses too much unified memory
**Fix**:
```python
# Offload all layers for Metal
llm = Llama(
model_path="model.gguf",
n_gpu_layers=99, # Offload everything
n_threads=1 # Metal handles parallelism
)
```
## Debugging
### Enable Verbose Output
```bash
# CLI verbose mode
./llama-cli -m model.gguf --verbose -p "Hello" -n 50
# Python verbose
llm = Llama(model_path="model.gguf", verbose=True)
```
### Check Model Metadata
```bash
# View GGUF metadata
./llama-cli -m model.gguf --verbose 2>&1 | head -100
```
### Validate GGUF File
```python
import struct
def validate_gguf(filepath):
with open(filepath, 'rb') as f:
magic = f.read(4)
if magic != b'GGUF':
print(f"Invalid magic: {magic}")
return False
version = struct.unpack('<I', f.read(4))[0]
print(f"GGUF version: {version}")
tensor_count = struct.unpack('<Q', f.read(8))[0]
metadata_count = struct.unpack('<Q', f.read(8))[0]
print(f"Tensors: {tensor_count}, Metadata: {metadata_count}")
return True
validate_gguf("model.gguf")
```
## Getting Help
1. **GitHub Issues**: https://github.com/ggml-org/llama.cpp/issues
2. **Discussions**: https://github.com/ggml-org/llama.cpp/discussions
3. **Reddit**: r/LocalLLaMA
### Reporting Issues
Include:
- llama.cpp version/commit hash
- Build command used
- Model name and quantization
- Full error message/stack trace
- Hardware: CPU/GPU model, RAM, VRAM
- OS version
- Minimal reproduction steps

View File

@@ -0,0 +1,97 @@
# GRPO/RL Training Skill
**Expert-level guidance for Group Relative Policy Optimization with TRL**
## 📁 Skill Structure
```
grpo-rl-training/
├── SKILL.md # Main skill documentation (READ THIS FIRST)
├── README.md # This file
├── templates/
│ └── basic_grpo_training.py # Production-ready training template
└── examples/
└── reward_functions_library.py # 20+ reward function examples
```
## 🚀 Quick Start
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
4. **Modify for your use case** - Adapt dataset, rewards, and config
## 💡 What's Inside
### SKILL.md (Main Documentation)
- Core GRPO concepts and algorithm fundamentals
- Complete implementation workflow (dataset → rewards → training → deployment)
- 10+ reward function examples with code
- Hyperparameter tuning guide
- Training insights (loss behavior, metrics, debugging)
- Troubleshooting guide
- Production best practices
### Templates
- **basic_grpo_training.py**: Minimal, production-ready training script
- Uses Qwen 2.5 1.5B Instruct
- 3 reward functions (format + correctness)
- LoRA for efficient training
- Fully documented and ready to run
### Examples
- **reward_functions_library.py**: 20+ battle-tested reward functions
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
- Format rewards (XML, JSON, strict/soft)
- Length rewards (ideal length, min/max)
- Style rewards (reasoning quality, citations, repetition penalty)
- Combined rewards (multi-objective optimization)
- Preset collections for common tasks
## 📖 Usage for Agents
When this skill is loaded in your agent's context:
1. **Always read SKILL.md first** before implementing
2. **Start simple** - Use length-based reward to validate setup
3. **Build incrementally** - Add one reward function at a time
4. **Reference examples** - Copy patterns from reward_functions_library.py
5. **Monitor training** - Watch reward metrics (not loss!)
## 🎯 Common Use Cases
| Task Type | Recommended Rewards | Template |
|-----------|---------------------|----------|
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
## ⚠️ Critical Reminders
- **Loss goes UP during training** - This is normal (it's KL divergence)
- **Use 3-5 reward functions** - Single rewards often fail
- **Test rewards before training** - Debug each function independently
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
- **Start with num_generations=4-8** - Scale up if GPU allows
## 🔗 External Resources
- [TRL Documentation](https://huggingface.co/docs/trl)
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
## 📝 Version
**v1.0.0** - Initial release (January 2025)
## 👨‍💻 Maintained By
Orchestra Research
For questions or improvements, see https://orchestra.com
---
**License:** MIT
**Last Updated:** January 2025

View File

@@ -0,0 +1,575 @@
---
name: grpo-rl-training
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
metadata:
hermes:
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
---
# GRPO/RL Training with TRL
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
## When to Use This Skill
Use GRPO training when you need to:
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
- **Align models to domain-specific behaviors** without labeled preference data
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
**Do NOT use GRPO for:**
- Simple supervised fine-tuning tasks (use SFT instead)
- Tasks without clear reward signals
- When you already have high-quality preference pairs (use DPO/PPO instead)
---
## Core Concepts
### 1. GRPO Algorithm Fundamentals
**Key Mechanism:**
- Generates **multiple completions** for each prompt (group size: 4-16)
- Compares completions within each group using reward functions
- Updates policy to favor higher-rewarded responses relative to the group
**Critical Difference from PPO:**
- No separate reward model needed
- More sample-efficient (learns from within-group comparisons)
- Simpler to implement and debug
**Mathematical Intuition:**
```
For each prompt p:
1. Generate N completions: {c₁, c₂, ..., cₙ}
2. Compute rewards: {r₁, r₂, ..., rₙ}
3. Learn to increase probability of high-reward completions
relative to low-reward ones in the same group
```
### 2. Reward Function Design Philosophy
**Golden Rules:**
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
2. **Scale rewards appropriately** - Higher weight = stronger signal
3. **Use incremental rewards** - Partial credit for partial compliance
4. **Test rewards independently** - Debug each reward function in isolation
**Reward Function Types:**
| Type | Use Case | Example Weight |
|------|----------|----------------|
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
| **Format** | Strict structure enforcement | 0.5-1.0 |
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
---
## Implementation Workflow
### Step 1: Dataset Preparation
**Critical Requirements:**
- Prompts in chat format (list of dicts with 'role' and 'content')
- Include system prompts to set expectations
- For verifiable tasks, include ground truth answers as additional columns
**Example Structure:**
```python
from datasets import load_dataset, Dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
[Your step-by-step thinking]
</reasoning>
<answer>
[Final answer]
</answer>
"""
def prepare_dataset(raw_data):
"""
Transform raw data into GRPO-compatible format.
Returns: Dataset with columns:
- 'prompt': List[Dict] with role/content (system + user messages)
- 'answer': str (ground truth, optional but recommended)
"""
return raw_data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_answer(x['raw_answer'])
})
```
**Pro Tips:**
- Use one-shot or few-shot examples in system prompt for complex formats
- Keep prompts concise (max_prompt_length: 256-512 tokens)
- Validate data quality before training (garbage in = garbage out)
### Step 2: Reward Function Implementation
**Template Structure:**
```python
def reward_function_name(
prompts, # List[List[Dict]]: Original prompts
completions, # List[List[Dict]]: Model generations
answer=None, # Optional: Ground truth from dataset
**kwargs # Additional dataset columns
) -> list[float]:
"""
Evaluate completions and return rewards.
Returns: List of floats (one per completion)
"""
# Extract completion text
responses = [comp[0]['content'] for comp in completions]
# Compute rewards
rewards = []
for response in responses:
score = compute_score(response)
rewards.append(score)
return rewards
```
**Example 1: Correctness Reward (Math/Coding)**
```python
def correctness_reward(prompts, completions, answer, **kwargs):
"""Reward correct answers with high score."""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_final_answer(r) for r in responses]
return [2.0 if ans == gt else 0.0
for ans, gt in zip(extracted, answer)]
```
**Example 2: Format Reward (Structured Output)**
```python
import re
def format_reward(completions, **kwargs):
"""Reward XML-like structured format."""
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
responses = [comp[0]['content'] for comp in completions]
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
for r in responses]
```
**Example 3: Incremental Format Reward (Partial Credit)**
```python
def incremental_format_reward(completions, **kwargs):
"""Award partial credit for format compliance."""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
score = 0.0
if '<reasoning>' in r:
score += 0.25
if '</reasoning>' in r:
score += 0.25
if '<answer>' in r:
score += 0.25
if '</answer>' in r:
score += 0.25
# Penalize extra text after closing tag
if r.count('</answer>') == 1:
extra_text = r.split('</answer>')[-1].strip()
score -= len(extra_text) * 0.001
rewards.append(score)
return rewards
```
**Critical Insight:**
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
### Step 3: Training Configuration
**Memory-Optimized Config (Small GPU)**
```python
from trl import GRPOConfig
training_args = GRPOConfig(
output_dir="outputs/grpo-model",
# Learning rate
learning_rate=5e-6, # Lower = more stable
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
# Batch settings
per_device_train_batch_size=1,
gradient_accumulation_steps=4, # Effective batch = 4
# GRPO-specific
num_generations=8, # Group size: 8-16 recommended
max_prompt_length=256,
max_completion_length=512,
# Training duration
num_train_epochs=1,
max_steps=None, # Or set fixed steps (e.g., 500)
# Optimization
bf16=True, # Faster on A100/H100
optim="adamw_8bit", # Memory-efficient optimizer
max_grad_norm=0.1,
# Logging
logging_steps=1,
save_steps=100,
report_to="wandb", # Or "none" for no logging
)
```
**High-Performance Config (Large GPU)**
```python
training_args = GRPOConfig(
output_dir="outputs/grpo-model",
learning_rate=1e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
num_generations=16, # Larger groups = better signal
max_prompt_length=512,
max_completion_length=1024,
num_train_epochs=1,
bf16=True,
use_vllm=True, # Fast generation with vLLM
logging_steps=10,
)
```
**Critical Hyperparameters:**
| Parameter | Impact | Tuning Advice |
|-----------|--------|---------------|
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
### Step 4: Model Setup and Training
**Standard Setup (Transformers)**
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import GRPOTrainer
# Load model
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # 2-3x faster
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Optional: LoRA for parameter-efficient training
peft_config = LoraConfig(
r=16, # Rank (higher = more capacity)
lora_alpha=32, # Scaling factor (typically 2*r)
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
# Initialize trainer
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
incremental_format_reward,
format_reward,
correctness_reward,
],
args=training_args,
train_dataset=dataset,
peft_config=peft_config, # Remove for full fine-tuning
)
# Train
trainer.train()
# Save
trainer.save_model("final_model")
```
**Unsloth Setup (2-3x Faster)**
```python
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="google/gemma-3-1b-it",
max_seq_length=1024,
load_in_4bit=True,
fast_inference=True,
max_lora_rank=32,
)
model = FastLanguageModel.get_peft_model(
model,
r=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=32,
use_gradient_checkpointing="unsloth",
)
# Rest is identical to standard setup
trainer = GRPOTrainer(model=model, ...)
trainer.train()
```
---
## Critical Training Insights
### 1. Loss Behavior (EXPECTED PATTERN)
- **Loss starts near 0 and INCREASES during training**
- This is CORRECT - loss measures KL divergence from initial policy
- Model is learning (diverging from original behavior to optimize rewards)
- Monitor reward metrics instead of loss for progress
### 2. Reward Tracking
Key metrics to watch:
- `reward`: Average across all completions
- `reward_std`: Diversity within groups (should remain > 0)
- `kl`: KL divergence from reference (should grow moderately)
**Healthy Training Pattern:**
```
Step Reward Reward_Std KL
100 0.5 0.3 0.02
200 0.8 0.25 0.05
300 1.2 0.2 0.08 ← Good progression
400 1.5 0.15 0.12
```
**Warning Signs:**
- Reward std → 0 (model collapsing to single response)
- KL exploding (> 0.5) (diverging too much, reduce LR)
- Reward stuck (reward functions too harsh or model capacity issue)
### 3. Common Pitfalls and Solutions
| Problem | Symptom | Solution |
|---------|---------|----------|
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
| **No learning** | Flat rewards | Check reward function logic, increase LR |
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
---
## Advanced Patterns
### 1. Multi-Stage Training
For complex tasks, train in stages:
```python
# Stage 1: Format compliance (epochs=1)
trainer_stage1 = GRPOTrainer(
model=model,
reward_funcs=[incremental_format_reward, format_reward],
...
)
trainer_stage1.train()
# Stage 2: Correctness (epochs=1)
trainer_stage2 = GRPOTrainer(
model=model,
reward_funcs=[format_reward, correctness_reward],
...
)
trainer_stage2.train()
```
### 2. Adaptive Reward Scaling
```python
class AdaptiveReward:
def __init__(self, base_reward_func, initial_weight=1.0):
self.func = base_reward_func
self.weight = initial_weight
def __call__(self, *args, **kwargs):
rewards = self.func(*args, **kwargs)
return [r * self.weight for r in rewards]
def adjust_weight(self, success_rate):
"""Increase weight if model struggling, decrease if succeeding."""
if success_rate < 0.3:
self.weight *= 1.2
elif success_rate > 0.8:
self.weight *= 0.9
```
### 3. Custom Dataset Integration
```python
def load_custom_knowledge_base(csv_path):
"""Example: School communication platform docs."""
import pandas as pd
df = pd.read_csv(csv_path)
dataset = Dataset.from_pandas(df).map(lambda x: {
'prompt': [
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': x['expert_answer']
})
return dataset
```
---
## Deployment and Inference
### Save and Merge LoRA
```python
# Merge LoRA adapters into base model
if hasattr(trainer.model, 'merge_and_unload'):
merged_model = trainer.model.merge_and_unload()
merged_model.save_pretrained("production_model")
tokenizer.save_pretrained("production_model")
```
### Inference Example
```python
from transformers import pipeline
generator = pipeline(
"text-generation",
model="production_model",
tokenizer=tokenizer
)
result = generator(
[
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': "What is 15 + 27?"}
],
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9
)
print(result[0]['generated_text'])
```
---
## Best Practices Checklist
**Before Training:**
- [ ] Validate dataset format (prompts as List[Dict])
- [ ] Test reward functions on sample data
- [ ] Calculate expected max_prompt_length from data
- [ ] Choose appropriate num_generations based on GPU memory
- [ ] Set up logging (wandb recommended)
**During Training:**
- [ ] Monitor reward progression (should increase)
- [ ] Check reward_std (should stay > 0.1)
- [ ] Watch for OOM errors (reduce batch size if needed)
- [ ] Sample generations every 50-100 steps
- [ ] Validate format compliance on holdout set
**After Training:**
- [ ] Merge LoRA weights if using PEFT
- [ ] Test on diverse prompts
- [ ] Compare to baseline model
- [ ] Document reward weights and hyperparameters
- [ ] Save reproducibility config
---
## Troubleshooting Guide
### Debugging Workflow
1. **Isolate reward functions** - Test each independently
2. **Check data distribution** - Ensure diversity in prompts
3. **Reduce complexity** - Start with single reward, add gradually
4. **Monitor generations** - Print samples every N steps
5. **Validate extraction logic** - Ensure answer parsing works
### Quick Fixes
```python
# Debug reward function
def debug_reward(completions, **kwargs):
responses = [comp[0]['content'] for comp in completions]
for i, r in enumerate(responses[:2]): # Print first 2
print(f"Response {i}: {r[:200]}...")
return [1.0] * len(responses) # Dummy rewards
# Test without training
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
trainer.generate_completions(dataset[:1]) # Generate without updating
```
---
## References and Resources
**Official Documentation:**
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
- Unsloth Docs: https://docs.unsloth.ai/
**Example Repositories:**
- Open R1 Implementation: https://github.com/huggingface/open-r1
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
**Recommended Reading:**
- Progressive Disclosure Pattern for agent instructions
- Reward shaping in RL (Ng et al.)
- LoRA paper (Hu et al., 2021)
---
## Usage Instructions for Agents
When this skill is loaded:
1. **Read this entire file** before implementing GRPO training
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
3. **Use the templates** in `templates/` directory as starting points
4. **Reference examples** in `examples/` for task-specific implementations
5. **Follow the workflow** sequentially (don't skip steps)
6. **Debug incrementally** - add one reward function at a time
**Critical Reminders:**
- Always use multiple reward functions (3-5 is optimal)
- Monitor reward metrics, not loss
- Test reward functions before training
- Start small (num_generations=4), scale up gradually
- Save checkpoints frequently (every 100 steps)
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.

View File

@@ -0,0 +1,228 @@
"""
Basic GRPO Training Template
=============================
A minimal, production-ready template for GRPO training with TRL.
Adapt this for your specific task by modifying:
1. Dataset loading (get_dataset function)
2. Reward functions (reward_*_func)
3. System prompt (SYSTEM_PROMPT)
4. Hyperparameters (GRPOConfig)
"""
import torch
import re
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import GRPOTrainer, GRPOConfig
# ==================== CONFIGURATION ====================
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
OUTPUT_DIR = "outputs/grpo-model"
MAX_PROMPT_LENGTH = 256
MAX_COMPLETION_LENGTH = 512
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
[Your step-by-step thinking]
</reasoning>
<answer>
[Final answer]
</answer>
"""
# ==================== DATASET ====================
def get_dataset(split="train"):
"""
Load and prepare your dataset.
Returns: Dataset with columns:
- 'prompt': List[Dict] with role/content
- 'answer': str (ground truth, optional)
"""
# Example: GSM8K math dataset
data = load_dataset('openai/gsm8k', 'main')[split]
def process_example(x):
# Extract ground truth answer
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
return {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': answer
}
return data.map(process_example)
# ==================== HELPER FUNCTIONS ====================
def extract_xml_tag(text: str, tag: str) -> str:
"""Extract content between XML tags."""
pattern = f'<{tag}>(.*?)</{tag}>'
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def extract_answer(text: str) -> str:
"""Extract the final answer from structured output."""
return extract_xml_tag(text, 'answer')
# ==================== REWARD FUNCTIONS ====================
def correctness_reward_func(prompts, completions, answer, **kwargs):
"""
Reward correct answers.
Weight: 2.0 (highest priority)
"""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_answer(r) for r in responses]
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
def format_reward_func(completions, **kwargs):
"""
Reward proper XML format.
Weight: 0.5
"""
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
responses = [comp[0]['content'] for comp in completions]
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
def incremental_format_reward_func(completions, **kwargs):
"""
Incremental reward for partial format compliance.
Weight: up to 0.5
"""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
score = 0.0
if '<reasoning>' in r:
score += 0.125
if '</reasoning>' in r:
score += 0.125
if '<answer>' in r:
score += 0.125
if '</answer>' in r:
score += 0.125
# Penalize extra content after closing tag
if '</answer>' in r:
extra = r.split('</answer>')[-1].strip()
score -= len(extra) * 0.001
rewards.append(score)
return rewards
# ==================== MODEL SETUP ====================
def setup_model_and_tokenizer():
"""Load model and tokenizer with optimizations."""
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def get_peft_config():
"""LoRA configuration for parameter-efficient training."""
return LoraConfig(
r=16,
lora_alpha=32,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
# ==================== TRAINING ====================
def main():
"""Main training function."""
# Load data
print("Loading dataset...")
dataset = get_dataset()
print(f"Dataset size: {len(dataset)}")
# Setup model
print("Loading model...")
model, tokenizer = setup_model_and_tokenizer()
# Training configuration
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
run_name="grpo-training",
# Learning rate
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
# Batch settings
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
# GRPO specific
num_generations=8,
max_prompt_length=MAX_PROMPT_LENGTH,
max_completion_length=MAX_COMPLETION_LENGTH,
# Training duration
num_train_epochs=1,
# Optimization
bf16=True,
optim="adamw_8bit",
max_grad_norm=0.1,
# Logging
logging_steps=1,
save_steps=100,
report_to="wandb", # Change to "none" to disable logging
)
# Initialize trainer
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
incremental_format_reward_func,
format_reward_func,
correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
peft_config=get_peft_config(),
)
# Train
print("Starting training...")
trainer.train()
# Save final model
print(f"Saving model to {OUTPUT_DIR}/final")
trainer.save_model(f"{OUTPUT_DIR}/final")
print("Training complete!")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,575 @@
---
name: guidance
description: Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [guidance, transformers]
metadata:
hermes:
tags: [Prompt Engineering, Guidance, Constrained Generation, Structured Output, JSON Validation, Grammar, Microsoft Research, Format Enforcement, Multi-Step Workflows]
---
# Guidance: Constrained LLM Generation
## When to Use This Skill
Use Guidance when you need to:
- **Control LLM output syntax** with regex or grammars
- **Guarantee valid JSON/XML/code** generation
- **Reduce latency** vs traditional prompting approaches
- **Enforce structured formats** (dates, emails, IDs, etc.)
- **Build multi-step workflows** with Pythonic control flow
- **Prevent invalid outputs** through grammatical constraints
**GitHub Stars**: 18,000+ | **From**: Microsoft Research
## Installation
```bash
# Base installation
pip install guidance
# With specific backends
pip install guidance[transformers] # Hugging Face models
pip install guidance[llama_cpp] # llama.cpp models
```
## Quick Start
### Basic Example: Structured Generation
```python
from guidance import models, gen
# Load model (supports OpenAI, Transformers, llama.cpp)
lm = models.OpenAI("gpt-4")
# Generate with constraints
result = lm + "The capital of France is " + gen("capital", max_tokens=5)
print(result["capital"]) # "Paris"
```
### With Anthropic Claude
```python
from guidance import models, gen, system, user, assistant
# Configure Claude
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Use context managers for chat format
with system():
lm += "You are a helpful assistant."
with user():
lm += "What is the capital of France?"
with assistant():
lm += gen(max_tokens=20)
```
## Core Concepts
### 1. Context Managers
Guidance uses Pythonic context managers for chat-style interactions.
```python
from guidance import system, user, assistant, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# System message
with system():
lm += "You are a JSON generation expert."
# User message
with user():
lm += "Generate a person object with name and age."
# Assistant response
with assistant():
lm += gen("response", max_tokens=100)
print(lm["response"])
```
**Benefits:**
- Natural chat flow
- Clear role separation
- Easy to read and maintain
### 2. Constrained Generation
Guidance ensures outputs match specified patterns using regex or grammars.
#### Regex Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Constrain to valid email format
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
# Constrain to date format (YYYY-MM-DD)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
# Constrain to phone number
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
print(lm["email"]) # Guaranteed valid email
print(lm["date"]) # Guaranteed YYYY-MM-DD format
```
**How it works:**
- Regex converted to grammar at token level
- Invalid tokens filtered during generation
- Model can only produce matching outputs
#### Selection Constraints
```python
from guidance import models, gen, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Constrain to specific choices
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
# Multiple-choice selection
lm += "Best answer: " + select(
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
name="answer"
)
print(lm["sentiment"]) # One of: positive, negative, neutral
print(lm["answer"]) # One of: A, B, C, or D
```
### 3. Token Healing
Guidance automatically "heals" token boundaries between prompt and generation.
**Problem:** Tokenization creates unnatural boundaries.
```python
# Without token healing
prompt = "The capital of France is "
# Last token: " is "
# First generated token might be " Par" (with leading space)
# Result: "The capital of France is Paris" (double space!)
```
**Solution:** Guidance backs up one token and regenerates.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Token healing enabled by default
lm += "The capital of France is " + gen("capital", max_tokens=5)
# Result: "The capital of France is Paris" (correct spacing)
```
**Benefits:**
- Natural text boundaries
- No awkward spacing issues
- Better model performance (sees natural token sequences)
### 4. Grammar-Based Generation
Define complex structures using context-free grammars.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# JSON grammar (simplified)
json_grammar = """
{
"name": <gen name regex="[A-Za-z ]+" max_tokens=20>,
"age": <gen age regex="[0-9]+" max_tokens=3>,
"email": <gen email regex="[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" max_tokens=50>
}
"""
# Generate valid JSON
lm += gen("person", grammar=json_grammar)
print(lm["person"]) # Guaranteed valid JSON structure
```
**Use cases:**
- Complex structured outputs
- Nested data structures
- Programming language syntax
- Domain-specific languages
### 5. Guidance Functions
Create reusable generation patterns with the `@guidance` decorator.
```python
from guidance import guidance, gen, models
@guidance
def generate_person(lm):
"""Generate a person with name and age."""
lm += "Name: " + gen("name", max_tokens=20, stop="\n")
lm += "\nAge: " + gen("age", regex=r"[0-9]+", max_tokens=3)
return lm
# Use the function
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_person(lm)
print(lm["name"])
print(lm["age"])
```
**Stateful Functions:**
```python
@guidance(stateless=False)
def react_agent(lm, question, tools, max_rounds=5):
"""ReAct agent with tool use."""
lm += f"Question: {question}\n\n"
for i in range(max_rounds):
# Thought
lm += f"Thought {i+1}: " + gen("thought", stop="\n")
# Action
lm += "\nAction: " + select(list(tools.keys()), name="action")
# Execute tool
tool_result = tools[lm["action"]]()
lm += f"\nObservation: {tool_result}\n\n"
# Check if done
lm += "Done? " + select(["Yes", "No"], name="done")
if lm["done"] == "Yes":
break
# Final answer
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
return lm
```
## Backend Configuration
### Anthropic Claude
```python
from guidance import models
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key" # Or set ANTHROPIC_API_KEY env var
)
```
### OpenAI
```python
lm = models.OpenAI(
model="gpt-4o-mini",
api_key="your-api-key" # Or set OPENAI_API_KEY env var
)
```
### Local Models (Transformers)
```python
from guidance.models import Transformers
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda" # Or "cpu"
)
```
### Local Models (llama.cpp)
```python
from guidance.models import LlamaCpp
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=35
)
```
## Common Patterns
### Pattern 1: JSON Generation
```python
from guidance import models, gen, system, user, assistant
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You generate valid JSON."
with user():
lm += "Generate a user profile with name, age, and email."
with assistant():
lm += """{
"name": """ + gen("name", regex=r'"[A-Za-z ]+"', max_tokens=30) + """,
"age": """ + gen("age", regex=r"[0-9]+", max_tokens=3) + """,
"email": """ + gen("email", regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"', max_tokens=50) + """
}"""
print(lm) # Valid JSON guaranteed
```
### Pattern 2: Classification
```python
from guidance import models, gen, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
text = "This product is amazing! I love it."
lm += f"Text: {text}\n"
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]+", max_tokens=3) + "%"
print(f"Sentiment: {lm['sentiment']}")
print(f"Confidence: {lm['confidence']}%")
```
### Pattern 3: Multi-Step Reasoning
```python
from guidance import models, gen, guidance
@guidance
def chain_of_thought(lm, question):
"""Generate answer with step-by-step reasoning."""
lm += f"Question: {question}\n\n"
# Generate multiple reasoning steps
for i in range(3):
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
# Final answer
lm += "\nTherefore, the answer is: " + gen("answer", max_tokens=50)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = chain_of_thought(lm, "What is 15% of 200?")
print(lm["answer"])
```
### Pattern 4: ReAct Agent
```python
from guidance import models, gen, select, guidance
@guidance(stateless=False)
def react_agent(lm, question):
"""ReAct agent with tool use."""
tools = {
"calculator": lambda expr: eval(expr),
"search": lambda query: f"Search results for: {query}",
}
lm += f"Question: {question}\n\n"
for round in range(5):
# Thought
lm += f"Thought: " + gen("thought", stop="\n") + "\n"
# Action selection
lm += "Action: " + select(["calculator", "search", "answer"], name="action")
if lm["action"] == "answer":
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
break
# Action input
lm += "\nAction Input: " + gen("action_input", stop="\n") + "\n"
# Execute tool
if lm["action"] in tools:
result = tools[lm["action"]](lm["action_input"])
lm += f"Observation: {result}\n\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = react_agent(lm, "What is 25 * 4 + 10?")
print(lm["answer"])
```
### Pattern 5: Data Extraction
```python
from guidance import models, gen, guidance
@guidance
def extract_entities(lm, text):
"""Extract structured entities from text."""
lm += f"Text: {text}\n\n"
# Extract person
lm += "Person: " + gen("person", stop="\n", max_tokens=30) + "\n"
# Extract organization
lm += "Organization: " + gen("organization", stop="\n", max_tokens=30) + "\n"
# Extract date
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}", max_tokens=10) + "\n"
# Extract location
lm += "Location: " + gen("location", stop="\n", max_tokens=30) + "\n"
return lm
text = "Tim Cook announced at Apple Park on 2024-09-15 in Cupertino."
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = extract_entities(lm, text)
print(f"Person: {lm['person']}")
print(f"Organization: {lm['organization']}")
print(f"Date: {lm['date']}")
print(f"Location: {lm['location']}")
```
## Best Practices
### 1. Use Regex for Format Validation
```python
# ✅ Good: Regex ensures valid format
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
# ❌ Bad: Free generation may produce invalid emails
lm += "Email: " + gen("email", max_tokens=50)
```
### 2. Use select() for Fixed Categories
```python
# ✅ Good: Guaranteed valid category
lm += "Status: " + select(["pending", "approved", "rejected"], name="status")
# ❌ Bad: May generate typos or invalid values
lm += "Status: " + gen("status", max_tokens=20)
```
### 3. Leverage Token Healing
```python
# Token healing is enabled by default
# No special action needed - just concatenate naturally
lm += "The capital is " + gen("capital") # Automatic healing
```
### 4. Use stop Sequences
```python
# ✅ Good: Stop at newline for single-line outputs
lm += "Name: " + gen("name", stop="\n")
# ❌ Bad: May generate multiple lines
lm += "Name: " + gen("name", max_tokens=50)
```
### 5. Create Reusable Functions
```python
# ✅ Good: Reusable pattern
@guidance
def generate_person(lm):
lm += "Name: " + gen("name", stop="\n")
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
return lm
# Use multiple times
lm = generate_person(lm)
lm += "\n\n"
lm = generate_person(lm)
```
### 6. Balance Constraints
```python
# ✅ Good: Reasonable constraints
lm += gen("name", regex=r"[A-Za-z ]+", max_tokens=30)
# ❌ Too strict: May fail or be very slow
lm += gen("name", regex=r"^(John|Jane)$", max_tokens=10)
```
## Comparison to Alternatives
| Feature | Guidance | Instructor | Outlines | LMQL |
|---------|----------|------------|----------|------|
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
| Grammar Support | ✅ CFG | ❌ No | ✅ CFG | ✅ CFG |
| Pydantic Validation | ❌ No | ✅ Yes | ✅ Yes | ❌ No |
| Token Healing | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
| Local Models | ✅ Yes | ⚠️ Limited | ✅ Yes | ✅ Yes |
| API Models | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
| Pythonic Syntax | ✅ Yes | ✅ Yes | ✅ Yes | ❌ SQL-like |
| Learning Curve | Low | Low | Medium | High |
**When to choose Guidance:**
- Need regex/grammar constraints
- Want token healing
- Building complex workflows with control flow
- Using local models (Transformers, llama.cpp)
- Prefer Pythonic syntax
**When to choose alternatives:**
- Instructor: Need Pydantic validation with automatic retrying
- Outlines: Need JSON schema validation
- LMQL: Prefer declarative query syntax
## Performance Characteristics
**Latency Reduction:**
- 30-50% faster than traditional prompting for constrained outputs
- Token healing reduces unnecessary regeneration
- Grammar constraints prevent invalid token generation
**Memory Usage:**
- Minimal overhead vs unconstrained generation
- Grammar compilation cached after first use
- Efficient token filtering at inference time
**Token Efficiency:**
- Prevents wasted tokens on invalid outputs
- No need for retry loops
- Direct path to valid outputs
## Resources
- **Documentation**: https://guidance.readthedocs.io
- **GitHub**: https://github.com/guidance-ai/guidance (18k+ stars)
- **Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
- **Discord**: Community support available
## See Also
- `references/constraints.md` - Comprehensive regex and grammar patterns
- `references/backends.md` - Backend-specific configuration
- `references/examples.md` - Production-ready examples

View File

@@ -0,0 +1,554 @@
# Backend Configuration Guide
Complete guide to configuring Guidance with different LLM backends.
## Table of Contents
- API-Based Models (Anthropic, OpenAI)
- Local Models (Transformers, llama.cpp)
- Backend Comparison
- Performance Tuning
- Advanced Configuration
## API-Based Models
### Anthropic Claude
#### Basic Setup
```python
from guidance import models
# Using environment variable
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Reads ANTHROPIC_API_KEY from environment
# Explicit API key
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key-here"
)
```
#### Available Models
```python
# Claude 3.5 Sonnet (Latest, recommended)
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Claude 3.7 Sonnet (Fast, cost-effective)
lm = models.Anthropic("claude-sonnet-3.7-20250219")
# Claude 3 Opus (Most capable)
lm = models.Anthropic("claude-3-opus-20240229")
# Claude 3.5 Haiku (Fastest, cheapest)
lm = models.Anthropic("claude-3-5-haiku-20241022")
```
#### Configuration Options
```python
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key",
max_tokens=4096, # Max tokens to generate
temperature=0.7, # Sampling temperature (0-1)
top_p=0.9, # Nucleus sampling
timeout=30, # Request timeout (seconds)
max_retries=3 # Retry failed requests
)
```
#### With Context Managers
```python
from guidance import models, system, user, assistant, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You are a helpful assistant."
with user():
lm += "What is the capital of France?"
with assistant():
lm += gen(max_tokens=50)
print(lm)
```
### OpenAI
#### Basic Setup
```python
from guidance import models
# Using environment variable
lm = models.OpenAI("gpt-4o")
# Reads OPENAI_API_KEY from environment
# Explicit API key
lm = models.OpenAI(
model="gpt-4o",
api_key="your-api-key-here"
)
```
#### Available Models
```python
# GPT-4o (Latest, multimodal)
lm = models.OpenAI("gpt-4o")
# GPT-4o Mini (Fast, cost-effective)
lm = models.OpenAI("gpt-4o-mini")
# GPT-4 Turbo
lm = models.OpenAI("gpt-4-turbo")
# GPT-3.5 Turbo (Cheapest)
lm = models.OpenAI("gpt-3.5-turbo")
```
#### Configuration Options
```python
lm = models.OpenAI(
model="gpt-4o-mini",
api_key="your-api-key",
max_tokens=2048,
temperature=0.7,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
timeout=30
)
```
#### Chat Format
```python
from guidance import models, gen
lm = models.OpenAI("gpt-4o-mini")
# OpenAI uses chat format
lm += [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"}
]
# Generate response
lm += gen(max_tokens=50)
```
### Azure OpenAI
```python
from guidance import models
lm = models.AzureOpenAI(
model="gpt-4o",
azure_endpoint="https://your-resource.openai.azure.com/",
api_key="your-azure-api-key",
api_version="2024-02-15-preview",
deployment_name="your-deployment-name"
)
```
## Local Models
### Transformers (Hugging Face)
#### Basic Setup
```python
from guidance.models import Transformers
# Load model from Hugging Face
lm = Transformers("microsoft/Phi-4-mini-instruct")
```
#### GPU Configuration
```python
# Use GPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda"
)
# Use specific GPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda:0" # GPU 0
)
# Use CPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cpu"
)
```
#### Advanced Configuration
```python
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda",
torch_dtype="float16", # Use FP16 (faster, less memory)
load_in_8bit=True, # 8-bit quantization
max_memory={0: "20GB"}, # GPU memory limit
offload_folder="./offload" # Offload to disk if needed
)
```
#### Popular Models
```python
# Phi-4 (Microsoft)
lm = Transformers("microsoft/Phi-4-mini-instruct")
lm = Transformers("microsoft/Phi-3-medium-4k-instruct")
# Llama 3 (Meta)
lm = Transformers("meta-llama/Llama-3.1-8B-Instruct")
lm = Transformers("meta-llama/Llama-3.1-70B-Instruct")
# Mistral (Mistral AI)
lm = Transformers("mistralai/Mistral-7B-Instruct-v0.3")
lm = Transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
# Qwen (Alibaba)
lm = Transformers("Qwen/Qwen2.5-7B-Instruct")
# Gemma (Google)
lm = Transformers("google/gemma-2-9b-it")
```
#### Generation Configuration
```python
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda"
)
# Configure generation
from guidance import gen
result = lm + gen(
max_tokens=100,
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.1
)
```
### llama.cpp
#### Basic Setup
```python
from guidance.models import LlamaCpp
# Load GGUF model
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096 # Context window
)
```
#### GPU Configuration
```python
# Use GPU acceleration
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=35, # Offload 35 layers to GPU
n_threads=8 # CPU threads for remaining layers
)
# Full GPU offload
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=-1 # Offload all layers
)
```
#### Advanced Configuration
```python
lm = LlamaCpp(
model_path="/path/to/llama-3.1-8b-instruct.Q4_K_M.gguf",
n_ctx=8192, # Context window (tokens)
n_gpu_layers=35, # GPU layers
n_threads=8, # CPU threads
n_batch=512, # Batch size for prompt processing
use_mmap=True, # Memory-map the model file
use_mlock=False, # Lock model in RAM
seed=42, # Random seed
verbose=False # Suppress verbose output
)
```
#### Quantized Models
```python
# Q4_K_M (4-bit, recommended for most cases)
lm = LlamaCpp("/path/to/model.Q4_K_M.gguf")
# Q5_K_M (5-bit, better quality)
lm = LlamaCpp("/path/to/model.Q5_K_M.gguf")
# Q8_0 (8-bit, high quality)
lm = LlamaCpp("/path/to/model.Q8_0.gguf")
# F16 (16-bit float, highest quality)
lm = LlamaCpp("/path/to/model.F16.gguf")
```
#### Popular GGUF Models
```python
# Llama 3.1
lm = LlamaCpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
# Mistral
lm = LlamaCpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
# Phi-4
lm = LlamaCpp("phi-4-mini-instruct.Q4_K_M.gguf")
```
## Backend Comparison
### Feature Matrix
| Feature | Anthropic | OpenAI | Transformers | llama.cpp |
|---------|-----------|--------|--------------|-----------|
| Constrained Generation | ✅ Full | ✅ Full | ✅ Full | ✅ Full |
| Token Healing | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
| Streaming | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
| GPU Support | N/A | N/A | ✅ Yes | ✅ Yes |
| Quantization | N/A | N/A | ✅ Yes | ✅ Yes |
| Cost | $$$ | $$$ | Free | Free |
| Latency | Low | Low | Medium | Low |
| Setup Difficulty | Easy | Easy | Medium | Medium |
### Performance Characteristics
**Anthropic Claude:**
- **Latency**: 200-500ms (API call)
- **Throughput**: Limited by API rate limits
- **Cost**: $3-15 per 1M input tokens
- **Best for**: Production systems, high-quality outputs
**OpenAI:**
- **Latency**: 200-400ms (API call)
- **Throughput**: Limited by API rate limits
- **Cost**: $0.15-30 per 1M input tokens
- **Best for**: Cost-sensitive production, gpt-4o-mini
**Transformers:**
- **Latency**: 50-200ms (local inference)
- **Throughput**: GPU-dependent (10-100 tokens/sec)
- **Cost**: Hardware cost only
- **Best for**: Privacy-sensitive, high-volume, experimentation
**llama.cpp:**
- **Latency**: 30-150ms (local inference)
- **Throughput**: Hardware-dependent (20-150 tokens/sec)
- **Cost**: Hardware cost only
- **Best for**: Edge deployment, Apple Silicon, CPU inference
### Memory Requirements
**Transformers (FP16):**
- 7B model: ~14GB GPU VRAM
- 13B model: ~26GB GPU VRAM
- 70B model: ~140GB GPU VRAM (multi-GPU)
**llama.cpp (Q4_K_M):**
- 7B model: ~4.5GB RAM
- 13B model: ~8GB RAM
- 70B model: ~40GB RAM
**Optimization Tips:**
- Use quantized models (Q4_K_M) for lower memory
- Use GPU offloading for faster inference
- Use CPU inference for smaller models (<7B)
## Performance Tuning
### API Models (Anthropic, OpenAI)
#### Reduce Latency
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Use lower max_tokens (faster response)
lm += gen(max_tokens=100) # Instead of 1000
# Use streaming (perceived latency reduction)
for chunk in lm.stream(gen(max_tokens=500)):
print(chunk, end="", flush=True)
```
#### Reduce Cost
```python
# Use cheaper models
lm = models.Anthropic("claude-3-5-haiku-20241022") # vs Sonnet
lm = models.OpenAI("gpt-4o-mini") # vs gpt-4o
# Reduce context size
# - Keep prompts concise
# - Avoid large few-shot examples
# - Use max_tokens limits
```
### Local Models (Transformers, llama.cpp)
#### Optimize GPU Usage
```python
from guidance.models import Transformers
# Use FP16 for 2x speedup
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
torch_dtype="float16"
)
# Use 8-bit quantization for 4x memory reduction
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
load_in_8bit=True
)
# Use flash attention (requires flash-attn package)
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
use_flash_attention_2=True
)
```
#### Optimize llama.cpp
```python
from guidance.models import LlamaCpp
# Maximize GPU layers
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_gpu_layers=-1 # All layers on GPU
)
# Optimize batch size
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_batch=512, # Larger batch = faster prompt processing
n_gpu_layers=-1
)
# Use Metal (Apple Silicon)
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_gpu_layers=-1, # Use Metal GPU acceleration
use_mmap=True
)
```
#### Batch Processing
```python
# Process multiple requests efficiently
requests = [
"What is 2+2?",
"What is the capital of France?",
"What is photosynthesis?"
]
# Bad: Sequential processing
for req in requests:
lm = Transformers("microsoft/Phi-4-mini-instruct")
lm += req + gen(max_tokens=50)
# Good: Reuse loaded model
lm = Transformers("microsoft/Phi-4-mini-instruct")
for req in requests:
lm += req + gen(max_tokens=50)
```
## Advanced Configuration
### Custom Model Configurations
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from guidance.models import Transformers
# Load custom model
tokenizer = AutoTokenizer.from_pretrained("your-model")
model = AutoModelForCausalLM.from_pretrained(
"your-model",
device_map="auto",
torch_dtype="float16"
)
# Use with Guidance
lm = Transformers(model=model, tokenizer=tokenizer)
```
### Environment Variables
```bash
# API keys
export ANTHROPIC_API_KEY="sk-ant-..."
export OPENAI_API_KEY="sk-..."
# Transformers cache
export HF_HOME="/path/to/cache"
export TRANSFORMERS_CACHE="/path/to/cache"
# GPU selection
export CUDA_VISIBLE_DEVICES=0,1 # Use GPU 0 and 1
```
### Debugging
```python
# Enable verbose logging
import logging
logging.basicConfig(level=logging.DEBUG)
# Check backend info
lm = models.Anthropic("claude-sonnet-4-5-20250929")
print(f"Model: {lm.model_name}")
print(f"Backend: {lm.backend}")
# Check GPU usage (Transformers)
lm = Transformers("microsoft/Phi-4-mini-instruct", device="cuda")
print(f"Device: {lm.device}")
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
```
## Resources
- **Anthropic Docs**: https://docs.anthropic.com
- **OpenAI Docs**: https://platform.openai.com/docs
- **Hugging Face Models**: https://huggingface.co/models
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
- **GGUF Models**: https://huggingface.co/models?library=gguf

View File

@@ -0,0 +1,674 @@
# Comprehensive Constraint Patterns
Guide to regex constraints, grammar-based generation, and token healing in Guidance.
## Table of Contents
- Regex Constraints
- Grammar-Based Generation
- Token Healing
- Selection Constraints
- Complex Patterns
- Performance Optimization
## Regex Constraints
### Basic Patterns
#### Numeric Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Integer (positive)
lm += "Age: " + gen("age", regex=r"[0-9]+")
# Integer (with negatives)
lm += "Temperature: " + gen("temp", regex=r"-?[0-9]+")
# Float (positive)
lm += "Price: $" + gen("price", regex=r"[0-9]+\.[0-9]{2}")
# Float (with negatives and optional decimals)
lm += "Value: " + gen("value", regex=r"-?[0-9]+(\.[0-9]+)?")
# Percentage (0-100)
lm += "Progress: " + gen("progress", regex=r"(100|[0-9]{1,2})")
# Range (1-5 stars)
lm += "Rating: " + gen("rating", regex=r"[1-5]") + " stars"
```
#### Text Constraints
```python
# Alphabetic only
lm += "Name: " + gen("name", regex=r"[A-Za-z]+")
# Alphabetic with spaces
lm += "Full Name: " + gen("full_name", regex=r"[A-Za-z ]+")
# Alphanumeric
lm += "Username: " + gen("username", regex=r"[A-Za-z0-9_]+")
# Capitalized words
lm += "Title: " + gen("title", regex=r"[A-Z][a-z]+( [A-Z][a-z]+)*")
# Lowercase only
lm += "Code: " + gen("code", regex=r"[a-z0-9-]+")
# Specific length
lm += "ID: " + gen("id", regex=r"[A-Z]{3}-[0-9]{6}") # e.g., "ABC-123456"
```
#### Date and Time Constraints
```python
# Date (YYYY-MM-DD)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
# Date (MM/DD/YYYY)
lm += "Date: " + gen("date_us", regex=r"\d{2}/\d{2}/\d{4}")
# Time (HH:MM)
lm += "Time: " + gen("time", regex=r"\d{2}:\d{2}")
# Time (HH:MM:SS)
lm += "Time: " + gen("time_full", regex=r"\d{2}:\d{2}:\d{2}")
# ISO 8601 datetime
lm += "Timestamp: " + gen(
"timestamp",
regex=r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z"
)
# Year (YYYY)
lm += "Year: " + gen("year", regex=r"(19|20)\d{2}")
# Month name
lm += "Month: " + gen(
"month",
regex=r"(January|February|March|April|May|June|July|August|September|October|November|December)"
)
```
#### Contact Information
```python
# Email
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
)
# Phone (US format)
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
# Phone (international format)
lm += "Phone: " + gen("phone_intl", regex=r"\+[0-9]{1,3}-[0-9]{1,14}")
# ZIP code (US)
lm += "ZIP: " + gen("zip", regex=r"\d{5}(-\d{4})?")
# Postal code (Canada)
lm += "Postal: " + gen("postal", regex=r"[A-Z]\d[A-Z] \d[A-Z]\d")
# URL
lm += "URL: " + gen(
"url",
regex=r"https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/[a-zA-Z0-9._~:/?#\[\]@!$&'()*+,;=-]*)?"
)
```
### Advanced Patterns
#### JSON Field Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# String field with quotes
lm += '"name": ' + gen("name", regex=r'"[A-Za-z ]+"')
# Numeric field (no quotes)
lm += '"age": ' + gen("age", regex=r"[0-9]+")
# Boolean field
lm += '"active": ' + gen("active", regex=r"(true|false)")
# Null field
lm += '"optional": ' + gen("optional", regex=r"(null|[0-9]+)")
# Array of strings
lm += '"tags": [' + gen(
"tags",
regex=r'"[a-z]+"(, "[a-z]+")*'
) + ']'
# Complete JSON object
lm += """{
"name": """ + gen("name", regex=r'"[A-Za-z ]+"') + """,
"age": """ + gen("age", regex=r"[0-9]+") + """,
"email": """ + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + """
}"""
```
#### Code Patterns
```python
# Python variable name
lm += "Variable: " + gen("var", regex=r"[a-z_][a-z0-9_]*")
# Python function name
lm += "Function: " + gen("func", regex=r"[a-z_][a-z0-9_]*")
# Hex color code
lm += "Color: #" + gen("color", regex=r"[0-9A-Fa-f]{6}")
# UUID
lm += "UUID: " + gen(
"uuid",
regex=r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
)
# Git commit hash (short)
lm += "Commit: " + gen("commit", regex=r"[0-9a-f]{7}")
# Semantic version
lm += "Version: " + gen("version", regex=r"[0-9]+\.[0-9]+\.[0-9]+")
# IP address (IPv4)
lm += "IP: " + gen(
"ip",
regex=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
)
```
#### Domain-Specific Patterns
```python
# Credit card number
lm += "Card: " + gen("card", regex=r"\d{4}-\d{4}-\d{4}-\d{4}")
# Social Security Number (US)
lm += "SSN: " + gen("ssn", regex=r"\d{3}-\d{2}-\d{4}")
# ISBN-13
lm += "ISBN: " + gen("isbn", regex=r"978-\d{1,5}-\d{1,7}-\d{1,7}-\d")
# License plate (US)
lm += "Plate: " + gen("plate", regex=r"[A-Z]{3}-\d{4}")
# Currency amount
lm += "Amount: $" + gen("amount", regex=r"[0-9]{1,3}(,[0-9]{3})*\.[0-9]{2}")
# Percentage with decimal
lm += "Rate: " + gen("rate", regex=r"[0-9]+\.[0-9]{1,2}%")
```
## Grammar-Based Generation
### JSON Grammar
```python
from guidance import models, gen, guidance
@guidance
def json_object(lm):
"""Generate valid JSON object."""
lm += "{\n"
# Name field (required)
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
# Age field (required)
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
# Email field (required)
lm += ' "email": ' + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + ",\n"
# Active field (required, boolean)
lm += ' "active": ' + gen("active", regex=r"(true|false)") + "\n"
lm += "}"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = json_object(lm)
print(lm) # Valid JSON guaranteed
```
### Nested JSON Grammar
```python
@guidance
def nested_json(lm):
"""Generate nested JSON structure."""
lm += "{\n"
# User object
lm += ' "user": {\n'
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + "\n"
lm += " },\n"
# Address object
lm += ' "address": {\n'
lm += ' "street": ' + gen("street", regex=r'"[A-Za-z0-9 ]+"') + ",\n"
lm += ' "city": ' + gen("city", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "zip": ' + gen("zip", regex=r'"\d{5}"') + "\n"
lm += " }\n"
lm += "}"
return lm
```
### Array Grammar
```python
@guidance
def json_array(lm, count=3):
"""Generate JSON array with fixed count."""
lm += "[\n"
for i in range(count):
lm += " {\n"
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + "\n"
lm += " }"
if i < count - 1:
lm += ","
lm += "\n"
lm += "]"
return lm
```
### XML Grammar
```python
@guidance
def xml_document(lm):
"""Generate valid XML document."""
lm += '<?xml version="1.0"?>\n'
lm += "<person>\n"
# Name element
lm += " <name>" + gen("name", regex=r"[A-Za-z ]+") + "</name>\n"
# Age element
lm += " <age>" + gen("age", regex=r"[0-9]+") + "</age>\n"
# Email element
lm += " <email>" + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
) + "</email>\n"
lm += "</person>"
return lm
```
### CSV Grammar
```python
@guidance
def csv_row(lm):
"""Generate CSV row."""
lm += gen("name", regex=r"[A-Za-z ]+") + ","
lm += gen("age", regex=r"[0-9]+") + ","
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
return lm
@guidance
def csv_document(lm, rows=5):
"""Generate complete CSV."""
# Header
lm += "Name,Age,Email\n"
# Rows
for i in range(rows):
lm = csv_row(lm)
if i < rows - 1:
lm += "\n"
return lm
```
## Token Healing
### How Token Healing Works
**Problem:** Tokenization creates unnatural boundaries.
```python
# Example without token healing
prompt = "The capital of France is "
# Tokenization: ["The", " capital", " of", " France", " is", " "]
# Model sees last token: " "
# First generated token might include leading space: " Paris"
# Result: "The capital of France is Paris" (double space)
```
**Solution:** Guidance backs up and regenerates the last token.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Token healing enabled by default
lm += "The capital of France is " + gen("capital", max_tokens=5)
# Process:
# 1. Back up to token before " is "
# 2. Regenerate " is" + "capital" together
# 3. Result: "The capital of France is Paris" (correct)
```
### Token Healing Examples
#### Natural Continuations
```python
# Before token healing
lm += "The function name is get" + gen("rest")
# Might generate: "The function name is get User" (space before User)
# With token healing
lm += "The function name is get" + gen("rest")
# Generates: "The function name is getUser" (correct camelCase)
```
#### Code Generation
```python
# Function name completion
lm += "def calculate_" + gen("rest", stop="(")
# Token healing ensures smooth connection: "calculate_total"
# Variable name completion
lm += "my_" + gen("var_name", regex=r"[a-z_]+")
# Token healing ensures: "my_variable_name" (not "my_ variable_name")
```
#### Domain-Specific Terms
```python
# Medical terms
lm += "The patient has hyper" + gen("condition")
# Token healing helps: "hypertension" (not "hyper tension")
# Technical terms
lm += "Using micro" + gen("tech")
# Token healing helps: "microservices" (not "micro services")
```
### Disabling Token Healing
```python
# Disable token healing if needed (rare)
lm += gen("text", token_healing=False)
```
## Selection Constraints
### Basic Selection
```python
from guidance import models, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Simple selection
lm += "Status: " + select(["active", "inactive", "pending"], name="status")
# Boolean selection
lm += "Approved: " + select(["Yes", "No"], name="approved")
# Multiple choice
lm += "Answer: " + select(
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
name="answer"
)
```
### Conditional Selection
```python
from guidance import models, select, gen, guidance
@guidance
def conditional_fields(lm):
"""Generate fields conditionally based on type."""
lm += "Type: " + select(["person", "company"], name="type")
if lm["type"] == "person":
lm += "\nName: " + gen("name", regex=r"[A-Za-z ]+")
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
else:
lm += "\nCompany Name: " + gen("company", regex=r"[A-Za-z ]+")
lm += "\nEmployees: " + gen("employees", regex=r"[0-9]+")
return lm
```
### Repeated Selection
```python
@guidance
def multiple_selections(lm):
"""Select multiple items."""
lm += "Select 3 colors:\n"
colors = ["red", "blue", "green", "yellow", "purple"]
for i in range(3):
lm += f"{i+1}. " + select(colors, name=f"color_{i}") + "\n"
return lm
```
## Complex Patterns
### Pattern 1: Structured Forms
```python
@guidance
def user_form(lm):
"""Generate structured user form."""
lm += "=== User Registration ===\n\n"
# Name (alphabetic only)
lm += "Full Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Age (numeric)
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
# Email (validated format)
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
stop="\n"
) + "\n"
# Phone (US format)
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") + "\n"
# Account type (selection)
lm += "Account Type: " + select(
["Standard", "Premium", "Enterprise"],
name="account_type"
) + "\n"
# Active status (boolean)
lm += "Active: " + select(["Yes", "No"], name="active") + "\n"
return lm
```
### Pattern 2: Multi-Entity Extraction
```python
@guidance
def extract_entities(lm, text):
"""Extract multiple entities with constraints."""
lm += f"Text: {text}\n\n"
# Person name (alphabetic)
lm += "Person: " + gen("person", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Organization (alphanumeric with spaces)
lm += "Organization: " + gen(
"organization",
regex=r"[A-Za-z0-9 ]+",
stop="\n"
) + "\n"
# Date (YYYY-MM-DD format)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") + "\n"
# Location (alphabetic with spaces)
lm += "Location: " + gen("location", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Amount (currency)
lm += "Amount: $" + gen("amount", regex=r"[0-9,]+\.[0-9]{2}") + "\n"
return lm
```
### Pattern 3: Code Generation
```python
@guidance
def generate_python_function(lm):
"""Generate Python function with constraints."""
# Function name (valid Python identifier)
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
# Parameter name
lm += gen("param", regex=r"[a-z_][a-z0-9_]*") + "):\n"
# Docstring
lm += ' """' + gen("docstring", stop='"""', max_tokens=50) + '"""\n'
# Function body (constrained to valid Python)
lm += " return " + gen("return_value", stop="\n") + "\n"
return lm
```
### Pattern 4: Hierarchical Data
```python
@guidance
def org_chart(lm):
"""Generate organizational chart."""
lm += "Company: " + gen("company", regex=r"[A-Za-z ]+") + "\n\n"
# CEO
lm += "CEO: " + gen("ceo", regex=r"[A-Za-z ]+") + "\n"
# Departments
for dept in ["Engineering", "Sales", "Marketing"]:
lm += f"\n{dept} Department:\n"
lm += " Head: " + gen(f"{dept.lower()}_head", regex=r"[A-Za-z ]+") + "\n"
lm += " Size: " + gen(f"{dept.lower()}_size", regex=r"[0-9]+") + " employees\n"
return lm
```
## Performance Optimization
### Best Practices
#### 1. Use Specific Patterns
```python
# ✅ Good: Specific pattern
lm += gen("age", regex=r"[0-9]{1,3}") # Fast
# ❌ Bad: Overly broad pattern
lm += gen("age", regex=r"[0-9]+") # Slower
```
#### 2. Limit Max Tokens
```python
# ✅ Good: Reasonable limit
lm += gen("name", max_tokens=30)
# ❌ Bad: No limit
lm += gen("name") # May generate forever
```
#### 3. Use stop Sequences
```python
# ✅ Good: Stop at newline
lm += gen("line", stop="\n")
# ❌ Bad: Rely on max_tokens
lm += gen("line", max_tokens=100)
```
#### 4. Cache Compiled Grammars
```python
# Grammars are cached automatically after first use
# No manual caching needed
@guidance
def reusable_pattern(lm):
"""This grammar is compiled once and cached."""
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
return lm
# First call: compiles grammar
lm = reusable_pattern(lm)
# Subsequent calls: uses cached grammar (fast)
lm = reusable_pattern(lm)
```
#### 5. Avoid Overlapping Constraints
```python
# ✅ Good: Clear constraints
lm += gen("age", regex=r"[0-9]+", max_tokens=3)
# ❌ Bad: Conflicting constraints
lm += gen("age", regex=r"[0-9]{2}", max_tokens=10) # max_tokens unnecessary
```
### Performance Benchmarks
**Regex vs Free Generation:**
- Simple regex (digits): ~1.2x slower than free gen
- Complex regex (email): ~1.5x slower than free gen
- Grammar-based: ~2x slower than free gen
**But:**
- 100% valid outputs (vs ~70% with free gen + validation)
- No retry loops needed
- Overall faster end-to-end for structured outputs
**Optimization Tips:**
- Use regex for critical fields only
- Use `select()` for small fixed sets (fastest)
- Use `stop` sequences when possible (faster than max_tokens)
- Cache compiled grammars by reusing functions
## Resources
- **Token Healing Paper**: https://arxiv.org/abs/2306.17648
- **Guidance Docs**: https://guidance.readthedocs.io
- **GitHub**: https://github.com/guidance-ai/guidance

View File

@@ -0,0 +1,767 @@
# Production-Ready Examples
Real-world examples of using Guidance for structured generation, agents, and workflows.
## Table of Contents
- JSON Generation
- Data Extraction
- Classification Systems
- Agent Systems
- Multi-Step Workflows
- Code Generation
- Production Tips
## JSON Generation
### Basic JSON
```python
from guidance import models, gen, guidance
@guidance
def generate_user(lm):
"""Generate valid user JSON."""
lm += "{\n"
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
lm += ' "email": ' + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + "\n"
lm += "}"
return lm
# Use it
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm += "Generate a user profile:\n"
lm = generate_user(lm)
print(lm)
# Output: Valid JSON guaranteed
```
### Nested JSON
```python
@guidance
def generate_order(lm):
"""Generate nested order JSON."""
lm += "{\n"
# Customer info
lm += ' "customer": {\n'
lm += ' "name": ' + gen("customer_name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "email": ' + gen(
"customer_email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + "\n"
lm += " },\n"
# Order details
lm += ' "order": {\n'
lm += ' "id": ' + gen("order_id", regex=r'"ORD-[0-9]{6}"') + ",\n"
lm += ' "date": ' + gen("order_date", regex=r'"\d{4}-\d{2}-\d{2}"') + ",\n"
lm += ' "total": ' + gen("order_total", regex=r"[0-9]+\.[0-9]{2}") + "\n"
lm += " },\n"
# Status
lm += ' "status": ' + gen(
"status",
regex=r'"(pending|processing|shipped|delivered)"'
) + "\n"
lm += "}"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_order(lm)
```
### JSON Array
```python
@guidance
def generate_user_list(lm, count=3):
"""Generate JSON array of users."""
lm += "[\n"
for i in range(count):
lm += " {\n"
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "active": ' + gen(f"active_{i}", regex=r"(true|false)") + "\n"
lm += " }"
if i < count - 1:
lm += ","
lm += "\n"
lm += "]"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_user_list(lm, count=5)
```
### Dynamic JSON Schema
```python
import json
from guidance import models, gen, guidance
@guidance
def json_from_schema(lm, schema):
"""Generate JSON matching a schema."""
lm += "{\n"
fields = list(schema["properties"].items())
for i, (field_name, field_schema) in enumerate(fields):
lm += f' "{field_name}": '
# Handle different types
if field_schema["type"] == "string":
if "pattern" in field_schema:
lm += gen(field_name, regex=f'"{field_schema["pattern"]}"')
else:
lm += gen(field_name, regex=r'"[^"]+"')
elif field_schema["type"] == "number":
lm += gen(field_name, regex=r"[0-9]+(\.[0-9]+)?")
elif field_schema["type"] == "integer":
lm += gen(field_name, regex=r"[0-9]+")
elif field_schema["type"] == "boolean":
lm += gen(field_name, regex=r"(true|false)")
if i < len(fields) - 1:
lm += ","
lm += "\n"
lm += "}"
return lm
# Define schema
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"score": {"type": "number"},
"active": {"type": "boolean"}
}
}
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = json_from_schema(lm, schema)
```
## Data Extraction
### Extract from Text
```python
from guidance import models, gen, guidance, system, user, assistant
@guidance
def extract_person_info(lm, text):
"""Extract structured info from text."""
lm += f"Text: {text}\n\n"
with assistant():
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
lm += "Occupation: " + gen("occupation", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
stop="\n"
) + "\n"
return lm
text = "John Smith is a 35-year-old software engineer. Contact: john@example.com"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You extract structured information from text."
with user():
lm = extract_person_info(lm, text)
print(f"Name: {lm['name']}")
print(f"Age: {lm['age']}")
print(f"Occupation: {lm['occupation']}")
print(f"Email: {lm['email']}")
```
### Multi-Entity Extraction
```python
@guidance
def extract_entities(lm, text):
"""Extract multiple entity types."""
lm += f"Analyze: {text}\n\n"
# Person entities
lm += "People:\n"
for i in range(3): # Up to 3 people
lm += f"- " + gen(f"person_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Organization entities
lm += "\nOrganizations:\n"
for i in range(2): # Up to 2 orgs
lm += f"- " + gen(f"org_{i}", regex=r"[A-Za-z0-9 ]+", stop="\n") + "\n"
# Dates
lm += "\nDates:\n"
for i in range(2): # Up to 2 dates
lm += f"- " + gen(f"date_{i}", regex=r"\d{4}-\d{2}-\d{2}", stop="\n") + "\n"
# Locations
lm += "\nLocations:\n"
for i in range(2): # Up to 2 locations
lm += f"- " + gen(f"location_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
return lm
text = """
Tim Cook and Satya Nadella met at Microsoft headquarters in Redmond on 2024-09-15
to discuss the collaboration between Apple and Microsoft. The meeting continued
in Cupertino on 2024-09-20.
"""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = extract_entities(lm, text)
```
### Batch Extraction
```python
@guidance
def batch_extract(lm, texts):
"""Extract from multiple texts."""
lm += "Batch Extraction Results:\n\n"
for i, text in enumerate(texts):
lm += f"=== Item {i+1} ===\n"
lm += f"Text: {text}\n"
lm += "Name: " + gen(f"name_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Sentiment: " + gen(
f"sentiment_{i}",
regex=r"(positive|negative|neutral)",
stop="\n"
) + "\n\n"
return lm
texts = [
"Alice is happy with the product",
"Bob is disappointed with the service",
"Carol has no strong feelings either way"
]
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = batch_extract(lm, texts)
```
## Classification Systems
### Sentiment Analysis
```python
from guidance import models, select, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
text = "This product is absolutely amazing! Best purchase ever."
lm += f"Text: {text}\n\n"
lm += "Sentiment: " + select(
["positive", "negative", "neutral"],
name="sentiment"
)
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]{1,3}") + "%\n"
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=50)
print(f"Sentiment: {lm['sentiment']}")
print(f"Confidence: {lm['confidence']}%")
print(f"Reasoning: {lm['reasoning']}")
```
### Multi-Label Classification
```python
@guidance
def classify_article(lm, text):
"""Classify article with multiple labels."""
lm += f"Article: {text}\n\n"
# Primary category
lm += "Primary Category: " + select(
["Technology", "Business", "Science", "Politics", "Entertainment"],
name="primary_category"
) + "\n"
# Secondary categories (up to 3)
lm += "\nSecondary Categories:\n"
categories = ["Technology", "Business", "Science", "Politics", "Entertainment"]
for i in range(3):
lm += f"{i+1}. " + select(categories, name=f"secondary_{i}") + "\n"
# Tags
lm += "\nTags: " + gen("tags", stop="\n", max_tokens=50) + "\n"
# Target audience
lm += "Target Audience: " + select(
["General", "Expert", "Beginner"],
name="audience"
)
return lm
article = """
Apple announced new AI features in iOS 18, leveraging machine learning to improve
battery life and performance. The company's stock rose 5% following the announcement.
"""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = classify_article(lm, article)
```
### Intent Classification
```python
@guidance
def classify_intent(lm, message):
"""Classify user intent."""
lm += f"User Message: {message}\n\n"
# Intent
lm += "Intent: " + select(
["question", "complaint", "request", "feedback", "other"],
name="intent"
) + "\n"
# Urgency
lm += "Urgency: " + select(
["low", "medium", "high", "critical"],
name="urgency"
) + "\n"
# Department
lm += "Route To: " + select(
["support", "sales", "billing", "technical"],
name="department"
) + "\n"
# Sentiment
lm += "Sentiment: " + select(
["positive", "neutral", "negative"],
name="sentiment"
)
return lm
message = "My account was charged twice for the same order. Need help ASAP!"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = classify_intent(lm, message)
print(f"Intent: {lm['intent']}")
print(f"Urgency: {lm['urgency']}")
print(f"Department: {lm['department']}")
```
## Agent Systems
### ReAct Agent
```python
from guidance import models, gen, select, guidance
@guidance(stateless=False)
def react_agent(lm, question, tools, max_rounds=5):
"""ReAct agent with tool use."""
lm += f"Question: {question}\n\n"
for round in range(max_rounds):
# Thought
lm += f"Thought {round+1}: " + gen("thought", stop="\n", max_tokens=100) + "\n"
# Action selection
lm += "Action: " + select(
list(tools.keys()) + ["answer"],
name="action"
)
if lm["action"] == "answer":
lm += "\n\nFinal Answer: " + gen("answer", max_tokens=200)
break
# Action input
lm += "\nAction Input: " + gen("action_input", stop="\n", max_tokens=100) + "\n"
# Execute tool
if lm["action"] in tools:
try:
result = tools[lm["action"]](lm["action_input"])
lm += f"Observation: {result}\n\n"
except Exception as e:
lm += f"Observation: Error - {str(e)}\n\n"
return lm
# Define tools
tools = {
"calculator": lambda expr: eval(expr),
"search": lambda query: f"Search results for '{query}': [Mock results]",
"weather": lambda city: f"Weather in {city}: Sunny, 72°F"
}
# Use agent
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = react_agent(lm, "What is (25 * 4) + 10?", tools)
print(lm["answer"])
```
### Multi-Agent System
```python
@guidance
def coordinator_agent(lm, task):
"""Coordinator that delegates to specialists."""
lm += f"Task: {task}\n\n"
# Determine which specialist to use
lm += "Specialist: " + select(
["researcher", "writer", "coder", "analyst"],
name="specialist"
) + "\n"
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=100) + "\n"
return lm
@guidance
def researcher_agent(lm, query):
"""Research specialist."""
lm += f"Research Query: {query}\n\n"
lm += "Findings:\n"
for i in range(3):
lm += f"{i+1}. " + gen(f"finding_{i}", stop="\n", max_tokens=100) + "\n"
return lm
@guidance
def writer_agent(lm, topic):
"""Writing specialist."""
lm += f"Topic: {topic}\n\n"
lm += "Title: " + gen("title", stop="\n", max_tokens=50) + "\n"
lm += "Content:\n" + gen("content", max_tokens=500)
return lm
# Coordination workflow
task = "Write an article about AI safety"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = coordinator_agent(lm, task)
specialist = lm["specialist"]
if specialist == "researcher":
lm = researcher_agent(lm, task)
elif specialist == "writer":
lm = writer_agent(lm, task)
```
### Tool Use with Validation
```python
@guidance(stateless=False)
def validated_tool_agent(lm, question):
"""Agent with validated tool calls."""
tools = {
"add": lambda a, b: float(a) + float(b),
"multiply": lambda a, b: float(a) * float(b),
"divide": lambda a, b: float(a) / float(b) if float(b) != 0 else "Error: Division by zero"
}
lm += f"Question: {question}\n\n"
for i in range(5):
# Select tool
lm += "Tool: " + select(list(tools.keys()) + ["done"], name="tool")
if lm["tool"] == "done":
lm += "\nAnswer: " + gen("answer", max_tokens=100)
break
# Get validated numeric arguments
lm += "\nArg1: " + gen("arg1", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
lm += "Arg2: " + gen("arg2", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
# Execute
result = tools[lm["tool"]](lm["arg1"], lm["arg2"])
lm += f"Result: {result}\n\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = validated_tool_agent(lm, "What is (10 + 5) * 3?")
```
## Multi-Step Workflows
### Chain of Thought
```python
@guidance
def chain_of_thought(lm, question):
"""Multi-step reasoning with CoT."""
lm += f"Question: {question}\n\n"
# Generate reasoning steps
lm += "Let me think step by step:\n\n"
for i in range(4):
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
# Final answer
lm += "\nTherefore, the answer is: " + gen("answer", stop="\n", max_tokens=50)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = chain_of_thought(lm, "If a train travels 60 mph for 2.5 hours, how far does it go?")
print(lm["answer"])
```
### Self-Consistency
```python
@guidance
def self_consistency(lm, question, num_samples=3):
"""Generate multiple reasoning paths and aggregate."""
lm += f"Question: {question}\n\n"
answers = []
for i in range(num_samples):
lm += f"=== Attempt {i+1} ===\n"
lm += "Reasoning: " + gen(f"reasoning_{i}", stop="\n", max_tokens=100) + "\n"
lm += "Answer: " + gen(f"answer_{i}", stop="\n", max_tokens=50) + "\n\n"
answers.append(lm[f"answer_{i}"])
# Aggregate (simple majority vote)
from collections import Counter
most_common = Counter(answers).most_common(1)[0][0]
lm += f"Final Answer (by majority): {most_common}\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = self_consistency(lm, "What is 15% of 200?")
```
### Planning and Execution
```python
@guidance
def plan_and_execute(lm, goal):
"""Plan tasks then execute them."""
lm += f"Goal: {goal}\n\n"
# Planning phase
lm += "Plan:\n"
num_steps = 4
for i in range(num_steps):
lm += f"{i+1}. " + gen(f"plan_step_{i}", stop="\n", max_tokens=100) + "\n"
# Execution phase
lm += "\nExecution:\n\n"
for i in range(num_steps):
lm += f"Step {i+1}: {lm[f'plan_step_{i}']}\n"
lm += "Status: " + select(["completed", "in-progress", "blocked"], name=f"status_{i}") + "\n"
lm += "Result: " + gen(f"result_{i}", stop="\n", max_tokens=150) + "\n\n"
# Summary
lm += "Summary: " + gen("summary", max_tokens=200)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = plan_and_execute(lm, "Build a REST API for a blog platform")
```
## Code Generation
### Python Function
```python
@guidance
def generate_python_function(lm, description):
"""Generate Python function from description."""
lm += f"Description: {description}\n\n"
# Function signature
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
lm += gen("params", regex=r"[a-z_][a-z0-9_]*(, [a-z_][a-z0-9_]*)*") + "):\n"
# Docstring
lm += ' """' + gen("docstring", stop='"""', max_tokens=100) + '"""\n'
# Function body
lm += " " + gen("body", stop="\n", max_tokens=200) + "\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_python_function(lm, "Check if a number is prime")
print(lm)
```
### SQL Query
```python
@guidance
def generate_sql(lm, description):
"""Generate SQL query from description."""
lm += f"Description: {description}\n\n"
lm += "SQL Query:\n"
# SELECT clause
lm += "SELECT " + gen("select_clause", stop=" FROM", max_tokens=100)
# FROM clause
lm += " FROM " + gen("from_clause", stop=" WHERE", max_tokens=50)
# WHERE clause (optional)
lm += " WHERE " + gen("where_clause", stop=";", max_tokens=100) + ";"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_sql(lm, "Get all users who signed up in the last 30 days")
```
### API Endpoint
```python
@guidance
def generate_api_endpoint(lm, description):
"""Generate REST API endpoint."""
lm += f"Description: {description}\n\n"
# HTTP method
lm += "Method: " + select(["GET", "POST", "PUT", "DELETE"], name="method") + "\n"
# Path
lm += "Path: /" + gen("path", regex=r"[a-z0-9/-]+", stop="\n") + "\n"
# Request body (if POST/PUT)
if lm["method"] in ["POST", "PUT"]:
lm += "\nRequest Body:\n"
lm += "{\n"
lm += ' "field1": ' + gen("field1", regex=r'"[a-z_]+"') + ",\n"
lm += ' "field2": ' + gen("field2", regex=r'"[a-z_]+"') + "\n"
lm += "}\n"
# Response
lm += "\nResponse (200 OK):\n"
lm += "{\n"
lm += ' "status": "success",\n'
lm += ' "data": ' + gen("response_data", max_tokens=100) + "\n"
lm += "}\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_api_endpoint(lm, "Create a new blog post")
```
## Production Tips
### Error Handling
```python
@guidance
def safe_extraction(lm, text):
"""Extract with fallback handling."""
try:
lm += f"Text: {text}\n"
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n", max_tokens=30)
return lm
except Exception as e:
# Fallback to less strict extraction
lm += f"Text: {text}\n"
lm += "Name: " + gen("name", stop="\n", max_tokens=30)
return lm
```
### Caching
```python
from functools import lru_cache
@lru_cache(maxsize=100)
def cached_generation(text):
"""Cache LLM generations."""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm += f"Analyze: {text}\n"
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
return lm["sentiment"]
# First call: hits LLM
result1 = cached_generation("This is great!")
# Second call: returns cached result
result2 = cached_generation("This is great!") # Instant!
```
### Monitoring
```python
import time
@guidance
def monitored_generation(lm, text):
"""Track generation metrics."""
start_time = time.time()
lm += f"Text: {text}\n"
lm += "Analysis: " + gen("analysis", max_tokens=100)
elapsed = time.time() - start_time
# Log metrics
print(f"Generation time: {elapsed:.2f}s")
print(f"Output length: {len(lm['analysis'])} chars")
return lm
```
### Batch Processing
```python
def batch_process(texts, batch_size=10):
"""Process texts in batches."""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
for text in batch:
lm += f"Text: {text}\n"
lm += "Sentiment: " + select(
["positive", "negative", "neutral"],
name=f"sentiment_{i}"
) + "\n\n"
results.extend([lm[f"sentiment_{i}"] for i in range(len(batch))])
return results
```
## Resources
- **Guidance Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
- **Guidance Docs**: https://guidance.readthedocs.io
- **Community Examples**: https://github.com/guidance-ai/guidance/discussions

307
skills/mlops/llava/SKILL.md Normal file
View File

@@ -0,0 +1,307 @@
---
name: llava
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [transformers, torch, pillow]
metadata:
hermes:
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
---
# LLaVA - Large Language and Vision Assistant
Open-source vision-language model for conversational image understanding.
## When to use LLaVA
**Use when:**
- Building vision-language chatbots
- Visual question answering (VQA)
- Image description and captioning
- Multi-turn image conversations
- Visual instruction following
- Document understanding with images
**Metrics**:
- **23,000+ GitHub stars**
- GPT-4V level capabilities (targeted)
- Apache 2.0 License
- Multiple model sizes (7B-34B params)
**Use alternatives instead**:
- **GPT-4V**: Highest quality, API-based
- **CLIP**: Simple zero-shot classification
- **BLIP-2**: Better for captioning only
- **Flamingo**: Research, not open-source
## Quick start
### Installation
```bash
# Clone repository
git clone https://github.com/haotian-liu/LLaVA
cd LLaVA
# Install
pip install -e .
```
### Basic usage
```python
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from PIL import Image
import torch
# Load model
model_path = "liuhaotian/llava-v1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path)
)
# Load image
image = Image.open("image.jpg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
# Create conversation
conv = conv_templates["llava_v1"].copy()
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Generate response
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
max_new_tokens=512
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
print(response)
```
## Available models
| Model | Parameters | VRAM | Quality |
|-------|------------|------|---------|
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
```python
# Load different models
model_7b = "liuhaotian/llava-v1.5-7b"
model_13b = "liuhaotian/llava-v1.5-13b"
model_34b = "liuhaotian/llava-v1.6-34b"
# 4-bit quantization for lower VRAM
load_4bit = True # Reduces VRAM by ~4×
```
## CLI usage
```bash
# Single image query
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file image.jpg \
--query "What is in this image?"
# Multi-turn conversation
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file image.jpg
# Then type questions interactively
```
## Web UI (Gradio)
```bash
# Launch Gradio interface
python -m llava.serve.gradio_web_server \
--model-path liuhaotian/llava-v1.5-7b \
--load-4bit # Optional: reduce VRAM
# Access at http://localhost:7860
```
## Multi-turn conversations
```python
# Initialize conversation
conv = conv_templates["llava_v1"].copy()
# Turn 1
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
conv.append_message(conv.roles[1], None)
response1 = generate(conv, model, image) # "A dog playing in a park"
# Turn 2
conv.messages[-1][1] = response1 # Add previous response
conv.append_message(conv.roles[0], "What breed is the dog?")
conv.append_message(conv.roles[1], None)
response2 = generate(conv, model, image) # "Golden Retriever"
# Turn 3
conv.messages[-1][1] = response2
conv.append_message(conv.roles[0], "What time of day is it?")
conv.append_message(conv.roles[1], None)
response3 = generate(conv, model, image)
```
## Common tasks
### Image captioning
```python
question = "Describe this image in detail."
response = ask(model, image, question)
```
### Visual question answering
```python
question = "How many people are in the image?"
response = ask(model, image, question)
```
### Object detection (textual)
```python
question = "List all the objects you can see in this image."
response = ask(model, image, question)
```
### Scene understanding
```python
question = "What is happening in this scene?"
response = ask(model, image, question)
```
### Document understanding
```python
question = "What is the main topic of this document?"
response = ask(model, document_image, question)
```
## Training custom model
```bash
# Stage 1: Feature alignment (558K image-caption pairs)
bash scripts/v1_5/pretrain.sh
# Stage 2: Visual instruction tuning (150K instruction data)
bash scripts/v1_5/finetune.sh
```
## Quantization (reduce VRAM)
```python
# 4-bit quantization
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path="liuhaotian/llava-v1.5-13b",
model_base=None,
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
load_4bit=True # Reduces VRAM ~4×
)
# 8-bit quantization
load_8bit=True # Reduces VRAM ~2×
```
## Best practices
1. **Start with 7B model** - Good quality, manageable VRAM
2. **Use 4-bit quantization** - Reduces VRAM significantly
3. **GPU required** - CPU inference extremely slow
4. **Clear prompts** - Specific questions get better answers
5. **Multi-turn conversations** - Maintain conversation context
6. **Temperature 0.2-0.7** - Balance creativity/consistency
7. **max_new_tokens 512-1024** - For detailed responses
8. **Batch processing** - Process multiple images sequentially
## Performance
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|-------|-------------|--------------|------------------|
| 7B | ~14 GB | ~4 GB | ~20 |
| 13B | ~28 GB | ~8 GB | ~12 |
| 34B | ~70 GB | ~18 GB | ~5 |
*On A100 GPU*
## Benchmarks
LLaVA achieves competitive scores on:
- **VQAv2**: 78.5%
- **GQA**: 62.0%
- **MM-Vet**: 35.4%
- **MMBench**: 64.3%
## Limitations
1. **Hallucinations** - May describe things not in image
2. **Spatial reasoning** - Struggles with precise locations
3. **Small text** - Difficulty reading fine print
4. **Object counting** - Imprecise for many objects
5. **VRAM requirements** - Need powerful GPU
6. **Inference speed** - Slower than CLIP
## Integration with frameworks
### LangChain
```python
from langchain.llms.base import LLM
class LLaVALLM(LLM):
def _call(self, prompt, stop=None):
# Custom LLaVA inference
return response
llm = LLaVALLM()
```
### Gradio App
```python
import gradio as gr
def chat(image, text, history):
response = ask_llava(model, image, text)
return response
demo = gr.ChatInterface(
chat,
additional_inputs=[gr.Image(type="pil")],
title="LLaVA Chat"
)
demo.launch()
```
## Resources
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
- **Paper**: https://arxiv.org/abs/2304.08485
- **Demo**: https://llava.hliu.cc
- **Models**: https://huggingface.co/liuhaotian
- **License**: Apache 2.0

Some files were not shown because too many files have changed in this diff Show More