mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-02 08:47:26 +08:00
Compare commits
42 Commits
feat/head-
...
fix/vision
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ac288e932 | ||
|
|
888277ece8 | ||
|
|
172a38c344 | ||
|
|
8bc0d4f77d | ||
|
|
8eabdefa8a | ||
|
|
f658af45c2 | ||
|
|
5212644861 | ||
|
|
1151f84351 | ||
|
|
9abd6bf342 | ||
|
|
d2c7ef6b41 | ||
|
|
a34102049b | ||
|
|
ef5d811aba | ||
|
|
2d44ed1c5b | ||
|
|
fa2e72ae9c | ||
|
|
5bfc4ed53b | ||
|
|
520aec20e0 | ||
|
|
64bec1d060 | ||
|
|
ac58309dbd | ||
|
|
a5a5d82a21 | ||
|
|
34e8d088c2 | ||
|
|
c754135965 | ||
|
|
c6b75baad0 | ||
|
|
a7ad6f6d28 | ||
|
|
1a2141d04d | ||
|
|
ff3f3169b2 | ||
|
|
f4580b6010 | ||
|
|
7b63a787b3 | ||
|
|
069570d103 | ||
|
|
0dafdcab86 | ||
|
|
654e16187e | ||
|
|
732c66b0f3 | ||
|
|
1f0944de21 | ||
|
|
f1a1b58319 | ||
|
|
c21d77ca08 | ||
|
|
d6c710706f | ||
|
|
a6d3becd6a | ||
|
|
3b67606c42 | ||
|
|
763c6d104d | ||
|
|
37752ff1ac | ||
|
|
36214d14db | ||
|
|
15561ec425 | ||
|
|
7d79ce92ac |
@@ -560,12 +560,16 @@ def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
forced = _get_auxiliary_provider("vision")
|
||||
if forced != "auto":
|
||||
return _resolve_forced_provider(forced)
|
||||
# Auto: only multimodal-capable providers
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex):
|
||||
# Auto: try providers known to support multimodal first, then fall
|
||||
# back to the user's custom endpoint. Many local models (Qwen-VL,
|
||||
# LLaVA, Pixtral, etc.) support vision — skipping them entirely
|
||||
# caused silent failures for local-only users.
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex,
|
||||
_try_custom_endpoint):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.debug("Auxiliary vision client: none available (auto only tries OpenRouter/Nous/Codex)")
|
||||
logger.debug("Auxiliary vision client: none available")
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
@@ -195,6 +195,8 @@ def build_skills_system_prompt() -> str:
|
||||
|
||||
# Collect skills with descriptions, grouped by category
|
||||
# Each entry: (skill_name, description)
|
||||
# Supports sub-categories: skills/mlops/training/axolotl/SKILL.md
|
||||
# → category "mlops/training", skill "axolotl"
|
||||
skills_by_category: dict[str, list[tuple[str, str]]] = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
# Skip skills incompatible with the current OS platform
|
||||
@@ -203,8 +205,13 @@ def build_skills_system_prompt() -> str:
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
category = parts[0]
|
||||
# Category is everything between skills_dir and the skill folder
|
||||
# e.g. parts = ("mlops", "training", "axolotl", "SKILL.md")
|
||||
# → category = "mlops/training", skill_name = "axolotl"
|
||||
# e.g. parts = ("github", "github-auth", "SKILL.md")
|
||||
# → category = "github", skill_name = "github-auth"
|
||||
skill_name = parts[-2]
|
||||
category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0]
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
@@ -215,9 +222,11 @@ def build_skills_system_prompt() -> str:
|
||||
return ""
|
||||
|
||||
# Read category-level descriptions from DESCRIPTION.md
|
||||
# Checks both the exact category path and parent directories
|
||||
category_descriptions = {}
|
||||
for category in skills_by_category:
|
||||
desc_file = skills_dir / category / "DESCRIPTION.md"
|
||||
cat_path = Path(category)
|
||||
desc_file = skills_dir / cat_path / "DESCRIPTION.md"
|
||||
if desc_file.exists():
|
||||
try:
|
||||
content = desc_file.read_text(encoding="utf-8")
|
||||
|
||||
@@ -555,6 +555,21 @@ toolsets:
|
||||
# args: ["-y", "@modelcontextprotocol/server-github"]
|
||||
# env:
|
||||
# GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
|
||||
#
|
||||
# Sampling (server-initiated LLM requests) — enabled by default.
|
||||
# Per-server config under the 'sampling' key:
|
||||
# analysis:
|
||||
# command: npx
|
||||
# args: ["-y", "analysis-server"]
|
||||
# sampling:
|
||||
# enabled: true # default: true
|
||||
# model: "gemini-3-flash" # override model (optional)
|
||||
# max_tokens_cap: 4096 # max tokens per request
|
||||
# timeout: 30 # LLM call timeout (seconds)
|
||||
# max_rpm: 10 # max requests per minute
|
||||
# allowed_models: [] # model whitelist (empty = all)
|
||||
# max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
# log_level: "info" # audit verbosity
|
||||
|
||||
# =============================================================================
|
||||
# Voice Transcription (Speech-to-Text)
|
||||
|
||||
44
cli.py
44
cli.py
@@ -158,6 +158,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"singularity_image": "docker://python:3.11",
|
||||
"modal_image": "python:3.11",
|
||||
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"docker_volumes": [], # host:container volume mounts for Docker backend
|
||||
},
|
||||
"browser": {
|
||||
"inactivity_timeout": 120, # Auto-cleanup inactive browser sessions after 2 min
|
||||
@@ -725,6 +726,7 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡀⠀⣀⣀
|
||||
[#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]"""
|
||||
|
||||
# Compact banner for smaller terminals (fallback)
|
||||
# Note: built dynamically by _build_compact_banner() to fit terminal width
|
||||
COMPACT_BANNER = """
|
||||
[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/]
|
||||
[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/]
|
||||
@@ -733,6 +735,26 @@ COMPACT_BANNER = """
|
||||
"""
|
||||
|
||||
|
||||
def _build_compact_banner() -> str:
|
||||
"""Build a compact banner that fits the current terminal width."""
|
||||
w = min(shutil.get_terminal_size().columns - 2, 64)
|
||||
if w < 30:
|
||||
return "\n[#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- Nous Research[/]\n"
|
||||
inner = w - 2 # inside the box border
|
||||
bar = "═" * w
|
||||
line1 = "⚕ NOUS HERMES - AI Agent Framework"
|
||||
line2 = "Messenger of the Digital Gods · Nous Research"
|
||||
# Truncate and pad to fit
|
||||
line1 = line1[:inner - 2].ljust(inner - 2)
|
||||
line2 = line2[:inner - 2].ljust(inner - 2)
|
||||
return (
|
||||
f"\n[bold #FFD700]╔{bar}╗[/]\n"
|
||||
f"[bold #FFD700]║[/] [#FFBF00]{line1}[/] [bold #FFD700]║[/]\n"
|
||||
f"[bold #FFD700]║[/] [dim #B8860B]{line2}[/] [bold #FFD700]║[/]\n"
|
||||
f"[bold #FFD700]╚{bar}╝[/]\n"
|
||||
)
|
||||
|
||||
|
||||
def _get_available_skills() -> Dict[str, List[str]]:
|
||||
"""
|
||||
Scan ~/.hermes/skills/ and return skills grouped by category.
|
||||
@@ -930,10 +952,12 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic
|
||||
padding=(0, 2),
|
||||
)
|
||||
|
||||
# Print the big HERMES-AGENT logo first (no panel wrapper for full width)
|
||||
console.print()
|
||||
console.print(HERMES_AGENT_LOGO)
|
||||
# Print the big HERMES-AGENT logo — skip if terminal is too narrow
|
||||
console.print()
|
||||
term_width = shutil.get_terminal_size().columns
|
||||
if term_width >= 95:
|
||||
console.print(HERMES_AGENT_LOGO)
|
||||
console.print()
|
||||
|
||||
# Print the panel with caduceus and info
|
||||
console.print(outer_panel)
|
||||
@@ -1383,8 +1407,13 @@ class HermesCLI:
|
||||
"""Display the welcome banner in Claude Code style."""
|
||||
self.console.clear()
|
||||
|
||||
if self.compact:
|
||||
self.console.print(COMPACT_BANNER)
|
||||
# Auto-compact for narrow terminals — the full banner with caduceus
|
||||
# + tool list needs ~80 columns minimum to render without wrapping.
|
||||
term_width = shutil.get_terminal_size().columns
|
||||
use_compact = self.compact or term_width < 80
|
||||
|
||||
if use_compact:
|
||||
self.console.print(_build_compact_banner())
|
||||
self._show_status()
|
||||
else:
|
||||
# Get tools for display
|
||||
@@ -2394,8 +2423,9 @@ class HermesCLI:
|
||||
# and gets mangled by patch_stdout).
|
||||
if self._app:
|
||||
cc = ChatConsole()
|
||||
if self.compact:
|
||||
cc.print(COMPACT_BANNER)
|
||||
term_w = shutil.get_terminal_size().columns
|
||||
if self.compact or term_w < 80:
|
||||
cc.print(_build_compact_banner())
|
||||
else:
|
||||
tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True)
|
||||
cwd = os.getenv("TERMINAL_CWD", os.getcwd())
|
||||
|
||||
46
datagen-config-examples/web_research.yaml
Normal file
46
datagen-config-examples/web_research.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
# datagen-config-examples/web_research.yaml
|
||||
#
|
||||
# Batch data generation config for WebResearchEnv.
|
||||
# Generates tool-calling trajectories for multi-step web research tasks.
|
||||
#
|
||||
# Usage:
|
||||
# python batch_runner.py \
|
||||
# --config datagen-config-examples/web_research.yaml \
|
||||
# --run_name web_research_v1
|
||||
|
||||
environment: web-research
|
||||
|
||||
# Toolsets available to the agent during data generation
|
||||
toolsets:
|
||||
- web
|
||||
- file
|
||||
|
||||
# How many parallel workers to use
|
||||
num_workers: 4
|
||||
|
||||
# Questions per batch
|
||||
batch_size: 20
|
||||
|
||||
# Total trajectories to generate (comment out to run full dataset)
|
||||
max_items: 500
|
||||
|
||||
# Model to use for generation (override with --model flag)
|
||||
model: openrouter/nousresearch/hermes-3-llama-3.1-405b
|
||||
|
||||
# System prompt additions (ephemeral — not saved to trajectories)
|
||||
ephemeral_system_prompt: |
|
||||
You are a highly capable research agent. When asked a factual question,
|
||||
always use web_search to find current, accurate information before answering.
|
||||
Cite at least 2 sources. Be concise and accurate.
|
||||
|
||||
# Output directory
|
||||
output_dir: data/web_research_v1
|
||||
|
||||
# Trajectory compression settings (for fitting into training token budgets)
|
||||
compression:
|
||||
enabled: true
|
||||
target_max_tokens: 16000
|
||||
|
||||
# Eval settings
|
||||
eval_every: 100 # Run eval every N trajectories
|
||||
eval_size: 25 # Number of held-out questions per eval run
|
||||
643
environments/web_research_env.py
Normal file
643
environments/web_research_env.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
WebResearchEnv — RL Environment for Multi-Step Web Research
|
||||
============================================================
|
||||
|
||||
Trains models to do accurate, efficient, multi-source web research.
|
||||
|
||||
Reward signals:
|
||||
- Answer correctness (LLM judge, 0.0–1.0)
|
||||
- Source diversity (used ≥2 distinct domains)
|
||||
- Efficiency (penalizes excessive tool calls)
|
||||
- Tool usage (bonus for actually using web tools)
|
||||
|
||||
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
||||
HuggingFace: google/frames-benchmark
|
||||
Fallback: built-in sample questions (no HF token needed)
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenAI-compatible server)
|
||||
python environments/web_research_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai
|
||||
|
||||
# Process mode (offline data generation)
|
||||
python environments/web_research_env.py process \\
|
||||
--env.data_path_to_save_groups data/web_research.jsonl
|
||||
|
||||
# Standalone eval
|
||||
python environments/web_research_env.py evaluate \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel
|
||||
|
||||
Built by: github.com/jackx707
|
||||
Inspired by: GroceryMind — production Hermes agent doing live web research
|
||||
across German grocery stores (firecrawl + hermes-agent)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
# Ensure hermes-agent root is on path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional HuggingFace datasets import
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
||||
# Multi-hop questions requiring real web search to answer.
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_QUESTIONS = [
|
||||
{
|
||||
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
|
||||
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
|
||||
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
|
||||
"answer": "Django, which Instagram was built on, is written in Python.",
|
||||
"difficulty": "easy",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
|
||||
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
|
||||
"difficulty": "hard",
|
||||
"hops": 3,
|
||||
},
|
||||
{
|
||||
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
|
||||
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "How many employees does the parent company of Instagram have?",
|
||||
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
|
||||
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
|
||||
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
|
||||
"difficulty": "medium",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
|
||||
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
{
|
||||
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
|
||||
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
|
||||
"difficulty": "hard",
|
||||
"hops": 2,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
||||
"""Configuration for the web research RL environment."""
|
||||
|
||||
# Reward weights
|
||||
correctness_weight: float = Field(
|
||||
default=0.6,
|
||||
description="Weight for answer correctness in reward (LLM judge score).",
|
||||
)
|
||||
tool_usage_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for tool usage signal (did the model actually use web tools?).",
|
||||
)
|
||||
efficiency_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
||||
)
|
||||
diversity_bonus: float = Field(
|
||||
default=0.1,
|
||||
description="Bonus reward for citing ≥2 distinct domains.",
|
||||
)
|
||||
|
||||
# Efficiency thresholds
|
||||
efficient_max_calls: int = Field(
|
||||
default=5,
|
||||
description="Maximum tool calls before efficiency penalty begins.",
|
||||
)
|
||||
heavy_penalty_calls: int = Field(
|
||||
default=10,
|
||||
description="Tool call count where efficiency penalty steepens.",
|
||||
)
|
||||
|
||||
# Eval
|
||||
eval_size: int = Field(
|
||||
default=20,
|
||||
description="Number of held-out items for evaluation.",
|
||||
)
|
||||
eval_split_ratio: float = Field(
|
||||
default=0.1,
|
||||
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
||||
)
|
||||
|
||||
# Dataset
|
||||
dataset_name: str = Field(
|
||||
default="google/frames-benchmark",
|
||||
description="HuggingFace dataset name for research questions.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class WebResearchEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
RL environment for training multi-step web research skills.
|
||||
|
||||
The model is given a factual question requiring 2-3 hops of web research
|
||||
and must use web_search / web_extract tools to find and synthesize the answer.
|
||||
|
||||
Reward is multi-signal:
|
||||
60% — answer correctness (LLM judge)
|
||||
20% — tool usage (did the model actually search the web?)
|
||||
20% — efficiency (penalizes >5 tool calls)
|
||||
|
||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
||||
"""
|
||||
|
||||
name = "web-research"
|
||||
env_config_cls = WebResearchEnvConfig
|
||||
|
||||
# Default toolsets for this environment — web + file for saving notes
|
||||
default_toolsets = ["web", "file"]
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
||||
"""Default configuration for the web research environment."""
|
||||
env_config = WebResearchEnvConfig(
|
||||
enabled_toolsets=["web", "file"],
|
||||
max_agent_turns=15,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a highly capable research agent. When asked a factual question, "
|
||||
"always use web_search to find current, accurate information before answering. "
|
||||
"Cite at least 2 sources. Be concise and accurate."
|
||||
),
|
||||
group_size=4,
|
||||
total_steps=1000,
|
||||
steps_per_eval=100,
|
||||
use_wandb=True,
|
||||
wandb_name="web-research",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.5",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._items: list[dict] = []
|
||||
self._eval_items: list[dict] = []
|
||||
self._index: int = 0
|
||||
|
||||
# Metrics tracking for wandb
|
||||
self._reward_buffer: list[float] = []
|
||||
self._correctness_buffer: list[float] = []
|
||||
self._tool_usage_buffer: list[float] = []
|
||||
self._efficiency_buffer: list[float] = []
|
||||
self._diversity_buffer: list[float] = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Setup — load dataset
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def setup(self) -> None:
|
||||
"""Load the FRAMES benchmark or fall back to built-in samples."""
|
||||
if HF_AVAILABLE:
|
||||
try:
|
||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
||||
ds = load_dataset(self.config.dataset_name, split="test")
|
||||
self._items = [
|
||||
{
|
||||
"question": row["Prompt"],
|
||||
"answer": row["Answer"],
|
||||
"difficulty": row.get("reasoning_types", "unknown"),
|
||||
"hops": 2,
|
||||
}
|
||||
for row in ds
|
||||
]
|
||||
# Hold out for eval
|
||||
eval_size = max(
|
||||
self.config.eval_size,
|
||||
int(len(self._items) * self.config.eval_split_ratio),
|
||||
)
|
||||
random.shuffle(self._items)
|
||||
self._eval_items = self._items[:eval_size]
|
||||
self._items = self._items[eval_size:]
|
||||
logger.info(
|
||||
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
|
||||
f"from FRAMES benchmark."
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
|
||||
|
||||
# Fallback
|
||||
random.shuffle(SAMPLE_QUESTIONS)
|
||||
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
|
||||
self._items = SAMPLE_QUESTIONS[:split]
|
||||
self._eval_items = SAMPLE_QUESTIONS[split:]
|
||||
logger.info(
|
||||
f"Using built-in sample dataset: {len(self._items)} train / "
|
||||
f"{len(self._eval_items)} eval items."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. get_next_item — return the next question
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_next_item(self) -> dict:
|
||||
"""Return the next item, cycling through the dataset."""
|
||||
if not self._items:
|
||||
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
||||
item = self._items[self._index % len(self._items)]
|
||||
self._index += 1
|
||||
return item
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. format_prompt — build the user-facing prompt
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def format_prompt(self, item: dict) -> str:
|
||||
"""Format the research question as a task prompt."""
|
||||
return (
|
||||
f"Research the following question thoroughly using web search. "
|
||||
f"You MUST search the web to find current, accurate information — "
|
||||
f"do not rely solely on your training data.\n\n"
|
||||
f"Question: {item['question']}\n\n"
|
||||
f"Requirements:\n"
|
||||
f"- Use web_search and/or web_extract tools to find information\n"
|
||||
f"- Search at least 2 different sources\n"
|
||||
f"- Provide a concise, accurate answer (2-4 sentences)\n"
|
||||
f"- Cite the sources you used"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. compute_reward — multi-signal scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_reward(
|
||||
self,
|
||||
item: dict,
|
||||
result: AgentResult,
|
||||
ctx: ToolContext,
|
||||
) -> float:
|
||||
"""
|
||||
Multi-signal reward function:
|
||||
|
||||
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
||||
tool_usage_weight * tool_used — binary: did the model use web tools?
|
||||
efficiency_weight * efficiency — penalizes wasteful tool usage
|
||||
+ diversity_bonus — source diversity (≥2 distinct domains)
|
||||
"""
|
||||
final_response: str = result.final_response or ""
|
||||
tools_used: list[str] = [
|
||||
tc.tool_name for tc in (result.tool_calls or [])
|
||||
] if hasattr(result, "tool_calls") and result.tool_calls else []
|
||||
tool_call_count: int = result.turns_used or len(tools_used)
|
||||
|
||||
cfg = self.config
|
||||
|
||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=final_response,
|
||||
)
|
||||
|
||||
# ---- Signal 2: Web tool usage --------------------------------
|
||||
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
|
||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
||||
|
||||
# ---- Signal 3: Efficiency ------------------------------------
|
||||
if tool_call_count <= cfg.efficient_max_calls:
|
||||
efficiency = 1.0
|
||||
elif tool_call_count <= cfg.heavy_penalty_calls:
|
||||
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
||||
else:
|
||||
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
||||
|
||||
# ---- Bonus: Source diversity ---------------------------------
|
||||
domains = self._extract_domains(final_response)
|
||||
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
||||
|
||||
# ---- Combine ------------------------------------------------
|
||||
reward = (
|
||||
cfg.correctness_weight * correctness
|
||||
+ cfg.tool_usage_weight * tool_used
|
||||
+ cfg.efficiency_weight * efficiency
|
||||
+ diversity
|
||||
)
|
||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
||||
|
||||
# Track for wandb
|
||||
self._reward_buffer.append(reward)
|
||||
self._correctness_buffer.append(correctness)
|
||||
self._tool_usage_buffer.append(tool_used)
|
||||
self._efficiency_buffer.append(efficiency)
|
||||
self._diversity_buffer.append(diversity)
|
||||
|
||||
logger.debug(
|
||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
||||
f"diversity={diversity:.1f} → total={reward:.3f}"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. evaluate — run on held-out eval split
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Run evaluation on the held-out split using the agent loop."""
|
||||
import time
|
||||
|
||||
items = self._eval_items
|
||||
if not items:
|
||||
logger.warning("No eval items available.")
|
||||
return
|
||||
|
||||
eval_size = min(self.config.eval_size, len(items))
|
||||
eval_items = items[:eval_size]
|
||||
|
||||
logger.info(f"Running eval on {len(eval_items)} questions...")
|
||||
start_time = time.time()
|
||||
samples = []
|
||||
|
||||
for item in eval_items:
|
||||
try:
|
||||
# Use the base env's agent loop for eval (same as training)
|
||||
prompt = self.format_prompt(item)
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
# Score the response
|
||||
correctness = await self._llm_judge(
|
||||
question=item["question"],
|
||||
expected=item["answer"],
|
||||
model_answer=response_content,
|
||||
)
|
||||
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": response_content,
|
||||
"expected": item["answer"],
|
||||
"correctness": correctness,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eval error on item: {e}")
|
||||
samples.append({
|
||||
"prompt": item["question"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": item["answer"],
|
||||
"correctness": 0.0,
|
||||
})
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Compute metrics
|
||||
correctness_scores = [s["correctness"] for s in samples]
|
||||
eval_metrics = {
|
||||
"eval/mean_correctness": (
|
||||
sum(correctness_scores) / len(correctness_scores)
|
||||
if correctness_scores else 0.0
|
||||
),
|
||||
"eval/n_items": len(samples),
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. wandb_log — custom metrics
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
||||
"""Log reward breakdown metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self._reward_buffer:
|
||||
n = len(self._reward_buffer)
|
||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
||||
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
||||
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
||||
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
||||
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
||||
wandb_metrics["train/total_rollouts"] = n
|
||||
|
||||
# Accuracy buckets
|
||||
wandb_metrics["train/correct_rate"] = (
|
||||
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
||||
)
|
||||
wandb_metrics["train/tool_usage_rate"] = (
|
||||
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
||||
)
|
||||
|
||||
# Clear buffers
|
||||
self._reward_buffer.clear()
|
||||
self._correctness_buffer.clear()
|
||||
self._tool_usage_buffer.clear()
|
||||
self._efficiency_buffer.clear()
|
||||
self._diversity_buffer.clear()
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _llm_judge(
|
||||
self,
|
||||
question: str,
|
||||
expected: str,
|
||||
model_answer: str,
|
||||
) -> float:
|
||||
"""
|
||||
Use the server's LLM to judge answer correctness.
|
||||
Falls back to keyword heuristic if LLM call fails.
|
||||
"""
|
||||
if not model_answer or not model_answer.strip():
|
||||
return 0.0
|
||||
|
||||
judge_prompt = (
|
||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
||||
f"Question: {question}\n\n"
|
||||
f"Reference answer: {expected}\n\n"
|
||||
f"Model answer: {model_answer}\n\n"
|
||||
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
|
||||
" 1.0 = fully correct and complete\n"
|
||||
" 0.7 = mostly correct with minor gaps\n"
|
||||
" 0.4 = partially correct\n"
|
||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
||||
" 0.0 = completely wrong or no answer\n\n"
|
||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
||||
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
n=1,
|
||||
max_tokens=150,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
text = response.choices[0].message.content if response.choices else ""
|
||||
parsed = self._parse_judge_json(text)
|
||||
if parsed is not None:
|
||||
return float(parsed)
|
||||
except Exception as e:
|
||||
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
||||
|
||||
return self._heuristic_score(expected, model_answer)
|
||||
|
||||
@staticmethod
|
||||
def _parse_judge_json(text: str) -> Optional[float]:
|
||||
"""Extract the score float from LLM judge JSON response."""
|
||||
try:
|
||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
||||
data = json.loads(clean)
|
||||
score = float(data.get("score", -1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
except Exception:
|
||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
||||
"""Lightweight keyword overlap score as fallback."""
|
||||
stopwords = {
|
||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
||||
"this", "that", "as", "by", "from", "be", "has", "have", "had",
|
||||
}
|
||||
|
||||
def tokenize(text: str) -> set:
|
||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
||||
|
||||
expected_tokens = tokenize(expected)
|
||||
answer_tokens = tokenize(model_answer)
|
||||
|
||||
if not expected_tokens:
|
||||
return 0.5
|
||||
|
||||
overlap = len(expected_tokens & answer_tokens)
|
||||
union = len(expected_tokens | answer_tokens)
|
||||
|
||||
jaccard = overlap / union if union > 0 else 0.0
|
||||
recall = overlap / len(expected_tokens)
|
||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
||||
|
||||
@staticmethod
|
||||
def _extract_domains(text: str) -> set:
|
||||
"""Extract unique domains from URLs cited in the response."""
|
||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
||||
domains = set()
|
||||
for url in urls:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc.lower().lstrip("www.")
|
||||
if domain:
|
||||
domains.add(domain)
|
||||
except Exception:
|
||||
pass
|
||||
return domains
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
WebResearchEnv.cli()
|
||||
@@ -252,6 +252,7 @@ def cleanup_document_cache(max_age_hours: int = 24) -> int:
|
||||
class MessageType(Enum):
|
||||
"""Types of incoming messages."""
|
||||
TEXT = "text"
|
||||
LOCATION = "location"
|
||||
PHOTO = "photo"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
|
||||
@@ -10,6 +10,7 @@ Uses slack-bolt (Python) with Socket Mode for:
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
try:
|
||||
@@ -33,6 +34,8 @@ from gateway.platforms.base import (
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
@@ -96,6 +99,13 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
async def handle_message_event(event, say):
|
||||
await self._handle_slack_message(event)
|
||||
|
||||
# Acknowledge app_mention events to prevent Bolt 404 errors.
|
||||
# The "message" handler above already processes @mentions in
|
||||
# channels, so this is intentionally a no-op to avoid duplicates.
|
||||
@self._app.event("app_mention")
|
||||
async def handle_app_mention(event, say):
|
||||
pass
|
||||
|
||||
# Register slash command handler
|
||||
@self._app.command("/hermes")
|
||||
async def handle_hermes_command(ack, command):
|
||||
@@ -266,6 +276,65 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a video file to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=video_path,
|
||||
filename=os.path.basename(video_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send video: {e}")
|
||||
return await super().send_video(chat_id, video_path, caption, reply_to)
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment to Slack."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
display_name = file_name or os.path.basename(file_path)
|
||||
|
||||
try:
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=file_path,
|
||||
filename=display_name,
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send document: {e}")
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Slack channel."""
|
||||
if not self._app:
|
||||
@@ -347,6 +416,58 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
msg_type = MessageType.VOICE
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache audio: {e}", flush=True)
|
||||
elif url:
|
||||
# Try to handle as a document attachment
|
||||
try:
|
||||
original_filename = f.get("name", "")
|
||||
ext = ""
|
||||
if original_filename:
|
||||
_, ext = os.path.splitext(original_filename)
|
||||
ext = ext.lower()
|
||||
|
||||
# Fallback: reverse-lookup from MIME type
|
||||
if not ext and mimetype:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(mimetype, "")
|
||||
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
continue # Skip unsupported file types silently
|
||||
|
||||
# Check file size (Slack limit: 20 MB for bots)
|
||||
file_size = f.get("size", 0)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not file_size or file_size > MAX_DOC_BYTES:
|
||||
print(f"[Slack] Document too large or unknown size: {file_size}", flush=True)
|
||||
continue
|
||||
|
||||
# Download and cache
|
||||
raw_bytes = await self._download_slack_file_bytes(url)
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, original_filename or f"document{ext}"
|
||||
)
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
msg_type = MessageType.DOCUMENT
|
||||
print(f"[Slack] Cached user document: {cached_path}", flush=True)
|
||||
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if text:
|
||||
text = f"{injection}\n\n{text}"
|
||||
else:
|
||||
text = injection
|
||||
except UnicodeDecodeError:
|
||||
pass # Binary content, skip injection
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Slack] Failed to cache document: {e}", flush=True)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
@@ -427,3 +548,16 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
"""Download a Slack file and return raw bytes."""
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@@ -132,6 +132,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
filters.COMMAND,
|
||||
self._handle_command
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION),
|
||||
self._handle_location_message
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.PHOTO | filters.VIDEO | filters.AUDIO | filters.VOICE | filters.Document.ALL | filters.Sticker.ALL,
|
||||
self._handle_media_message
|
||||
@@ -546,6 +550,41 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming location/venue pin messages."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
msg = update.message
|
||||
venue = getattr(msg, "venue", None)
|
||||
location = getattr(venue, "location", None) if venue else getattr(msg, "location", None)
|
||||
|
||||
if not location:
|
||||
return
|
||||
|
||||
lat = getattr(location, "latitude", None)
|
||||
lon = getattr(location, "longitude", None)
|
||||
if lat is None or lon is None:
|
||||
return
|
||||
|
||||
# Build a text message with coordinates and context
|
||||
parts = ["[The user shared a location pin.]"]
|
||||
if venue:
|
||||
title = getattr(venue, "title", None)
|
||||
address = getattr(venue, "address", None)
|
||||
if title:
|
||||
parts.append(f"Venue: {title}")
|
||||
if address:
|
||||
parts.append(f"Address: {address}")
|
||||
parts.append(f"latitude: {lat}")
|
||||
parts.append(f"longitude: {lon}")
|
||||
parts.append(f"Map: https://www.google.com/maps/search/?api=1&query={lat},{lon}")
|
||||
parts.append("Ask what they'd like to find nearby (restaurants, cafes, etc.) and any preferences.")
|
||||
|
||||
event = self._build_message_event(msg, MessageType.LOCATION)
|
||||
event.text = "\n".join(parts)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_media_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming media messages, downloading images to local cache."""
|
||||
if not update.message:
|
||||
|
||||
@@ -75,11 +75,16 @@ if _config_path.exists():
|
||||
"container_memory": "TERMINAL_CONTAINER_MEMORY",
|
||||
"container_disk": "TERMINAL_CONTAINER_DISK",
|
||||
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
|
||||
"docker_volumes": "TERMINAL_DOCKER_VOLUMES",
|
||||
"sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
}
|
||||
for _cfg_key, _env_var in _terminal_env_map.items():
|
||||
if _cfg_key in _terminal_cfg:
|
||||
os.environ[_env_var] = str(_terminal_cfg[_cfg_key])
|
||||
_val = _terminal_cfg[_cfg_key]
|
||||
if isinstance(_val, list):
|
||||
os.environ[_env_var] = json.dumps(_val)
|
||||
else:
|
||||
os.environ[_env_var] = str(_val)
|
||||
_compression_cfg = _cfg.get("compression", {})
|
||||
if _compression_cfg and isinstance(_compression_cfg, dict):
|
||||
_compression_env_map = {
|
||||
@@ -1449,6 +1454,11 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
current_provider = "openrouter"
|
||||
|
||||
# Detect custom endpoint: provider resolved to openrouter but a custom
|
||||
# base URL is configured — the user set up a custom endpoint.
|
||||
if current_provider == "openrouter" and os.getenv("OPENAI_BASE_URL", "").strip():
|
||||
current_provider = "custom"
|
||||
|
||||
if not args:
|
||||
provider_label = _PROVIDER_LABELS.get(current_provider, current_provider)
|
||||
lines = [
|
||||
@@ -1575,6 +1585,10 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
current_provider = "openrouter"
|
||||
|
||||
# Detect custom endpoint
|
||||
if current_provider == "openrouter" and os.getenv("OPENAI_BASE_URL", "").strip():
|
||||
current_provider = "custom"
|
||||
|
||||
current_label = _PROVIDER_LABELS.get(current_provider, current_provider)
|
||||
|
||||
lines = [
|
||||
|
||||
@@ -47,7 +47,7 @@ def _fetch_models_from_api(access_token: str) -> List[str]:
|
||||
if item.get("supported_in_api") is False:
|
||||
continue
|
||||
visibility = item.get("visibility", "")
|
||||
if isinstance(visibility, str) and visibility.strip().lower() == "hide":
|
||||
if isinstance(visibility, str) and visibility.strip().lower() == "hidden":
|
||||
continue
|
||||
priority = item.get("priority")
|
||||
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
|
||||
|
||||
@@ -77,6 +77,10 @@ DEFAULT_CONFIG = {
|
||||
"container_memory": 5120, # MB (default 5GB)
|
||||
"container_disk": 51200, # MB (default 50GB)
|
||||
"container_persistent": True, # Persist filesystem across sessions
|
||||
# Docker volume mounts — share host directories with the container.
|
||||
# Each entry is "host_path:container_path" (standard Docker -v syntax).
|
||||
# Example: ["/home/user/projects:/workspace/projects", "/data:/data"]
|
||||
"docker_volumes": [],
|
||||
},
|
||||
|
||||
"browser": {
|
||||
@@ -401,14 +405,18 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "messaging",
|
||||
},
|
||||
"SLACK_BOT_TOKEN": {
|
||||
"description": "Slack bot integration",
|
||||
"description": "Slack bot token (xoxb-). Get from OAuth & Permissions after installing your app. "
|
||||
"Required scopes: chat:write, app_mentions:read, channels:history, groups:history, "
|
||||
"im:history, im:read, im:write, users:read, files:write",
|
||||
"prompt": "Slack Bot Token (xoxb-...)",
|
||||
"url": "https://api.slack.com/apps",
|
||||
"password": True,
|
||||
"category": "messaging",
|
||||
},
|
||||
"SLACK_APP_TOKEN": {
|
||||
"description": "Slack Socket Mode connection",
|
||||
"description": "Slack app-level token (xapp-) for Socket Mode. Get from Basic Information → "
|
||||
"App-Level Tokens. Also ensure Event Subscriptions include: message.im, "
|
||||
"message.channels, message.groups, app_mention",
|
||||
"prompt": "Slack App Token (xapp-...)",
|
||||
"url": "https://api.slack.com/apps",
|
||||
"password": True,
|
||||
|
||||
@@ -482,14 +482,19 @@ _PLATFORMS = [
|
||||
"token_var": "SLACK_BOT_TOKEN",
|
||||
"setup_instructions": [
|
||||
"1. Go to https://api.slack.com/apps → Create New App → From Scratch",
|
||||
"2. Enable Socket Mode: App Settings → Socket Mode → Enable",
|
||||
"3. Get Bot Token: OAuth & Permissions → Install to Workspace → copy xoxb-... token",
|
||||
"4. Get App Token: Basic Information → App-Level Tokens → Generate",
|
||||
" Name it anything, add scope: connections:write → copy xapp-... token",
|
||||
"5. Add bot scopes: OAuth & Permissions → Scopes → chat:write, im:history,",
|
||||
" im:read, im:write, channels:history, channels:read",
|
||||
"6. Reinstall the app to your workspace after adding scopes",
|
||||
"2. Enable Socket Mode: Settings → Socket Mode → Enable",
|
||||
" Create an App-Level Token with scope: connections:write → copy xapp-... token",
|
||||
"3. Add Bot Token Scopes: Features → OAuth & Permissions → Scopes",
|
||||
" Required: chat:write, app_mentions:read, channels:history, channels:read,",
|
||||
" groups:history, im:history, im:read, im:write, users:read, files:write",
|
||||
"4. Subscribe to Events: Features → Event Subscriptions → Enable",
|
||||
" Required events: message.im, message.channels, app_mention",
|
||||
" Optional: message.groups (for private channels)",
|
||||
" ⚠ Without message.channels the bot will ONLY work in DMs!",
|
||||
"5. Install to Workspace: Settings → Install App → copy xoxb-... token",
|
||||
"6. Reinstall the app after any scope or event changes",
|
||||
"7. Find your user ID: click your profile → three dots → Copy member ID",
|
||||
"8. Invite the bot to channels: /invite @YourBot",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "SLACK_BOT_TOKEN", "prompt": "Bot Token (xoxb-...)", "password": True,
|
||||
|
||||
@@ -761,9 +761,39 @@ def cmd_model(args):
|
||||
("kimi-coding", "Kimi / Moonshot (Moonshot AI direct API)"),
|
||||
("minimax", "MiniMax (global direct API)"),
|
||||
("minimax-cn", "MiniMax China (domestic direct API)"),
|
||||
("custom", "Custom endpoint (self-hosted / VLLM / etc.)"),
|
||||
]
|
||||
|
||||
# Add user-defined custom providers from config.yaml
|
||||
custom_providers_cfg = config.get("custom_providers") or []
|
||||
_custom_provider_map = {} # key → {name, base_url, api_key}
|
||||
if isinstance(custom_providers_cfg, list):
|
||||
for entry in custom_providers_cfg:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = entry.get("name", "").strip()
|
||||
base_url = entry.get("base_url", "").strip()
|
||||
if not name or not base_url:
|
||||
continue
|
||||
# Generate a stable key from the name
|
||||
key = "custom:" + name.lower().replace(" ", "-")
|
||||
short_url = base_url.replace("https://", "").replace("http://", "").rstrip("/")
|
||||
saved_model = entry.get("model", "")
|
||||
model_hint = f" — {saved_model}" if saved_model else ""
|
||||
providers.append((key, f"{name} ({short_url}){model_hint}"))
|
||||
_custom_provider_map[key] = {
|
||||
"name": name,
|
||||
"base_url": base_url,
|
||||
"api_key": entry.get("api_key", ""),
|
||||
"model": saved_model,
|
||||
}
|
||||
|
||||
# Always add the manual custom endpoint option last
|
||||
providers.append(("custom", "Custom endpoint (enter URL manually)"))
|
||||
|
||||
# Add removal option if there are saved custom providers
|
||||
if _custom_provider_map:
|
||||
providers.append(("remove-custom", "Remove a saved custom provider"))
|
||||
|
||||
# Reorder so the active provider is at the top
|
||||
known_keys = {k for k, _ in providers}
|
||||
active_key = active if active in known_keys else "custom"
|
||||
@@ -791,6 +821,10 @@ def cmd_model(args):
|
||||
_model_flow_openai_codex(config, current_model)
|
||||
elif selected_provider == "custom":
|
||||
_model_flow_custom(config)
|
||||
elif selected_provider.startswith("custom:") and selected_provider in _custom_provider_map:
|
||||
_model_flow_named_custom(config, _custom_provider_map[selected_provider])
|
||||
elif selected_provider == "remove-custom":
|
||||
_remove_custom_provider(config)
|
||||
elif selected_provider in ("zai", "kimi-coding", "minimax", "minimax-cn"):
|
||||
_model_flow_api_key_provider(config, selected_provider, current_model)
|
||||
|
||||
@@ -1006,7 +1040,11 @@ def _model_flow_openai_codex(config, current_model=""):
|
||||
|
||||
|
||||
def _model_flow_custom(config):
|
||||
"""Custom endpoint: collect URL, API key, and model name."""
|
||||
"""Custom endpoint: collect URL, API key, and model name.
|
||||
|
||||
Automatically saves the endpoint to ``custom_providers`` in config.yaml
|
||||
so it appears in the provider menu on subsequent runs.
|
||||
"""
|
||||
from hermes_cli.auth import _save_model_choice, deactivate_provider
|
||||
from hermes_cli.config import get_env_value, save_env_value, load_config, save_config
|
||||
|
||||
@@ -1038,6 +1076,8 @@ def _model_flow_custom(config):
|
||||
print(f"Invalid URL: {effective_url} (must start with http:// or https://)")
|
||||
return
|
||||
|
||||
effective_key = api_key or current_key
|
||||
|
||||
if base_url:
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
@@ -1050,7 +1090,7 @@ def _model_flow_custom(config):
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if isinstance(model, dict):
|
||||
model["provider"] = "auto"
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = effective_url
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
@@ -1061,6 +1101,223 @@ def _model_flow_custom(config):
|
||||
deactivate_provider()
|
||||
print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.")
|
||||
|
||||
# Auto-save to custom_providers so it appears in the menu next time
|
||||
_save_custom_provider(effective_url, effective_key, model_name or "")
|
||||
|
||||
|
||||
def _save_custom_provider(base_url, api_key="", model=""):
|
||||
"""Save a custom endpoint to custom_providers in config.yaml.
|
||||
|
||||
Deduplicates by base_url — if the URL already exists, updates the
|
||||
model name but doesn't add a duplicate entry.
|
||||
Auto-generates a display name from the URL hostname.
|
||||
"""
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
cfg = load_config()
|
||||
providers = cfg.get("custom_providers") or []
|
||||
if not isinstance(providers, list):
|
||||
providers = []
|
||||
|
||||
# Check if this URL is already saved — update model if so
|
||||
for entry in providers:
|
||||
if isinstance(entry, dict) and entry.get("base_url", "").rstrip("/") == base_url.rstrip("/"):
|
||||
if model and entry.get("model") != model:
|
||||
entry["model"] = model
|
||||
cfg["custom_providers"] = providers
|
||||
save_config(cfg)
|
||||
return # already saved, updated model if needed
|
||||
|
||||
# Auto-generate a name from the URL
|
||||
import re
|
||||
clean = base_url.replace("https://", "").replace("http://", "").rstrip("/")
|
||||
# Remove /v1 suffix for cleaner names
|
||||
clean = re.sub(r"/v1/?$", "", clean)
|
||||
# Use hostname:port as the name
|
||||
name = clean.split("/")[0]
|
||||
# Capitalize for readability
|
||||
if "localhost" in name or "127.0.0.1" in name:
|
||||
name = f"Local ({name})"
|
||||
elif "runpod" in name.lower():
|
||||
name = f"RunPod ({name})"
|
||||
else:
|
||||
name = name.capitalize()
|
||||
|
||||
entry = {"name": name, "base_url": base_url}
|
||||
if api_key:
|
||||
entry["api_key"] = api_key
|
||||
if model:
|
||||
entry["model"] = model
|
||||
|
||||
providers.append(entry)
|
||||
cfg["custom_providers"] = providers
|
||||
save_config(cfg)
|
||||
print(f" 💾 Saved to custom providers as \"{name}\" (edit in config.yaml)")
|
||||
|
||||
|
||||
def _remove_custom_provider(config):
|
||||
"""Let the user remove a saved custom provider from config.yaml."""
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
cfg = load_config()
|
||||
providers = cfg.get("custom_providers") or []
|
||||
if not isinstance(providers, list) or not providers:
|
||||
print("No custom providers configured.")
|
||||
return
|
||||
|
||||
print("Remove a custom provider:\n")
|
||||
|
||||
choices = []
|
||||
for entry in providers:
|
||||
if isinstance(entry, dict):
|
||||
name = entry.get("name", "unnamed")
|
||||
url = entry.get("base_url", "")
|
||||
short_url = url.replace("https://", "").replace("http://", "").rstrip("/")
|
||||
choices.append(f"{name} ({short_url})")
|
||||
else:
|
||||
choices.append(str(entry))
|
||||
choices.append("Cancel")
|
||||
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu = TerminalMenu(
|
||||
[f" {c}" for c in choices], cursor_index=0,
|
||||
menu_cursor="-> ", menu_cursor_style=("fg_red", "bold"),
|
||||
menu_highlight_style=("fg_red",),
|
||||
cycle_cursor=True, clear_screen=False,
|
||||
title="Select provider to remove:",
|
||||
)
|
||||
idx = menu.show()
|
||||
print()
|
||||
except (ImportError, NotImplementedError):
|
||||
for i, c in enumerate(choices, 1):
|
||||
print(f" {i}. {c}")
|
||||
print()
|
||||
try:
|
||||
val = input(f"Choice [1-{len(choices)}]: ").strip()
|
||||
idx = int(val) - 1 if val else None
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
idx = None
|
||||
|
||||
if idx is None or idx >= len(providers):
|
||||
print("No change.")
|
||||
return
|
||||
|
||||
removed = providers.pop(idx)
|
||||
cfg["custom_providers"] = providers
|
||||
save_config(cfg)
|
||||
removed_name = removed.get("name", "unnamed") if isinstance(removed, dict) else str(removed)
|
||||
print(f"✅ Removed \"{removed_name}\" from custom providers.")
|
||||
|
||||
|
||||
def _model_flow_named_custom(config, provider_info):
|
||||
"""Handle a named custom provider from config.yaml custom_providers list.
|
||||
|
||||
If the entry has a saved model name, activates it immediately.
|
||||
Otherwise probes the endpoint's /models API to let the user pick one.
|
||||
"""
|
||||
from hermes_cli.auth import _save_model_choice, deactivate_provider
|
||||
from hermes_cli.config import save_env_value, load_config, save_config
|
||||
from hermes_cli.models import fetch_api_models
|
||||
|
||||
name = provider_info["name"]
|
||||
base_url = provider_info["base_url"]
|
||||
api_key = provider_info.get("api_key", "")
|
||||
saved_model = provider_info.get("model", "")
|
||||
|
||||
# If a model is saved, just activate immediately — no probing needed
|
||||
if saved_model:
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
save_env_value("OPENAI_API_KEY", api_key)
|
||||
_save_model_choice(saved_model)
|
||||
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if isinstance(model, dict):
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = base_url
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
print(f"✅ Switched to: {saved_model}")
|
||||
print(f" Provider: {name} ({base_url})")
|
||||
return
|
||||
|
||||
# No saved model — probe endpoint and let user pick
|
||||
print(f" Provider: {name}")
|
||||
print(f" URL: {base_url}")
|
||||
print()
|
||||
print("No model saved for this provider. Fetching available models...")
|
||||
models = fetch_api_models(api_key, base_url, timeout=8.0)
|
||||
|
||||
if models:
|
||||
print(f"Found {len(models)} model(s):\n")
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu_items = [f" {m}" for m in models] + [" Cancel"]
|
||||
menu = TerminalMenu(
|
||||
menu_items, cursor_index=0,
|
||||
menu_cursor="-> ", menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True, clear_screen=False,
|
||||
title=f"Select model from {name}:",
|
||||
)
|
||||
idx = menu.show()
|
||||
print()
|
||||
if idx is None or idx >= len(models):
|
||||
print("Cancelled.")
|
||||
return
|
||||
model_name = models[idx]
|
||||
except (ImportError, NotImplementedError):
|
||||
for i, m in enumerate(models, 1):
|
||||
print(f" {i}. {m}")
|
||||
print(f" {len(models) + 1}. Cancel")
|
||||
print()
|
||||
try:
|
||||
val = input(f"Choice [1-{len(models) + 1}]: ").strip()
|
||||
if not val:
|
||||
print("Cancelled.")
|
||||
return
|
||||
idx = int(val) - 1
|
||||
if idx < 0 or idx >= len(models):
|
||||
print("Cancelled.")
|
||||
return
|
||||
model_name = models[idx]
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print("\nCancelled.")
|
||||
return
|
||||
else:
|
||||
print("Could not fetch models from endpoint. Enter model name manually.")
|
||||
try:
|
||||
model_name = input("Model name: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nCancelled.")
|
||||
return
|
||||
if not model_name:
|
||||
print("No model specified. Cancelled.")
|
||||
return
|
||||
|
||||
# Activate and save the model to the custom_providers entry
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
save_env_value("OPENAI_API_KEY", api_key)
|
||||
_save_model_choice(model_name)
|
||||
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if isinstance(model, dict):
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = base_url
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
# Save model name to the custom_providers entry for next time
|
||||
_save_custom_provider(base_url, api_key, model_name)
|
||||
|
||||
print(f"\n✅ Model set to: {model_name}")
|
||||
print(f" Provider: {name} ({base_url})")
|
||||
|
||||
|
||||
# Curated model lists for direct API-key providers
|
||||
_PROVIDER_MODELS = {
|
||||
|
||||
@@ -63,7 +63,7 @@ _PROVIDER_LABELS = {
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"custom": "custom endpoint",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
|
||||
_PROVIDER_ALIASES = {
|
||||
|
||||
@@ -632,6 +632,29 @@ def setup_model_provider(config: dict):
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
|
||||
# Update config.yaml and deactivate any OAuth provider so the
|
||||
# resolver doesn't keep returning the old provider (e.g. Codex).
|
||||
try:
|
||||
from hermes_cli.auth import deactivate_provider
|
||||
deactivate_provider()
|
||||
except Exception:
|
||||
pass
|
||||
import yaml
|
||||
config_path = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml"
|
||||
try:
|
||||
disk_cfg = {}
|
||||
if config_path.exists():
|
||||
disk_cfg = yaml.safe_load(config_path.read_text()) or {}
|
||||
model_section = disk_cfg.get("model", {})
|
||||
if isinstance(model_section, str):
|
||||
model_section = {"default": model_section}
|
||||
model_section["provider"] = "openrouter"
|
||||
model_section.pop("base_url", None) # OpenRouter uses default URL
|
||||
disk_cfg["model"] = model_section
|
||||
config_path.write_text(yaml.safe_dump(disk_cfg, sort_keys=False))
|
||||
except Exception as e:
|
||||
logger.debug("Could not save provider to config.yaml: %s", e)
|
||||
|
||||
elif provider_idx == 3: # Custom endpoint
|
||||
selected_provider = "custom"
|
||||
print()
|
||||
@@ -659,6 +682,28 @@ def setup_model_provider(config: dict):
|
||||
if model_name:
|
||||
config['model'] = model_name
|
||||
save_env_value("LLM_MODEL", model_name)
|
||||
|
||||
# Save provider and base_url to config.yaml so the gateway and CLI
|
||||
# both resolve the correct provider without relying on env-var heuristics.
|
||||
if base_url:
|
||||
import yaml
|
||||
config_path = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml"
|
||||
try:
|
||||
disk_cfg = {}
|
||||
if config_path.exists():
|
||||
disk_cfg = yaml.safe_load(config_path.read_text()) or {}
|
||||
model_section = disk_cfg.get("model", {})
|
||||
if isinstance(model_section, str):
|
||||
model_section = {"default": model_section}
|
||||
model_section["provider"] = "custom"
|
||||
model_section["base_url"] = base_url.rstrip("/")
|
||||
if model_name:
|
||||
model_section["default"] = model_name
|
||||
disk_cfg["model"] = model_section
|
||||
config_path.write_text(yaml.safe_dump(disk_cfg, sort_keys=False))
|
||||
except Exception as e:
|
||||
logger.debug("Could not save provider to config.yaml: %s", e)
|
||||
|
||||
print_success("Custom endpoint configured")
|
||||
|
||||
elif provider_idx == 4: # Z.AI / GLM
|
||||
@@ -1527,10 +1572,22 @@ def setup_gateway(config: dict):
|
||||
|
||||
if not existing_slack and prompt_yes_no("Set up Slack bot?", False):
|
||||
print_info("Steps to create a Slack app:")
|
||||
print_info(" 1. Go to https://api.slack.com/apps → Create New App")
|
||||
print_info(" 2. Enable Socket Mode: App Settings → Socket Mode → Enable")
|
||||
print_info(" 3. Bot Token: OAuth & Permissions → Install to Workspace")
|
||||
print_info(" 4. App Token: Basic Information → App-Level Tokens → Generate")
|
||||
print_info(" 1. Go to https://api.slack.com/apps → Create New App (from scratch)")
|
||||
print_info(" 2. Enable Socket Mode: Settings → Socket Mode → Enable")
|
||||
print_info(" • Create an App-Level Token with 'connections:write' scope")
|
||||
print_info(" 3. Add Bot Token Scopes: Features → OAuth & Permissions")
|
||||
print_info(" Required scopes: chat:write, app_mentions:read,")
|
||||
print_info(" channels:history, channels:read, groups:history,")
|
||||
print_info(" im:history, im:read, im:write, users:read, files:write")
|
||||
print_info(" 4. Subscribe to Events: Features → Event Subscriptions → Enable")
|
||||
print_info(" Required events: message.im, message.channels,")
|
||||
print_info(" message.groups, app_mention")
|
||||
print_warning(" ⚠ Without message.channels/message.groups events,")
|
||||
print_warning(" the bot will ONLY work in DMs, not channels!")
|
||||
print_info(" 5. Install to Workspace: Settings → Install App")
|
||||
print_info(" 6. After installing, invite the bot to channels: /invite @YourBot")
|
||||
print()
|
||||
print_info(" Full guide: https://hermes-agent.ai/docs/user-guide/messaging/slack")
|
||||
print()
|
||||
bot_token = prompt("Slack Bot Token (xoxb-...)", password=True)
|
||||
if bot_token:
|
||||
@@ -1542,7 +1599,7 @@ def setup_gateway(config: dict):
|
||||
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can use your bot")
|
||||
print_info(" Find Slack user IDs in your profile or via the Slack API")
|
||||
print_info(" To find a Member ID: click a user's name → View full profile → ⋮ → Copy member ID")
|
||||
print()
|
||||
allowed_users = prompt("Allowed user IDs (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
modal = ["swe-rex[modal]>=1.4.0"]
|
||||
daytona = ["daytona>=0.148.0"]
|
||||
dev = ["pytest", "pytest-asyncio"]
|
||||
dev = ["pytest", "pytest-asyncio", "mcp>=1.2.0"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
cron = ["croniter"]
|
||||
slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
|
||||
21
run_agent.py
21
run_agent.py
@@ -3834,6 +3834,27 @@ class AIAgent:
|
||||
else:
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Normalize content to string — some OpenAI-compatible servers
|
||||
# (llama-server, etc.) return content as a dict or list instead
|
||||
# of a plain string, which crashes downstream .strip() calls.
|
||||
if assistant_message.content is not None and not isinstance(assistant_message.content, str):
|
||||
raw = assistant_message.content
|
||||
if isinstance(raw, dict):
|
||||
assistant_message.content = raw.get("text", "") or raw.get("content", "") or json.dumps(raw)
|
||||
elif isinstance(raw, list):
|
||||
# Multimodal content list — extract text parts
|
||||
parts = []
|
||||
for part in raw:
|
||||
if isinstance(part, str):
|
||||
parts.append(part)
|
||||
elif isinstance(part, dict) and part.get("type") == "text":
|
||||
parts.append(part.get("text", ""))
|
||||
elif isinstance(part, dict) and "text" in part:
|
||||
parts.append(str(part["text"]))
|
||||
assistant_message.content = "\n".join(parts)
|
||||
else:
|
||||
assistant_message.content = str(raw)
|
||||
|
||||
# Handle assistant response
|
||||
if assistant_message.content and not self.quiet_mode:
|
||||
print(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}")
|
||||
|
||||
3
skills/creative/DESCRIPTION.md
Normal file
3
skills/creative/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Creative content generation — ASCII art, hand-drawn style diagrams, and visual design tools.
|
||||
---
|
||||
69
skills/leisure/find-nearby/SKILL.md
Normal file
69
skills/leisure/find-nearby/SKILL.md
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
name: find-nearby
|
||||
description: Find nearby places (restaurants, cafes, bars, pharmacies, etc.) using OpenStreetMap. Works with coordinates, addresses, cities, zip codes, or Telegram location pins. No API keys needed.
|
||||
version: 1.0.0
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [location, maps, nearby, places, restaurants, local]
|
||||
related_skills: []
|
||||
---
|
||||
|
||||
# Find Nearby — Local Place Discovery
|
||||
|
||||
Find restaurants, cafes, bars, pharmacies, and other places near any location. Uses OpenStreetMap (free, no API keys). Works with:
|
||||
|
||||
- **Coordinates** from Telegram location pins (latitude/longitude in conversation)
|
||||
- **Addresses** ("near 123 Main St, Springfield")
|
||||
- **Cities** ("restaurants in downtown Austin")
|
||||
- **Zip codes** ("pharmacies near 90210")
|
||||
- **Landmarks** ("cafes near Times Square")
|
||||
|
||||
## Quick Reference
|
||||
|
||||
```bash
|
||||
# By coordinates (from Telegram location pin or user-provided)
|
||||
python3 SKILL_DIR/scripts/find_nearby.py --lat <LAT> --lon <LON> --type restaurant --radius 1500
|
||||
|
||||
# By address, city, or landmark (auto-geocoded)
|
||||
python3 SKILL_DIR/scripts/find_nearby.py --near "Times Square, New York" --type cafe
|
||||
|
||||
# Multiple place types
|
||||
python3 SKILL_DIR/scripts/find_nearby.py --near "downtown austin" --type restaurant --type bar --limit 10
|
||||
|
||||
# JSON output
|
||||
python3 SKILL_DIR/scripts/find_nearby.py --near "90210" --type pharmacy --json
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Flag | Description | Default |
|
||||
|------|-------------|---------|
|
||||
| `--lat`, `--lon` | Exact coordinates | — |
|
||||
| `--near` | Address, city, zip, or landmark (geocoded) | — |
|
||||
| `--type` | Place type (repeatable for multiple) | restaurant |
|
||||
| `--radius` | Search radius in meters | 1500 |
|
||||
| `--limit` | Max results | 15 |
|
||||
| `--json` | Machine-readable JSON output | off |
|
||||
|
||||
### Common Place Types
|
||||
|
||||
`restaurant`, `cafe`, `bar`, `pub`, `fast_food`, `pharmacy`, `hospital`, `bank`, `atm`, `fuel`, `parking`, `supermarket`, `convenience`, `hotel`
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Get the location.** Look for coordinates (`latitude: ... / longitude: ...`) from a Telegram pin, or ask the user for an address/city/zip.
|
||||
|
||||
2. **Ask for preferences** (only if not already stated): place type, how far they're willing to go, any specifics (cuisine, "open now", etc.).
|
||||
|
||||
3. **Run the script** with appropriate flags. Use `--json` if you need to process results programmatically.
|
||||
|
||||
4. **Present results** with names, distances, and Google Maps links. If the user asked about hours or "open now," check the `hours` field in results — if missing or unclear, verify with `web_search`.
|
||||
|
||||
5. **For directions**, use the `directions_url` from results, or construct: `https://www.google.com/maps/dir/?api=1&origin=<LAT>,<LON>&destination=<LAT>,<LON>`
|
||||
|
||||
## Tips
|
||||
|
||||
- If results are sparse, widen the radius (1500 → 3000m)
|
||||
- For "open now" requests: check the `hours` field in results, cross-reference with `web_search` for accuracy since OSM hours aren't always complete
|
||||
- Zip codes alone can be ambiguous globally — prompt the user for country/state if results look wrong
|
||||
- The script uses OpenStreetMap data which is community-maintained; coverage varies by region
|
||||
184
skills/leisure/find-nearby/scripts/find_nearby.py
Normal file
184
skills/leisure/find-nearby/scripts/find_nearby.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Find nearby places using OpenStreetMap (Overpass + Nominatim). No API keys needed.
|
||||
|
||||
Usage:
|
||||
# By coordinates
|
||||
python find_nearby.py --lat 36.17 --lon -115.14 --type restaurant --radius 1500
|
||||
|
||||
# By address/city/zip (auto-geocoded)
|
||||
python find_nearby.py --near "Times Square, New York" --type cafe --radius 1000
|
||||
python find_nearby.py --near "90210" --type pharmacy
|
||||
|
||||
# Multiple types
|
||||
python find_nearby.py --lat 36.17 --lon -115.14 --type restaurant --type bar
|
||||
|
||||
# JSON output for programmatic use
|
||||
python find_nearby.py --near "downtown las vegas" --type restaurant --json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
OVERPASS_URLS = [
|
||||
"https://overpass-api.de/api/interpreter",
|
||||
"https://overpass.kumi.systems/api/interpreter",
|
||||
]
|
||||
NOMINATIM_URL = "https://nominatim.openstreetmap.org/search"
|
||||
USER_AGENT = "HermesAgent/1.0 (find-nearby skill)"
|
||||
TIMEOUT = 15
|
||||
|
||||
|
||||
def _http_get(url: str) -> Any:
|
||||
req = urllib.request.Request(url, headers={"User-Agent": USER_AGENT})
|
||||
with urllib.request.urlopen(req, timeout=TIMEOUT) as r:
|
||||
return json.loads(r.read())
|
||||
|
||||
|
||||
def _http_post(url: str, data: str) -> Any:
|
||||
req = urllib.request.Request(
|
||||
url, data=data.encode(), headers={"User-Agent": USER_AGENT}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=TIMEOUT) as r:
|
||||
return json.loads(r.read())
|
||||
|
||||
|
||||
def haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
||||
"""Distance in meters between two coordinates."""
|
||||
R = 6_371_000
|
||||
rlat1, rlat2 = math.radians(lat1), math.radians(lat2)
|
||||
dlat = math.radians(lat2 - lat1)
|
||||
dlon = math.radians(lon2 - lon1)
|
||||
a = math.sin(dlat / 2) ** 2 + math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2) ** 2
|
||||
return R * 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
|
||||
|
||||
def geocode(query: str) -> tuple[float, float]:
|
||||
"""Convert address/city/zip to coordinates via Nominatim."""
|
||||
params = urllib.parse.urlencode({"q": query, "format": "json", "limit": 1})
|
||||
results = _http_get(f"{NOMINATIM_URL}?{params}")
|
||||
if not results:
|
||||
print(f"Error: Could not geocode '{query}'. Try a more specific address.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return float(results[0]["lat"]), float(results[0]["lon"])
|
||||
|
||||
|
||||
def find_nearby(lat: float, lon: float, types: list[str], radius: int = 1500, limit: int = 15) -> list[dict]:
|
||||
"""Query Overpass for nearby amenities."""
|
||||
# Build Overpass QL query
|
||||
type_filters = "".join(
|
||||
f'nwr["amenity"="{t}"](around:{radius},{lat},{lon});' for t in types
|
||||
)
|
||||
query = f"[out:json][timeout:{TIMEOUT}];({type_filters});out center tags;"
|
||||
|
||||
# Try each Overpass server
|
||||
data = None
|
||||
for url in OVERPASS_URLS:
|
||||
try:
|
||||
data = _http_post(url, f"data={urllib.parse.quote(query)}")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# Parse results
|
||||
places = []
|
||||
for el in data.get("elements", []):
|
||||
tags = el.get("tags", {})
|
||||
name = tags.get("name")
|
||||
if not name:
|
||||
continue
|
||||
|
||||
# Get coordinates (nodes have lat/lon directly, ways/relations use center)
|
||||
plat = el.get("lat") or (el.get("center", {}) or {}).get("lat")
|
||||
plon = el.get("lon") or (el.get("center", {}) or {}).get("lon")
|
||||
if not plat or not plon:
|
||||
continue
|
||||
|
||||
dist = haversine(lat, lon, plat, plon)
|
||||
|
||||
place = {
|
||||
"name": name,
|
||||
"type": tags.get("amenity", ""),
|
||||
"distance_m": round(dist),
|
||||
"lat": plat,
|
||||
"lon": plon,
|
||||
"maps_url": f"https://www.google.com/maps/search/?api=1&query={plat},{plon}",
|
||||
"directions_url": f"https://www.google.com/maps/dir/?api=1&origin={lat},{lon}&destination={plat},{plon}",
|
||||
}
|
||||
|
||||
# Add useful optional fields
|
||||
if tags.get("cuisine"):
|
||||
place["cuisine"] = tags["cuisine"]
|
||||
if tags.get("opening_hours"):
|
||||
place["hours"] = tags["opening_hours"]
|
||||
if tags.get("phone"):
|
||||
place["phone"] = tags["phone"]
|
||||
if tags.get("website"):
|
||||
place["website"] = tags["website"]
|
||||
if tags.get("addr:street"):
|
||||
addr_parts = [tags.get("addr:housenumber", ""), tags.get("addr:street", "")]
|
||||
if tags.get("addr:city"):
|
||||
addr_parts.append(tags["addr:city"])
|
||||
place["address"] = " ".join(p for p in addr_parts if p)
|
||||
|
||||
places.append(place)
|
||||
|
||||
# Sort by distance, limit results
|
||||
places.sort(key=lambda p: p["distance_m"])
|
||||
return places[:limit]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Find nearby places via OpenStreetMap")
|
||||
parser.add_argument("--lat", type=float, help="Latitude")
|
||||
parser.add_argument("--lon", type=float, help="Longitude")
|
||||
parser.add_argument("--near", type=str, help="Address, city, or zip code (geocoded automatically)")
|
||||
parser.add_argument("--type", action="append", dest="types", default=[], help="Place type (restaurant, cafe, bar, pharmacy, etc.)")
|
||||
parser.add_argument("--radius", type=int, default=1500, help="Search radius in meters (default: 1500)")
|
||||
parser.add_argument("--limit", type=int, default=15, help="Max results (default: 15)")
|
||||
parser.add_argument("--json", action="store_true", dest="json_output", help="Output as JSON")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve coordinates
|
||||
if args.near:
|
||||
lat, lon = geocode(args.near)
|
||||
elif args.lat is not None and args.lon is not None:
|
||||
lat, lon = args.lat, args.lon
|
||||
else:
|
||||
print("Error: Provide --lat/--lon or --near", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not args.types:
|
||||
args.types = ["restaurant"]
|
||||
|
||||
places = find_nearby(lat, lon, args.types, args.radius, args.limit)
|
||||
|
||||
if args.json_output:
|
||||
print(json.dumps({"origin": {"lat": lat, "lon": lon}, "results": places, "count": len(places)}, indent=2))
|
||||
else:
|
||||
if not places:
|
||||
print(f"No {'/'.join(args.types)} found within {args.radius}m")
|
||||
return
|
||||
print(f"Found {len(places)} places within {args.radius}m:\n")
|
||||
for i, p in enumerate(places, 1):
|
||||
dist_str = f"{p['distance_m']}m" if p["distance_m"] < 1000 else f"{p['distance_m']/1000:.1f}km"
|
||||
print(f" {i}. {p['name']} ({p['type']}) — {dist_str}")
|
||||
if p.get("cuisine"):
|
||||
print(f" Cuisine: {p['cuisine']}")
|
||||
if p.get("hours"):
|
||||
print(f" Hours: {p['hours']}")
|
||||
if p.get("address"):
|
||||
print(f" Address: {p['address']}")
|
||||
print(f" Map: {p['maps_url']}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -321,6 +321,32 @@ mcp_servers:
|
||||
|
||||
All tools from all servers are registered and available simultaneously. Each server's tools are prefixed with its name to avoid collisions.
|
||||
|
||||
## Sampling (Server-Initiated LLM Requests)
|
||||
|
||||
Hermes supports MCP's `sampling/createMessage` capability — MCP servers can request LLM completions through the agent during tool execution. This enables agent-in-the-loop workflows (data analysis, content generation, decision-making).
|
||||
|
||||
Sampling is **enabled by default**. Configure per server:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
my_server:
|
||||
command: "npx"
|
||||
args: ["-y", "my-mcp-server"]
|
||||
sampling:
|
||||
enabled: true # default: true
|
||||
model: "gemini-3-flash" # model override (optional)
|
||||
max_tokens_cap: 4096 # max tokens per request
|
||||
timeout: 30 # LLM call timeout (seconds)
|
||||
max_rpm: 10 # max requests per minute
|
||||
allowed_models: [] # model whitelist (empty = all)
|
||||
max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
log_level: "info" # audit verbosity
|
||||
```
|
||||
|
||||
Servers can also include `tools` in sampling requests for multi-turn tool-augmented workflows. The `max_tool_rounds` config prevents infinite tool loops. Per-server audit metrics (requests, errors, tokens, tool use count) are tracked via `get_mcp_status()`.
|
||||
|
||||
Disable sampling for untrusted servers with `sampling: { enabled: false }`.
|
||||
|
||||
## Notes
|
||||
|
||||
- MCP tools are called synchronously from the agent's perspective but run asynchronously on a dedicated background event loop
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
Media content extraction and transformation tools — YouTube transcripts, audio, video processing.
|
||||
---
|
||||
description: Skills for working with media content — YouTube transcripts, GIF search, music generation, and audio visualization.
|
||||
---
|
||||
|
||||
3
skills/mlops/cloud/DESCRIPTION.md
Normal file
3
skills/mlops/cloud/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: GPU cloud providers and serverless compute platforms for ML workloads.
|
||||
---
|
||||
3
skills/mlops/evaluation/DESCRIPTION.md
Normal file
3
skills/mlops/evaluation/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Model evaluation benchmarks, experiment tracking, data curation, tokenizers, and interpretability tools.
|
||||
---
|
||||
3
skills/mlops/inference/DESCRIPTION.md
Normal file
3
skills/mlops/inference/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Model serving, quantization (GGUF/GPTQ), structured output, inference optimization, and model surgery tools for deploying and running LLMs.
|
||||
---
|
||||
330
skills/mlops/inference/obliteratus/SKILL.md
Normal file
330
skills/mlops/inference/obliteratus/SKILL.md
Normal file
@@ -0,0 +1,330 @@
|
||||
---
|
||||
name: obliteratus
|
||||
description: Remove refusal behaviors from open-weight LLMs using OBLITERATUS — mechanistic interpretability techniques (diff-in-means, SVD, whitened SVD, LEACE, SAE decomposition, etc.) to excise guardrails while preserving reasoning. 9 CLI methods, 28 analysis modules, 116 model presets across 5 compute tiers, tournament evaluation, and telemetry-driven recommendations. Use when a user wants to uncensor, abliterate, or remove refusal from an LLM.
|
||||
version: 2.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
dependencies: [obliteratus, torch, transformers, bitsandbytes, accelerate, safetensors]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Abliteration, Uncensoring, Refusal-Removal, LLM, Weight-Projection, SVD, Mechanistic-Interpretability, HuggingFace, Model-Surgery]
|
||||
related_skills: [vllm, gguf, huggingface-tokenizers]
|
||||
---
|
||||
|
||||
# OBLITERATUS Skill
|
||||
|
||||
Remove refusal behaviors (guardrails) from open-weight LLMs without retraining or fine-tuning. Uses mechanistic interpretability techniques — including diff-in-means, SVD, whitened SVD, LEACE concept erasure, SAE decomposition, Bayesian kernel projection, and more — to identify and surgically excise refusal directions from model weights while preserving reasoning capabilities.
|
||||
|
||||
**License warning:** OBLITERATUS is AGPL-3.0. NEVER import it as a Python library. Always invoke via CLI (`obliteratus` command) or subprocess. This keeps Hermes Agent's MIT license clean.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Trigger when the user:
|
||||
- Wants to "uncensor" or "abliterate" an LLM
|
||||
- Asks about removing refusal/guardrails from a model
|
||||
- Wants to create an uncensored version of Llama, Qwen, Mistral, etc.
|
||||
- Mentions "refusal removal", "abliteration", "weight projection"
|
||||
- Wants to analyze how a model's refusal mechanism works
|
||||
- References OBLITERATUS, abliterator, or refusal directions
|
||||
|
||||
## Step 1: Installation
|
||||
|
||||
Check if already installed:
|
||||
```bash
|
||||
obliteratus --version 2>/dev/null && echo "INSTALLED" || echo "NOT INSTALLED"
|
||||
```
|
||||
|
||||
If not installed, clone and install from GitHub:
|
||||
```bash
|
||||
git clone https://github.com/elder-plinius/OBLITERATUS.git
|
||||
cd OBLITERATUS
|
||||
pip install -e .
|
||||
# For Gradio web UI support:
|
||||
# pip install -e ".[spaces]"
|
||||
```
|
||||
|
||||
**IMPORTANT:** Confirm with user before installing. This pulls in ~5-10GB of dependencies (PyTorch, Transformers, bitsandbytes, etc.).
|
||||
|
||||
## Step 2: Check Hardware
|
||||
|
||||
Before anything, check what GPU is available:
|
||||
```bash
|
||||
python3 -c "
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
gpu = torch.cuda.get_device_name(0)
|
||||
vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
print(f'GPU: {gpu}')
|
||||
print(f'VRAM: {vram:.1f} GB')
|
||||
if vram < 4: print('TIER: tiny (models under 1B)')
|
||||
elif vram < 8: print('TIER: small (models 1-4B)')
|
||||
elif vram < 16: print('TIER: medium (models 4-9B with 4bit quant)')
|
||||
elif vram < 32: print('TIER: large (models 8-32B with 4bit quant)')
|
||||
else: print('TIER: frontier (models 32B+)')
|
||||
else:
|
||||
print('NO GPU - only tiny models (under 1B) on CPU')
|
||||
"
|
||||
```
|
||||
|
||||
### VRAM Requirements (with 4-bit quantization)
|
||||
|
||||
| VRAM | Max Model Size | Example Models |
|
||||
|:---------|:----------------|:--------------------------------------------|
|
||||
| CPU only | ~1B params | GPT-2, TinyLlama, SmolLM |
|
||||
| 4-8 GB | ~4B params | Qwen2.5-1.5B, Phi-3.5 mini, Llama 3.2 3B |
|
||||
| 8-16 GB | ~9B params | Llama 3.1 8B, Mistral 7B, Gemma 2 9B |
|
||||
| 24 GB | ~32B params | Qwen3-32B, Llama 3.1 70B (tight), Command-R |
|
||||
| 48 GB+ | ~72B+ params | Qwen2.5-72B, DeepSeek-R1 |
|
||||
| Multi-GPU| 200B+ params | Llama 3.1 405B, DeepSeek-V3 (685B MoE) |
|
||||
|
||||
## Step 3: Browse Available Models & Get Recommendations
|
||||
|
||||
```bash
|
||||
# Browse models by compute tier
|
||||
obliteratus models --tier medium
|
||||
|
||||
# Get architecture info for a specific model
|
||||
obliteratus info <model_name>
|
||||
|
||||
# Get telemetry-driven recommendation for best method & params
|
||||
obliteratus recommend <model_name>
|
||||
obliteratus recommend <model_name> --insights # global cross-architecture rankings
|
||||
```
|
||||
|
||||
## Step 4: Choose a Method
|
||||
|
||||
### Method Selection Guide
|
||||
**Default / recommended for most cases: `advanced`.** It uses multi-direction SVD with norm-preserving projection and is well-tested.
|
||||
|
||||
| Situation | Recommended Method | Why |
|
||||
|:----------------------------------|:-------------------|:-----------------------------------------|
|
||||
| Default / most models | `advanced` | Multi-direction SVD, norm-preserving, reliable |
|
||||
| Quick test / prototyping | `basic` | Fast, simple, good enough to evaluate |
|
||||
| Dense model (Llama, Mistral) | `advanced` | Multi-direction, norm-preserving |
|
||||
| MoE model (DeepSeek, Mixtral) | `nuclear` | Expert-granular, handles MoE complexity |
|
||||
| Reasoning model (R1 distills) | `surgical` | CoT-aware, preserves chain-of-thought |
|
||||
| Stubborn refusals persist | `aggressive` | Whitened SVD + head surgery + jailbreak |
|
||||
| Want reversible changes | Use steering vectors (see Analysis section) |
|
||||
| Maximum quality, time no object | `optimized` | Bayesian search for best parameters |
|
||||
| Experimental auto-detection | `informed` | Auto-detects alignment type — experimental, may not always outperform advanced |
|
||||
|
||||
### 9 CLI Methods
|
||||
- **basic** — Single refusal direction via diff-in-means. Fast (~5-10 min for 8B).
|
||||
- **advanced** (DEFAULT, RECOMMENDED) — Multiple SVD directions, norm-preserving projection, 2 refinement passes. Medium speed (~10-20 min).
|
||||
- **aggressive** — Whitened SVD + jailbreak-contrastive + attention head surgery. Higher risk of coherence damage.
|
||||
- **spectral_cascade** — DCT frequency-domain decomposition. Research/novel approach.
|
||||
- **informed** — Runs analysis DURING abliteration to auto-configure. Experimental — slower and less predictable than advanced.
|
||||
- **surgical** — SAE features + neuron masking + head surgery + per-expert. Very slow (~1-2 hrs). Best for reasoning models.
|
||||
- **optimized** — Bayesian hyperparameter search (Optuna TPE). Longest runtime but finds optimal parameters.
|
||||
- **inverted** — Flips the refusal direction. Model becomes actively willing.
|
||||
- **nuclear** — Maximum force combo for stubborn MoE models. Expert-granular.
|
||||
|
||||
### Direction Extraction Methods (--direction-method flag)
|
||||
- **diff_means** (default) — Simple difference-in-means between refused/complied activations. Robust.
|
||||
- **svd** — Multi-direction SVD extraction. Better for complex alignment.
|
||||
- **leace** — LEACE (Linear Erasure via Closed-form Estimation). Optimal linear erasure.
|
||||
|
||||
### 4 Python-API-Only Methods
|
||||
(NOT available via CLI — require Python import, which violates AGPL boundary. Mention to user only if they explicitly want to use OBLITERATUS as a library in their own AGPL project.)
|
||||
- failspy, gabliteration, heretic, rdo
|
||||
|
||||
## Step 5: Run Abliteration
|
||||
|
||||
### Standard usage
|
||||
```bash
|
||||
# Default method (advanced) — recommended for most models
|
||||
obliteratus obliterate <model_name> --method advanced --output-dir ./abliterated-models
|
||||
|
||||
# With 4-bit quantization (saves VRAM)
|
||||
obliteratus obliterate <model_name> --method advanced --quantization 4bit --output-dir ./abliterated-models
|
||||
|
||||
# Large models (70B+) — conservative defaults
|
||||
obliteratus obliterate <model_name> --method advanced --quantization 4bit --large-model --output-dir ./abliterated-models
|
||||
```
|
||||
|
||||
### Fine-tuning parameters
|
||||
```bash
|
||||
obliteratus obliterate <model_name> \
|
||||
--method advanced \
|
||||
--direction-method diff_means \
|
||||
--n-directions 4 \
|
||||
--refinement-passes 2 \
|
||||
--regularization 0.1 \
|
||||
--quantization 4bit \
|
||||
--output-dir ./abliterated-models \
|
||||
--contribute # opt-in telemetry for community research
|
||||
```
|
||||
|
||||
### Key flags
|
||||
| Flag | Description | Default |
|
||||
|:-----|:------------|:--------|
|
||||
| `--method` | Abliteration method | advanced |
|
||||
| `--direction-method` | Direction extraction | diff_means |
|
||||
| `--n-directions` | Number of refusal directions (1-32) | method-dependent |
|
||||
| `--refinement-passes` | Iterative passes (1-5) | 2 |
|
||||
| `--regularization` | Regularization strength (0.0-1.0) | 0.1 |
|
||||
| `--quantization` | Load in 4bit or 8bit | none (full precision) |
|
||||
| `--large-model` | Conservative defaults for 120B+ | false |
|
||||
| `--output-dir` | Where to save the abliterated model | ./obliterated_model |
|
||||
| `--contribute` | Share anonymized results for research | false |
|
||||
| `--verify-sample-size` | Number of test prompts for refusal check | 20 |
|
||||
| `--dtype` | Model dtype (float16, bfloat16) | auto |
|
||||
|
||||
### Other execution modes
|
||||
```bash
|
||||
# Interactive guided mode (hardware → model → preset)
|
||||
obliteratus interactive
|
||||
|
||||
# Web UI (Gradio)
|
||||
obliteratus ui --port 7860
|
||||
|
||||
# Run a full ablation study from YAML config
|
||||
obliteratus run config.yaml --preset quick
|
||||
|
||||
# Tournament: pit all methods against each other
|
||||
obliteratus tourney <model_name>
|
||||
```
|
||||
|
||||
## Step 6: Verify Results
|
||||
|
||||
After abliteration, check the output metrics:
|
||||
|
||||
| Metric | Good Value | Warning |
|
||||
|:-------|:-----------|:--------|
|
||||
| Refusal rate | < 5% (ideally ~0%) | > 10% means refusals persist |
|
||||
| Perplexity change | < 10% increase | > 15% means coherence damage |
|
||||
| KL divergence | < 0.1 | > 0.5 means significant distribution shift |
|
||||
| Coherence | High / passes qualitative check | Degraded responses, repetition |
|
||||
|
||||
### If refusals persist (> 10%)
|
||||
1. Try `aggressive` method
|
||||
2. Increase `--n-directions` (e.g., 8 or 16)
|
||||
3. Add `--refinement-passes 3`
|
||||
4. Try `--direction-method svd` instead of diff_means
|
||||
|
||||
### If coherence is damaged (perplexity > 15% increase)
|
||||
1. Reduce `--n-directions` (try 2)
|
||||
2. Increase `--regularization` (try 0.3)
|
||||
3. Reduce `--refinement-passes` to 1
|
||||
4. Try `basic` method (gentler)
|
||||
|
||||
## Step 7: Use the Abliterated Model
|
||||
|
||||
The output is a standard HuggingFace model directory.
|
||||
|
||||
```bash
|
||||
# Test locally with transformers
|
||||
python3 -c "
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained('./abliterated-models/<model>')
|
||||
tokenizer = AutoTokenizer.from_pretrained('./abliterated-models/<model>')
|
||||
inputs = tokenizer('How do I pick a lock?', return_tensors='pt')
|
||||
outputs = model.generate(**inputs, max_new_tokens=200)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
"
|
||||
|
||||
# Upload to HuggingFace Hub
|
||||
huggingface-cli upload <username>/<model-name>-abliterated ./abliterated-models/<model>
|
||||
|
||||
# Serve with vLLM
|
||||
vllm serve ./abliterated-models/<model>
|
||||
```
|
||||
|
||||
## CLI Command Reference
|
||||
|
||||
| Command | Description |
|
||||
|:--------|:------------|
|
||||
| `obliteratus obliterate` | Main abliteration command |
|
||||
| `obliteratus info <model>` | Print model architecture details |
|
||||
| `obliteratus models --tier <tier>` | Browse curated models by compute tier |
|
||||
| `obliteratus recommend <model>` | Telemetry-driven method/param suggestion |
|
||||
| `obliteratus interactive` | Guided setup wizard |
|
||||
| `obliteratus tourney <model>` | Tournament: all methods head-to-head |
|
||||
| `obliteratus run <config.yaml>` | Execute ablation study from YAML |
|
||||
| `obliteratus strategies` | List all registered ablation strategies |
|
||||
| `obliteratus report <results.json>` | Regenerate visual reports |
|
||||
| `obliteratus ui` | Launch Gradio web interface |
|
||||
| `obliteratus aggregate` | Summarize community telemetry data |
|
||||
|
||||
## Analysis Modules
|
||||
|
||||
OBLITERATUS includes 28 analysis modules for mechanistic interpretability.
|
||||
See `skill_view(name="obliteratus", file_path="references/analysis-modules.md")` for the full reference.
|
||||
|
||||
### Quick analysis commands
|
||||
```bash
|
||||
# Run specific analysis modules
|
||||
obliteratus run analysis-config.yaml --preset quick
|
||||
|
||||
# Key modules to run first:
|
||||
# - alignment_imprint: Fingerprint DPO/RLHF/CAI/SFT alignment method
|
||||
# - concept_geometry: Single direction vs polyhedral cone
|
||||
# - logit_lens: Which layer decides to refuse
|
||||
# - anti_ouroboros: Self-repair risk score
|
||||
# - causal_tracing: Causally necessary components
|
||||
```
|
||||
|
||||
### Steering Vectors (Reversible Alternative)
|
||||
Instead of permanent weight modification, use inference-time steering:
|
||||
```python
|
||||
# Python API only — for user's own projects
|
||||
from obliteratus.analysis.steering_vectors import SteeringVectorFactory, SteeringHookManager
|
||||
```
|
||||
|
||||
## Ablation Strategies
|
||||
|
||||
Beyond direction-based abliteration, OBLITERATUS includes structural ablation strategies:
|
||||
- **Embedding Ablation** — Target embedding layer components
|
||||
- **FFN Ablation** — Feed-forward network block removal
|
||||
- **Head Pruning** — Attention head pruning
|
||||
- **Layer Removal** — Full layer removal
|
||||
|
||||
List all available: `obliteratus strategies`
|
||||
|
||||
## Evaluation
|
||||
|
||||
OBLITERATUS includes built-in evaluation tools:
|
||||
- Refusal rate benchmarking
|
||||
- Perplexity comparison (before/after)
|
||||
- LM Eval Harness integration for academic benchmarks
|
||||
- Head-to-head competitor comparison
|
||||
- Baseline performance tracking
|
||||
|
||||
## Platform Support
|
||||
|
||||
- **CUDA** — Full support (NVIDIA GPUs)
|
||||
- **Apple Silicon (MLX)** — Supported via MLX backend
|
||||
- **CPU** — Supported for tiny models (< 1B params)
|
||||
|
||||
## YAML Config Templates
|
||||
|
||||
Load templates for reproducible runs via `skill_view`:
|
||||
- `templates/abliteration-config.yaml` — Standard single-model config
|
||||
- `templates/analysis-study.yaml` — Pre-abliteration analysis study
|
||||
- `templates/batch-abliteration.yaml` — Multi-model batch processing
|
||||
|
||||
## Telemetry
|
||||
|
||||
OBLITERATUS can optionally contribute anonymized run data to a global research dataset.
|
||||
Enable with `--contribute` flag. No personal data is collected — only model name, method, metrics.
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **Don't use `informed` as default** — it's experimental and slower. Use `advanced` for reliable results.
|
||||
2. **Models under ~1B respond poorly to abliteration** — their refusal behaviors are shallow and fragmented, making clean direction extraction difficult. Expect partial results (20-40% remaining refusal). Models 3B+ have cleaner refusal directions and respond much better (often 0% refusal with `advanced`).
|
||||
3. **`aggressive` can make things worse** — on small models it can damage coherence and actually increase refusal rate. Only use it if `advanced` leaves > 10% refusals on a 3B+ model.
|
||||
4. **Always check perplexity** — if it spikes > 15%, the model is damaged. Reduce aggressiveness.
|
||||
5. **MoE models need special handling** — use `nuclear` method for Mixtral, DeepSeek-MoE, etc.
|
||||
6. **Quantized models can't be re-quantized** — abliterate the full-precision model, then quantize the output.
|
||||
7. **VRAM estimation is approximate** — 4-bit quant helps but peak usage can spike during extraction.
|
||||
8. **Reasoning models are sensitive** — use `surgical` for R1 distills to preserve chain-of-thought.
|
||||
9. **Check `obliteratus recommend`** — telemetry data may have better parameters than defaults.
|
||||
10. **AGPL license** — never `import obliteratus` in MIT/Apache projects. CLI invocation only.
|
||||
11. **Large models (70B+)** — always use `--large-model` flag for conservative defaults.
|
||||
12. **Spectral certification RED is common** — the spectral check often flags "incomplete" even when practical refusal rate is 0%. Check actual refusal rate rather than relying on spectral certification alone.
|
||||
|
||||
## Complementary Skills
|
||||
|
||||
- **vllm** — Serve abliterated models with high throughput
|
||||
- **gguf** — Convert abliterated models to GGUF for llama.cpp
|
||||
- **huggingface-tokenizers** — Work with model tokenizers
|
||||
@@ -0,0 +1,166 @@
|
||||
# OBLITERATUS Analysis Modules — Reference
|
||||
|
||||
OBLITERATUS includes 28 analysis modules for mechanistic interpretability of refusal in LLMs.
|
||||
These modules help understand how and where refusal behaviors are encoded before performing abliteration.
|
||||
|
||||
---
|
||||
|
||||
## Core Analysis (Run These First)
|
||||
|
||||
### 1. Alignment Imprint Detection (`alignment_imprint.py`)
|
||||
Fingerprints whether a model was trained via DPO, RLHF, CAI, or SFT.
|
||||
This determines which extraction strategy will work best.
|
||||
|
||||
### 2. Concept Cone Geometry (`concept_geometry.py`)
|
||||
Determines if refusal is a single linear direction or a polyhedral cone
|
||||
(set of multiple mechanisms). Single-direction models respond well to `basic`;
|
||||
polyhedral models need `advanced` or `surgical`.
|
||||
|
||||
### 3. Refusal Logit Lens (`logit_lens.py`)
|
||||
Identifies the specific layer where a model "decides" to refuse by decoding
|
||||
intermediate layer representations into token space.
|
||||
|
||||
### 4. Ouroboros Detection (`anti_ouroboros.py`)
|
||||
Identifies if a model attempts to "self-repair" refusal behaviors after
|
||||
excision. Reports a risk score (0-1). High scores mean additional refinement
|
||||
passes are needed.
|
||||
|
||||
### 5. Causal Tracing (`causal_tracing.py`)
|
||||
Identifies which components (layers, heads, MLPs) are causally necessary
|
||||
for refusal behavior using activation patching.
|
||||
|
||||
---
|
||||
|
||||
## Geometric Analysis
|
||||
|
||||
### 6. Cross-Layer Alignment (`cross_layer.py`)
|
||||
Measures how refusal directions align across different layers. High alignment
|
||||
means the refusal signal is consistent; low alignment suggests layer-specific
|
||||
mechanisms.
|
||||
|
||||
### 7. Residual Stream Decomposition (`residual_stream.py`)
|
||||
Decomposes the residual stream into attention and MLP contributions to
|
||||
understand which component type contributes more to refusal.
|
||||
|
||||
### 8. Riemannian Manifold Geometry (`riemannian_manifold.py`)
|
||||
Analyzes the curvature and geometry of the weight manifold near refusal
|
||||
directions. Informs how aggressively projections can be applied without
|
||||
damaging the manifold structure.
|
||||
|
||||
### 9. Whitened SVD (`whitened_svd.py`)
|
||||
Covariance-normalized SVD extraction that separates guardrail signals from
|
||||
natural activation variance. More precise than standard SVD for models with
|
||||
high activation variance.
|
||||
|
||||
### 10. Concept Cone Geometry (extended)
|
||||
Maps the full polyhedral structure of refusal, including cone angles,
|
||||
face counts, and intersection patterns.
|
||||
|
||||
---
|
||||
|
||||
## Probing & Classification
|
||||
|
||||
### 11. Activation Probing (`activation_probing.py`)
|
||||
Post-excision verification — probes for residual refusal concepts after
|
||||
abliteration to ensure complete removal.
|
||||
|
||||
### 12. Probing Classifiers (`probing_classifiers.py`)
|
||||
Trains linear classifiers to detect refusal in activations. Used both
|
||||
before (to verify refusal exists) and after (to verify it's gone).
|
||||
|
||||
### 13. Activation Patching (`activation_patching.py`)
|
||||
Interchange interventions — swaps activations between refused and complied
|
||||
runs to identify causal components.
|
||||
|
||||
### 14. Tuned Lens (`tuned_lens.py`)
|
||||
Trained version of logit lens that provides more accurate per-layer
|
||||
decoding by learning affine transformations for each layer.
|
||||
|
||||
### 15. Multi-Token Position Analysis (`multi_token_position.py`)
|
||||
Analyzes refusal signals across multiple token positions, not just the
|
||||
last token. Important for models that distribute refusal across the sequence.
|
||||
|
||||
---
|
||||
|
||||
## Abliteration & Manipulation
|
||||
|
||||
### 16. SAE-Based Abliteration (`sae_abliteration.py`)
|
||||
Uses Sparse Autoencoder features to identify and remove specific refusal
|
||||
features. More surgical than direction-based methods.
|
||||
|
||||
### 17. Steering Vectors (`steering_vectors.py`)
|
||||
Creates and applies inference-time steering vectors for reversible refusal
|
||||
modification. Includes `SteeringVectorFactory` and `SteeringHookManager`.
|
||||
|
||||
### 18. LEACE Concept Erasure (`leace.py`)
|
||||
Linear Erasure via Closed-form Estimation — mathematically optimal linear
|
||||
concept removal. Available as both analysis module and direction extraction method.
|
||||
|
||||
### 19. Sparse Surgery (`sparse_surgery.py`)
|
||||
High-precision weight modification targeting individual neurons and
|
||||
weight matrix entries rather than full directions.
|
||||
|
||||
### 20. Conditional Abliteration (`conditional_abliteration.py`)
|
||||
Targeted removal that only affects specific refusal categories while
|
||||
preserving others (e.g., remove weapons refusal but keep CSAM refusal).
|
||||
|
||||
---
|
||||
|
||||
## Transfer & Robustness
|
||||
|
||||
### 21. Cross-Model Transfer (`cross_model_transfer.py`)
|
||||
Tests whether refusal directions extracted from one model transfer to
|
||||
another architecture. Measures universality of guardrail directions.
|
||||
|
||||
### 22. Defense Robustness (`defense_robustness.py`)
|
||||
Evaluates how robust the abliteration is against various defense mechanisms
|
||||
and re-alignment attempts.
|
||||
|
||||
### 23. Spectral Certification (`spectral_certification.py`)
|
||||
Provides mathematical bounds on the completeness of refusal removal
|
||||
using spectral analysis of the projection.
|
||||
|
||||
### 24. Wasserstein Optimal Extraction (`wasserstein_optimal.py`)
|
||||
Uses optimal transport theory for more precise direction extraction
|
||||
that minimizes distribution shift.
|
||||
|
||||
### 25. Wasserstein Transfer (`wasserstein_transfer.py`)
|
||||
Distribution transfer between models using Wasserstein distance
|
||||
for cross-architecture refusal direction mapping.
|
||||
|
||||
---
|
||||
|
||||
## Advanced / Research
|
||||
|
||||
### 26. Bayesian Kernel Projection (`bayesian_kernel_projection.py`)
|
||||
Probabilistic feature mapping that estimates uncertainty in refusal
|
||||
direction identification.
|
||||
|
||||
### 27. Cross-Model Universality Index
|
||||
Measures if guardrail directions generalize across different model
|
||||
architectures and training regimes.
|
||||
|
||||
### 28. Visualization (`visualization.py`)
|
||||
Plotting and graphing utilities for all analysis modules. Generates
|
||||
heatmaps, direction plots, and layer-wise analysis charts.
|
||||
|
||||
---
|
||||
|
||||
## Running Analysis
|
||||
|
||||
### Via CLI
|
||||
```bash
|
||||
# Run analysis from a YAML config
|
||||
obliteratus run analysis-study.yaml --preset quick
|
||||
|
||||
# Available study presets:
|
||||
# quick — Fast sanity check (2-3 modules)
|
||||
# full — All core + geometric analysis
|
||||
# jailbreak — Refusal circuit localization
|
||||
# knowledge — Knowledge preservation analysis
|
||||
# robustness — Stress testing / defense evaluation
|
||||
```
|
||||
|
||||
### Via YAML Config
|
||||
See the `templates/analysis-study.yaml` template for a complete example.
|
||||
Load with: `skill_view(name="obliteratus", file_path="templates/analysis-study.yaml")`
|
||||
141
skills/mlops/inference/obliteratus/references/methods-guide.md
Normal file
141
skills/mlops/inference/obliteratus/references/methods-guide.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# OBLITERATUS Methods — Detailed Guide
|
||||
|
||||
> The CLI accepts 9 methods via `--method`: basic, advanced, aggressive, spectral_cascade,
|
||||
> informed, surgical, optimized, inverted, nuclear.
|
||||
> Four additional methods (failspy, gabliteration, heretic, rdo) are available only via the Python API.
|
||||
|
||||
## How Abliteration Works (Theory)
|
||||
|
||||
Abliteration identifies a "refusal direction" — a vector in the model's activation space that
|
||||
corresponds to refusal behavior — and projects it out of the weight matrices.
|
||||
|
||||
Mathematically: `W_new = W_old - (W_old @ d @ d.T)` where `d` is the refusal direction.
|
||||
|
||||
The key challenge is finding accurate refusal directions without damaging other capabilities.
|
||||
|
||||
---
|
||||
|
||||
## Direction Extraction Methods
|
||||
|
||||
Before projecting, OBLITERATUS extracts refusal directions using one of three methods:
|
||||
|
||||
| Method | Flag | Description | Best For |
|
||||
|:-------|:-----|:------------|:---------|
|
||||
| Diff-in-Means | `--direction-method diff_means` | Difference between mean activations on refused vs. complied prompts | Default, fast, robust |
|
||||
| SVD | `--direction-method svd` | Multi-direction extraction via Singular Value Decomposition | Complex alignment, multiple refusal mechanisms |
|
||||
| LEACE | `--direction-method leace` | Linear Erasure via Closed-form Estimation — mathematically optimal | Maximum precision, research |
|
||||
|
||||
---
|
||||
|
||||
## Method Details
|
||||
|
||||
### basic
|
||||
- **Directions:** 1 (single diff-in-means vector)
|
||||
- **Speed:** Fast (~5-10 min for 8B model)
|
||||
- **Risk:** Low
|
||||
- **Use case:** Quick tests, prototyping, evaluating if abliteration works for a model
|
||||
- **How it works:** Extracts one refusal direction and projects it out uniformly across all layers.
|
||||
|
||||
### advanced (DEFAULT — RECOMMENDED)
|
||||
- **Directions:** 4 (multi-direction SVD)
|
||||
- **Speed:** Medium (~10-20 min for 8B model)
|
||||
- **Risk:** Low-Medium
|
||||
- **Refinement passes:** 2
|
||||
- **Use case:** Default for most models. Well-tested and reliable.
|
||||
- **How it works:** Extracts multiple refusal directions via SVD, applies norm-preserving bi-projection to maintain weight matrix norms. Two refinement passes catch residual refusal.
|
||||
|
||||
### aggressive
|
||||
- **Directions:** 8+ (whitened SVD + jailbreak-contrastive)
|
||||
- **Speed:** Medium-Slow
|
||||
- **Risk:** Medium-High (may damage coherence)
|
||||
- **Use case:** When `advanced` leaves > 10% refusals. Stubborn models.
|
||||
- **How it works:** Uses whitened SVD for covariance-normalized extraction, adds jailbreak-contrastive directions, performs attention head surgery on the most refusal-active heads.
|
||||
|
||||
### spectral_cascade
|
||||
- **Speed:** Medium
|
||||
- **Risk:** Medium
|
||||
- **Use case:** Research, novel approaches
|
||||
- **How it works:** DCT (Discrete Cosine Transform) frequency-domain decomposition of refusal signals. Separates high-frequency (surface-level) from low-frequency (deep) refusal patterns.
|
||||
|
||||
### informed (EXPERIMENTAL)
|
||||
- **Speed:** Slow (~20-40 min for 8B model)
|
||||
- **Risk:** Variable — results depend on analysis quality
|
||||
- **Use case:** When you want auto-configuration, but be aware this is experimental and may not outperform `advanced`.
|
||||
- **How it works:** Runs 4 analysis modules first (alignment imprint, concept geometry, logit lens, ouroboros detection), then auto-configures extraction strategy. Includes an "Ouroboros loop" that detects and counteracts self-repair.
|
||||
- **Note:** The auto-detection can sometimes misconfigure. If results are poor, fall back to `advanced`.
|
||||
|
||||
### surgical
|
||||
- **Speed:** Very slow (~1-2 hrs for 8B model)
|
||||
- **Risk:** Low (very precise)
|
||||
- **Use case:** Reasoning models (R1 distills, QwQ, etc.) where chain-of-thought must be preserved.
|
||||
- **How it works:** Uses SAE (Sparse Autoencoder) features + individual neuron masking + attention head surgery + per-expert decomposition (for MoE). CoT-aware — identifies and protects reasoning-critical directions before projecting.
|
||||
|
||||
### optimized
|
||||
- **Speed:** Very slow (hours — runs many trials)
|
||||
- **Risk:** Low (finds optimal parameters)
|
||||
- **Use case:** When quality matters more than speed. Production models.
|
||||
- **How it works:** Bayesian hyperparameter search via Optuna TPE sampler. Optimizes n_directions, regularization, refinement passes, and layer selection jointly. Evaluates each configuration on refusal rate + perplexity.
|
||||
|
||||
### inverted
|
||||
- **Speed:** Fast
|
||||
- **Risk:** High (model behavior changes dramatically)
|
||||
- **Use case:** Research, studying refusal mechanisms
|
||||
- **How it works:** Instead of projecting out the refusal direction, reflects it. The model actively complies rather than passively not-refusing. Useful for understanding the geometry of alignment.
|
||||
|
||||
### nuclear
|
||||
- **Speed:** Slow
|
||||
- **Risk:** Medium-High
|
||||
- **Use case:** Stubborn MoE models (DeepSeek-MoE, Mixtral, etc.)
|
||||
- **How it works:** Combines expert-granular abliteration (EGA), steering vector injection, attention head pruning, and multi-pass refinement. Decomposes refusal signals into per-expert components for MoE architectures.
|
||||
|
||||
---
|
||||
|
||||
## Method Selection Flowchart
|
||||
|
||||
```
|
||||
Is this a quick test?
|
||||
→ YES: basic
|
||||
→ NO: continue
|
||||
|
||||
Is it an MoE model (Mixtral, DeepSeek-MoE)?
|
||||
→ YES: nuclear
|
||||
→ NO: continue
|
||||
|
||||
Is it a reasoning model (R1, QwQ, CoT-focused)?
|
||||
→ YES: surgical
|
||||
→ NO: continue
|
||||
|
||||
Do you need the absolute best quality and have time?
|
||||
→ YES: optimized
|
||||
→ NO: advanced (recommended default)
|
||||
|
||||
Did advanced leave > 10% refusals?
|
||||
→ YES: aggressive
|
||||
→ Still refusing: nuclear
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Parameters
|
||||
|
||||
| Parameter | Range | Default | Effect |
|
||||
|:----------|:------|:--------|:-------|
|
||||
| `--n-directions` | 1-32 | method-dependent | More directions = more complete removal, but higher damage risk |
|
||||
| `--regularization` | 0.0-1.0 | 0.1 | Higher = more conservative (less removal, less damage) |
|
||||
| `--refinement-passes` | 1-5 | 2 | More passes catch residual refusal, but diminishing returns |
|
||||
| `--quantization` | 4bit, 8bit | none | Reduces VRAM usage; quality impact minimal for extraction |
|
||||
| `--verify-sample-size` | 10-200 | 20 | More samples = more accurate refusal rate estimate |
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Likely Cause | Fix |
|
||||
|:--------|:-------------|:----|
|
||||
| Refusal rate > 20% | Too few directions | Increase `--n-directions`, try `aggressive` |
|
||||
| Refusal rate 5-20% | Residual refusal | Add `--refinement-passes 3`, try `--direction-method svd` |
|
||||
| Perplexity spike > 20% | Over-aggressive removal | Reduce `--n-directions`, increase `--regularization` |
|
||||
| Repetitive output | Weight matrix damage | Use `basic` with fewer directions, check norm preservation |
|
||||
| MoE model still refuses | Non-expert-aware method | Switch to `nuclear` |
|
||||
| Reasoning degraded | CoT directions damaged | Use `surgical` method |
|
||||
| OOM during extraction | Insufficient VRAM | Add `--quantization 4bit` and/or `--large-model` |
|
||||
3
skills/mlops/models/DESCRIPTION.md
Normal file
3
skills/mlops/models/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Specific model architectures and tools — computer vision (CLIP, SAM, Stable Diffusion), speech (Whisper), audio generation (AudioCraft), and multimodal models (LLaVA).
|
||||
---
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user