mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-07 11:17:07 +08:00
Compare commits
2 Commits
fix/memory
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b31bd883e | ||
|
|
d346b6f93f |
@@ -65,15 +65,10 @@ OPENCODE_GO_API_KEY=
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# Parallel API Key - AI-native web search and extract
|
||||
# Get at: https://parallel.ai
|
||||
PARALLEL_API_KEY=
|
||||
|
||||
# Firecrawl API Key - Web search, extract, and crawl
|
||||
# Get at: https://firecrawl.dev/
|
||||
FIRECRAWL_API_KEY=
|
||||
|
||||
|
||||
# FAL.ai API Key - Image generation
|
||||
# Get at: https://fal.ai/
|
||||
FAL_KEY=
|
||||
|
||||
@@ -44,7 +44,7 @@ hermes-agent/
|
||||
│ ├── terminal_tool.py # Terminal orchestration
|
||||
│ ├── process_registry.py # Background process management
|
||||
│ ├── file_tools.py # File read/write/search/patch
|
||||
│ ├── web_tools.py # Web search/extract (Parallel + Firecrawl)
|
||||
│ ├── web_tools.py # Firecrawl search/extract
|
||||
│ ├── browser_tool.py # Browserbase browser automation
|
||||
│ ├── code_execution_tool.py # execute_code sandbox
|
||||
│ ├── delegate_tool.py # Subagent delegation
|
||||
|
||||
@@ -147,7 +147,7 @@ hermes-agent/
|
||||
│ ├── approval.py # Dangerous command detection + per-session approval
|
||||
│ ├── terminal_tool.py # Terminal orchestration (sudo, env lifecycle, backends)
|
||||
│ ├── file_operations.py # read_file, write_file, search, patch, etc.
|
||||
│ ├── web_tools.py # web_search, web_extract (Parallel/Firecrawl + Gemini summarization)
|
||||
│ ├── web_tools.py # web_search, web_extract (Firecrawl + Gemini summarization)
|
||||
│ ├── vision_tools.py # Image analysis via multimodal models
|
||||
│ ├── delegate_tool.py # Subagent spawning and parallel task execution
|
||||
│ ├── code_execution_tool.py # Sandboxed Python with RPC tool access
|
||||
|
||||
@@ -963,12 +963,8 @@ def convert_messages_to_anthropic(
|
||||
elif isinstance(prev_blocks, str) and isinstance(curr_blocks, str):
|
||||
fixed[-1]["content"] = prev_blocks + "\n" + curr_blocks
|
||||
else:
|
||||
# Mixed types — normalize both to list and merge
|
||||
if isinstance(prev_blocks, str):
|
||||
prev_blocks = [{"type": "text", "text": prev_blocks}]
|
||||
if isinstance(curr_blocks, str):
|
||||
curr_blocks = [{"type": "text", "text": curr_blocks}]
|
||||
fixed[-1]["content"] = prev_blocks + curr_blocks
|
||||
# Keep the later message
|
||||
fixed[-1] = m
|
||||
else:
|
||||
fixed.append(m)
|
||||
result = fixed
|
||||
@@ -1053,8 +1049,7 @@ def build_anthropic_kwargs(
|
||||
elif tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "none":
|
||||
# Anthropic has no tool_choice "none" — omit tools entirely to prevent use
|
||||
kwargs.pop("tools", None)
|
||||
pass # Don't send tool_choice — Anthropic will use tools if needed
|
||||
elif isinstance(tool_choice, str):
|
||||
# Specific tool name
|
||||
kwargs["tool_choice"] = {"type": "tool", "name": tool_choice}
|
||||
|
||||
@@ -706,8 +706,6 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st
|
||||
|
||||
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
|
||||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
|
||||
_try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
|
||||
@@ -313,19 +313,7 @@ Write only the summary body. Do not include any preamble or prefix; the system w
|
||||
|
||||
if summary:
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
|
||||
# Pick a role that avoids consecutive same-role with both neighbors.
|
||||
# Priority: avoid colliding with head (already committed), then tail.
|
||||
if last_head_role in ("assistant", "tool"):
|
||||
summary_role = "user"
|
||||
else:
|
||||
summary_role = "assistant"
|
||||
# If the chosen role collides with the tail AND flipping wouldn't
|
||||
# collide with the head, flip it.
|
||||
if summary_role == first_tail_role:
|
||||
flipped = "assistant" if summary_role == "user" else "user"
|
||||
if flipped != last_head_role:
|
||||
summary_role = flipped
|
||||
summary_role = "user" if last_head_role in ("assistant", "tool") else "assistant"
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
else:
|
||||
if not self.quiet_mode:
|
||||
|
||||
@@ -22,21 +22,14 @@ from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.usage_pricing import (
|
||||
CanonicalUsage,
|
||||
DEFAULT_PRICING,
|
||||
estimate_usage_cost,
|
||||
format_duration_compact,
|
||||
get_pricing,
|
||||
has_known_pricing,
|
||||
)
|
||||
from agent.usage_pricing import DEFAULT_PRICING, estimate_cost_usd, format_duration_compact, get_pricing, has_known_pricing
|
||||
|
||||
_DEFAULT_PRICING = DEFAULT_PRICING
|
||||
|
||||
|
||||
def _has_known_pricing(model_name: str, provider: str = None, base_url: str = None) -> bool:
|
||||
def _has_known_pricing(model_name: str) -> bool:
|
||||
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
|
||||
return has_known_pricing(model_name, provider=provider, base_url=base_url)
|
||||
return has_known_pricing(model_name)
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
@@ -48,43 +41,9 @@ def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
return get_pricing(model_name)
|
||||
|
||||
|
||||
def _estimate_cost(
|
||||
session_or_model: Dict[str, Any] | str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
*,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
provider: str = None,
|
||||
base_url: str = None,
|
||||
) -> tuple[float, str]:
|
||||
"""Estimate the USD cost for a session row or a model/token tuple."""
|
||||
if isinstance(session_or_model, dict):
|
||||
session = session_or_model
|
||||
model = session.get("model") or ""
|
||||
usage = CanonicalUsage(
|
||||
input_tokens=session.get("input_tokens") or 0,
|
||||
output_tokens=session.get("output_tokens") or 0,
|
||||
cache_read_tokens=session.get("cache_read_tokens") or 0,
|
||||
cache_write_tokens=session.get("cache_write_tokens") or 0,
|
||||
)
|
||||
provider = session.get("billing_provider")
|
||||
base_url = session.get("billing_base_url")
|
||||
else:
|
||||
model = session_or_model or ""
|
||||
usage = CanonicalUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
)
|
||||
result = estimate_usage_cost(
|
||||
model,
|
||||
usage,
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
)
|
||||
return float(result.amount_usd or 0.0), result.status
|
||||
def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
"""Estimate the USD cost for a given model and token counts."""
|
||||
return estimate_cost_usd(model, input_tokens, output_tokens)
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
@@ -176,10 +135,7 @@ class InsightsEngine:
|
||||
|
||||
# Columns we actually need (skip system_prompt, model_config blobs)
|
||||
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
|
||||
"message_count, tool_call_count, input_tokens, output_tokens, "
|
||||
"cache_read_tokens, cache_write_tokens, billing_provider, "
|
||||
"billing_base_url, billing_mode, estimated_cost_usd, "
|
||||
"actual_cost_usd, cost_status, cost_source")
|
||||
"message_count, tool_call_count, input_tokens, output_tokens")
|
||||
|
||||
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
|
||||
"""Fetch sessions within the time window."""
|
||||
@@ -331,30 +287,21 @@ class InsightsEngine:
|
||||
"""Compute high-level overview statistics."""
|
||||
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
|
||||
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
|
||||
total_cache_read = sum(s.get("cache_read_tokens") or 0 for s in sessions)
|
||||
total_cache_write = sum(s.get("cache_write_tokens") or 0 for s in sessions)
|
||||
total_tokens = total_input + total_output + total_cache_read + total_cache_write
|
||||
total_tokens = total_input + total_output
|
||||
total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions)
|
||||
total_messages = sum(s.get("message_count") or 0 for s in sessions)
|
||||
|
||||
# Cost estimation (weighted by model)
|
||||
total_cost = 0.0
|
||||
actual_cost = 0.0
|
||||
models_with_pricing = set()
|
||||
models_without_pricing = set()
|
||||
unknown_cost_sessions = 0
|
||||
included_cost_sessions = 0
|
||||
for s in sessions:
|
||||
model = s.get("model") or ""
|
||||
estimated, status = _estimate_cost(s)
|
||||
total_cost += estimated
|
||||
actual_cost += s.get("actual_cost_usd") or 0.0
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
total_cost += _estimate_cost(model, inp, out)
|
||||
display = model.split("/")[-1] if "/" in model else (model or "unknown")
|
||||
if status == "included":
|
||||
included_cost_sessions += 1
|
||||
elif status == "unknown":
|
||||
unknown_cost_sessions += 1
|
||||
if _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url")):
|
||||
if _has_known_pricing(model):
|
||||
models_with_pricing.add(display)
|
||||
else:
|
||||
models_without_pricing.add(display)
|
||||
@@ -381,11 +328,8 @@ class InsightsEngine:
|
||||
"total_tool_calls": total_tool_calls,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_cache_read_tokens": total_cache_read,
|
||||
"total_cache_write_tokens": total_cache_write,
|
||||
"total_tokens": total_tokens,
|
||||
"estimated_cost": total_cost,
|
||||
"actual_cost": actual_cost,
|
||||
"total_hours": total_hours,
|
||||
"avg_session_duration": avg_duration,
|
||||
"avg_messages_per_session": total_messages / len(sessions) if sessions else 0,
|
||||
@@ -397,15 +341,12 @@ class InsightsEngine:
|
||||
"date_range_end": date_range_end,
|
||||
"models_with_pricing": sorted(models_with_pricing),
|
||||
"models_without_pricing": sorted(models_without_pricing),
|
||||
"unknown_cost_sessions": unknown_cost_sessions,
|
||||
"included_cost_sessions": included_cost_sessions,
|
||||
}
|
||||
|
||||
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Break down usage by model."""
|
||||
model_data = defaultdict(lambda: {
|
||||
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
|
||||
"cache_read_tokens": 0, "cache_write_tokens": 0,
|
||||
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
|
||||
})
|
||||
|
||||
@@ -417,18 +358,12 @@ class InsightsEngine:
|
||||
d["sessions"] += 1
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
cache_read = s.get("cache_read_tokens") or 0
|
||||
cache_write = s.get("cache_write_tokens") or 0
|
||||
d["input_tokens"] += inp
|
||||
d["output_tokens"] += out
|
||||
d["cache_read_tokens"] += cache_read
|
||||
d["cache_write_tokens"] += cache_write
|
||||
d["total_tokens"] += inp + out + cache_read + cache_write
|
||||
d["total_tokens"] += inp + out
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
estimate, status = _estimate_cost(s)
|
||||
d["cost"] += estimate
|
||||
d["has_pricing"] = _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url"))
|
||||
d["cost_status"] = status
|
||||
d["cost"] += _estimate_cost(model, inp, out)
|
||||
d["has_pricing"] = _has_known_pricing(model)
|
||||
|
||||
result = [
|
||||
{"model": model, **data}
|
||||
@@ -442,8 +377,7 @@ class InsightsEngine:
|
||||
"""Break down usage by platform/source."""
|
||||
platform_data = defaultdict(lambda: {
|
||||
"sessions": 0, "messages": 0, "input_tokens": 0,
|
||||
"output_tokens": 0, "cache_read_tokens": 0,
|
||||
"cache_write_tokens": 0, "total_tokens": 0, "tool_calls": 0,
|
||||
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
|
||||
})
|
||||
|
||||
for s in sessions:
|
||||
@@ -453,13 +387,9 @@ class InsightsEngine:
|
||||
d["messages"] += s.get("message_count") or 0
|
||||
inp = s.get("input_tokens") or 0
|
||||
out = s.get("output_tokens") or 0
|
||||
cache_read = s.get("cache_read_tokens") or 0
|
||||
cache_write = s.get("cache_write_tokens") or 0
|
||||
d["input_tokens"] += inp
|
||||
d["output_tokens"] += out
|
||||
d["cache_read_tokens"] += cache_read
|
||||
d["cache_write_tokens"] += cache_write
|
||||
d["total_tokens"] += inp + out + cache_read + cache_write
|
||||
d["total_tokens"] += inp + out
|
||||
d["tool_calls"] += s.get("tool_call_count") or 0
|
||||
|
||||
result = [
|
||||
|
||||
@@ -266,10 +266,8 @@ def get_model_context_length(model: str, base_url: str = "") -> int:
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length", 128000)
|
||||
|
||||
# 3. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||
for default_model, length in sorted(
|
||||
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
# 3. Hardcoded defaults (fuzzy match)
|
||||
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
if default_model in model or model in default_model:
|
||||
return length
|
||||
|
||||
|
||||
@@ -1,593 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Dict
|
||||
|
||||
from agent.model_metadata import fetch_model_metadata
|
||||
|
||||
MODEL_PRICING = {
|
||||
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||
"gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
|
||||
"gpt-5": {"input": 10.00, "output": 30.00},
|
||||
"gpt-5.4": {"input": 10.00, "output": 30.00},
|
||||
"o3": {"input": 10.00, "output": 40.00},
|
||||
"o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"o4-mini": {"input": 1.10, "output": 4.40},
|
||||
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
|
||||
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
|
||||
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
|
||||
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
|
||||
"deepseek-chat": {"input": 0.14, "output": 0.28},
|
||||
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
|
||||
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
|
||||
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
"llama-4-maverick": {"input": 0.50, "output": 0.70},
|
||||
"llama-4-scout": {"input": 0.20, "output": 0.30},
|
||||
"glm-5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.7": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2.5": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
|
||||
}
|
||||
|
||||
DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||
|
||||
_ZERO = Decimal("0")
|
||||
_ONE_MILLION = Decimal("1000000")
|
||||
|
||||
CostStatus = Literal["actual", "estimated", "included", "unknown"]
|
||||
CostSource = Literal[
|
||||
"provider_cost_api",
|
||||
"provider_generation_api",
|
||||
"provider_models_api",
|
||||
"official_docs_snapshot",
|
||||
"user_override",
|
||||
"custom_contract",
|
||||
"none",
|
||||
]
|
||||
def get_pricing(model_name: str) -> Dict[str, float]:
|
||||
if not model_name:
|
||||
return DEFAULT_PRICING
|
||||
|
||||
bare = model_name.split("/")[-1].lower()
|
||||
if bare in MODEL_PRICING:
|
||||
return MODEL_PRICING[bare]
|
||||
|
||||
best_match = None
|
||||
best_len = 0
|
||||
for key, price in MODEL_PRICING.items():
|
||||
if bare.startswith(key) and len(key) > best_len:
|
||||
best_match = price
|
||||
best_len = len(key)
|
||||
if best_match:
|
||||
return best_match
|
||||
|
||||
if "opus" in bare:
|
||||
return {"input": 15.00, "output": 75.00}
|
||||
if "sonnet" in bare:
|
||||
return {"input": 3.00, "output": 15.00}
|
||||
if "haiku" in bare:
|
||||
return {"input": 0.80, "output": 4.00}
|
||||
if "gpt-4o-mini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
if "gpt-4o" in bare:
|
||||
return {"input": 2.50, "output": 10.00}
|
||||
if "gpt-5" in bare:
|
||||
return {"input": 10.00, "output": 30.00}
|
||||
if "deepseek" in bare:
|
||||
return {"input": 0.14, "output": 0.28}
|
||||
if "gemini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
|
||||
return DEFAULT_PRICING
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CanonicalUsage:
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
request_count: int = 1
|
||||
raw_usage: Optional[dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return self.input_tokens + self.cache_read_tokens + self.cache_write_tokens
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self.prompt_tokens + self.output_tokens
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BillingRoute:
|
||||
provider: str
|
||||
model: str
|
||||
base_url: str = ""
|
||||
billing_mode: str = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PricingEntry:
|
||||
input_cost_per_million: Optional[Decimal] = None
|
||||
output_cost_per_million: Optional[Decimal] = None
|
||||
cache_read_cost_per_million: Optional[Decimal] = None
|
||||
cache_write_cost_per_million: Optional[Decimal] = None
|
||||
request_cost: Optional[Decimal] = None
|
||||
source: CostSource = "none"
|
||||
source_url: Optional[str] = None
|
||||
pricing_version: Optional[str] = None
|
||||
fetched_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CostResult:
|
||||
amount_usd: Optional[Decimal]
|
||||
status: CostStatus
|
||||
source: CostSource
|
||||
label: str
|
||||
fetched_at: Optional[datetime] = None
|
||||
pricing_version: Optional[str] = None
|
||||
notes: tuple[str, ...] = ()
|
||||
|
||||
|
||||
_UTC_NOW = lambda: datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# Official docs snapshot entries. Models whose published pricing and cache
|
||||
# semantics are stable enough to encode exactly.
|
||||
_OFFICIAL_DOCS_PRICING: Dict[tuple[str, str], PricingEntry] = {
|
||||
(
|
||||
"anthropic",
|
||||
"claude-opus-4-20250514",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("15.00"),
|
||||
output_cost_per_million=Decimal("75.00"),
|
||||
cache_read_cost_per_million=Decimal("1.50"),
|
||||
cache_write_cost_per_million=Decimal("18.75"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-prompt-caching-2026-03-16",
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
"claude-sonnet-4-20250514",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("3.00"),
|
||||
output_cost_per_million=Decimal("15.00"),
|
||||
cache_read_cost_per_million=Decimal("0.30"),
|
||||
cache_write_cost_per_million=Decimal("3.75"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-prompt-caching-2026-03-16",
|
||||
),
|
||||
# OpenAI
|
||||
(
|
||||
"openai",
|
||||
"gpt-4o",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("2.50"),
|
||||
output_cost_per_million=Decimal("10.00"),
|
||||
cache_read_cost_per_million=Decimal("1.25"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"gpt-4o-mini",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.15"),
|
||||
output_cost_per_million=Decimal("0.60"),
|
||||
cache_read_cost_per_million=Decimal("0.075"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"gpt-4.1",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("2.00"),
|
||||
output_cost_per_million=Decimal("8.00"),
|
||||
cache_read_cost_per_million=Decimal("0.50"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"gpt-4.1-mini",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.40"),
|
||||
output_cost_per_million=Decimal("1.60"),
|
||||
cache_read_cost_per_million=Decimal("0.10"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"gpt-4.1-nano",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.10"),
|
||||
output_cost_per_million=Decimal("0.40"),
|
||||
cache_read_cost_per_million=Decimal("0.025"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"o3",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("10.00"),
|
||||
output_cost_per_million=Decimal("40.00"),
|
||||
cache_read_cost_per_million=Decimal("2.50"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"openai",
|
||||
"o3-mini",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("1.10"),
|
||||
output_cost_per_million=Decimal("4.40"),
|
||||
cache_read_cost_per_million=Decimal("0.55"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://openai.com/api/pricing/",
|
||||
pricing_version="openai-pricing-2026-03-16",
|
||||
),
|
||||
# Anthropic older models (pre-4.6 generation)
|
||||
(
|
||||
"anthropic",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("3.00"),
|
||||
output_cost_per_million=Decimal("15.00"),
|
||||
cache_read_cost_per_million=Decimal("0.30"),
|
||||
cache_write_cost_per_million=Decimal("3.75"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
"claude-3-5-haiku-20241022",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.80"),
|
||||
output_cost_per_million=Decimal("4.00"),
|
||||
cache_read_cost_per_million=Decimal("0.08"),
|
||||
cache_write_cost_per_million=Decimal("1.00"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
"claude-3-opus-20240229",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("15.00"),
|
||||
output_cost_per_million=Decimal("75.00"),
|
||||
cache_read_cost_per_million=Decimal("1.50"),
|
||||
cache_write_cost_per_million=Decimal("18.75"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"anthropic",
|
||||
"claude-3-haiku-20240307",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.25"),
|
||||
output_cost_per_million=Decimal("1.25"),
|
||||
cache_read_cost_per_million=Decimal("0.03"),
|
||||
cache_write_cost_per_million=Decimal("0.30"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
|
||||
pricing_version="anthropic-pricing-2026-03-16",
|
||||
),
|
||||
# DeepSeek
|
||||
(
|
||||
"deepseek",
|
||||
"deepseek-chat",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.14"),
|
||||
output_cost_per_million=Decimal("0.28"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://api-docs.deepseek.com/quick_start/pricing",
|
||||
pricing_version="deepseek-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"deepseek",
|
||||
"deepseek-reasoner",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.55"),
|
||||
output_cost_per_million=Decimal("2.19"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://api-docs.deepseek.com/quick_start/pricing",
|
||||
pricing_version="deepseek-pricing-2026-03-16",
|
||||
),
|
||||
# Google Gemini
|
||||
(
|
||||
"google",
|
||||
"gemini-2.5-pro",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("1.25"),
|
||||
output_cost_per_million=Decimal("10.00"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://ai.google.dev/pricing",
|
||||
pricing_version="google-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"google",
|
||||
"gemini-2.5-flash",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.15"),
|
||||
output_cost_per_million=Decimal("0.60"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://ai.google.dev/pricing",
|
||||
pricing_version="google-pricing-2026-03-16",
|
||||
),
|
||||
(
|
||||
"google",
|
||||
"gemini-2.0-flash",
|
||||
): PricingEntry(
|
||||
input_cost_per_million=Decimal("0.10"),
|
||||
output_cost_per_million=Decimal("0.40"),
|
||||
source="official_docs_snapshot",
|
||||
source_url="https://ai.google.dev/pricing",
|
||||
pricing_version="google-pricing-2026-03-16",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _to_decimal(value: Any) -> Optional[Decimal]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return Decimal(str(value))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _to_int(value: Any) -> int:
|
||||
try:
|
||||
return int(value or 0)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def resolve_billing_route(
|
||||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> BillingRoute:
|
||||
provider_name = (provider or "").strip().lower()
|
||||
base = (base_url or "").strip().lower()
|
||||
model = (model_name or "").strip()
|
||||
if not provider_name and "/" in model:
|
||||
inferred_provider, bare_model = model.split("/", 1)
|
||||
if inferred_provider in {"anthropic", "openai", "google"}:
|
||||
provider_name = inferred_provider
|
||||
model = bare_model
|
||||
|
||||
if provider_name == "openai-codex":
|
||||
return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included")
|
||||
if provider_name == "openrouter" or "openrouter.ai" in base:
|
||||
return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api")
|
||||
if provider_name == "anthropic":
|
||||
return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
|
||||
if provider_name == "openai":
|
||||
return BillingRoute(provider="openai", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
|
||||
if provider_name in {"custom", "local"} or (base and "localhost" in base):
|
||||
return BillingRoute(provider=provider_name or "custom", model=model, base_url=base_url or "", billing_mode="unknown")
|
||||
return BillingRoute(provider=provider_name or "unknown", model=model.split("/")[-1] if model else "", base_url=base_url or "", billing_mode="unknown")
|
||||
|
||||
|
||||
def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]:
|
||||
return _OFFICIAL_DOCS_PRICING.get((route.provider, route.model.lower()))
|
||||
|
||||
|
||||
def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
||||
metadata = fetch_model_metadata()
|
||||
model_id = route.model
|
||||
if model_id not in metadata:
|
||||
return None
|
||||
pricing = metadata[model_id].get("pricing") or {}
|
||||
prompt = _to_decimal(pricing.get("prompt"))
|
||||
completion = _to_decimal(pricing.get("completion"))
|
||||
request = _to_decimal(pricing.get("request"))
|
||||
cache_read = _to_decimal(
|
||||
pricing.get("cache_read")
|
||||
or pricing.get("cached_prompt")
|
||||
or pricing.get("input_cache_read")
|
||||
)
|
||||
cache_write = _to_decimal(
|
||||
pricing.get("cache_write")
|
||||
or pricing.get("cache_creation")
|
||||
or pricing.get("input_cache_write")
|
||||
)
|
||||
if prompt is None and completion is None and request is None:
|
||||
return None
|
||||
def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]:
|
||||
if value is None:
|
||||
return None
|
||||
return value * _ONE_MILLION
|
||||
|
||||
return PricingEntry(
|
||||
input_cost_per_million=_per_token_to_per_million(prompt),
|
||||
output_cost_per_million=_per_token_to_per_million(completion),
|
||||
cache_read_cost_per_million=_per_token_to_per_million(cache_read),
|
||||
cache_write_cost_per_million=_per_token_to_per_million(cache_write),
|
||||
request_cost=request,
|
||||
source="provider_models_api",
|
||||
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
|
||||
pricing_version="openrouter-models-api",
|
||||
fetched_at=_UTC_NOW(),
|
||||
def has_known_pricing(model_name: str) -> bool:
|
||||
pricing = get_pricing(model_name)
|
||||
return pricing is not DEFAULT_PRICING and any(
|
||||
float(value) > 0 for value in pricing.values()
|
||||
)
|
||||
|
||||
|
||||
def get_pricing_entry(
|
||||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> Optional[PricingEntry]:
|
||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
return PricingEntry(
|
||||
input_cost_per_million=_ZERO,
|
||||
output_cost_per_million=_ZERO,
|
||||
cache_read_cost_per_million=_ZERO,
|
||||
cache_write_cost_per_million=_ZERO,
|
||||
source="none",
|
||||
pricing_version="included-route",
|
||||
)
|
||||
if route.provider == "openrouter":
|
||||
return _openrouter_pricing_entry(route)
|
||||
return _lookup_official_docs_pricing(route)
|
||||
|
||||
|
||||
def normalize_usage(
|
||||
response_usage: Any,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
api_mode: Optional[str] = None,
|
||||
) -> CanonicalUsage:
|
||||
"""Normalize raw API response usage into canonical token buckets.
|
||||
|
||||
Handles three API shapes:
|
||||
- Anthropic: input_tokens/output_tokens/cache_read_input_tokens/cache_creation_input_tokens
|
||||
- Codex Responses: input_tokens includes cache tokens; input_tokens_details.cached_tokens separates them
|
||||
- OpenAI Chat Completions: prompt_tokens includes cache tokens; prompt_tokens_details.cached_tokens separates them
|
||||
|
||||
In both Codex and OpenAI modes, input_tokens is derived by subtracting cache
|
||||
tokens from the total — the API contract is that input/prompt totals include
|
||||
cached tokens and the details object breaks them out.
|
||||
"""
|
||||
if not response_usage:
|
||||
return CanonicalUsage()
|
||||
|
||||
provider_name = (provider or "").strip().lower()
|
||||
mode = (api_mode or "").strip().lower()
|
||||
|
||||
if mode == "anthropic_messages" or provider_name == "anthropic":
|
||||
input_tokens = _to_int(getattr(response_usage, "input_tokens", 0))
|
||||
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
|
||||
cache_read_tokens = _to_int(getattr(response_usage, "cache_read_input_tokens", 0))
|
||||
cache_write_tokens = _to_int(getattr(response_usage, "cache_creation_input_tokens", 0))
|
||||
elif mode == "codex_responses":
|
||||
input_total = _to_int(getattr(response_usage, "input_tokens", 0))
|
||||
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
|
||||
details = getattr(response_usage, "input_tokens_details", None)
|
||||
cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0)
|
||||
cache_write_tokens = _to_int(
|
||||
getattr(details, "cache_creation_tokens", 0) if details else 0
|
||||
)
|
||||
input_tokens = max(0, input_total - cache_read_tokens - cache_write_tokens)
|
||||
else:
|
||||
prompt_total = _to_int(getattr(response_usage, "prompt_tokens", 0))
|
||||
output_tokens = _to_int(getattr(response_usage, "completion_tokens", 0))
|
||||
details = getattr(response_usage, "prompt_tokens_details", None)
|
||||
cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0)
|
||||
cache_write_tokens = _to_int(
|
||||
getattr(details, "cache_write_tokens", 0) if details else 0
|
||||
)
|
||||
input_tokens = max(0, prompt_total - cache_read_tokens - cache_write_tokens)
|
||||
|
||||
reasoning_tokens = 0
|
||||
output_details = getattr(response_usage, "output_tokens_details", None)
|
||||
if output_details:
|
||||
reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0))
|
||||
|
||||
return CanonicalUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
|
||||
|
||||
def estimate_usage_cost(
|
||||
model_name: str,
|
||||
usage: CanonicalUsage,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> CostResult:
|
||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
return CostResult(
|
||||
amount_usd=_ZERO,
|
||||
status="included",
|
||||
source="none",
|
||||
label="included",
|
||||
pricing_version="included-route",
|
||||
)
|
||||
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
||||
if not entry:
|
||||
return CostResult(amount_usd=None, status="unknown", source="none", label="n/a")
|
||||
|
||||
notes: list[str] = []
|
||||
amount = _ZERO
|
||||
|
||||
if usage.input_tokens and entry.input_cost_per_million is None:
|
||||
return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a")
|
||||
if usage.output_tokens and entry.output_cost_per_million is None:
|
||||
return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a")
|
||||
if usage.cache_read_tokens:
|
||||
if entry.cache_read_cost_per_million is None:
|
||||
return CostResult(
|
||||
amount_usd=None,
|
||||
status="unknown",
|
||||
source=entry.source,
|
||||
label="n/a",
|
||||
notes=("cache-read pricing unavailable for route",),
|
||||
)
|
||||
if usage.cache_write_tokens:
|
||||
if entry.cache_write_cost_per_million is None:
|
||||
return CostResult(
|
||||
amount_usd=None,
|
||||
status="unknown",
|
||||
source=entry.source,
|
||||
label="n/a",
|
||||
notes=("cache-write pricing unavailable for route",),
|
||||
)
|
||||
|
||||
if entry.input_cost_per_million is not None:
|
||||
amount += Decimal(usage.input_tokens) * entry.input_cost_per_million / _ONE_MILLION
|
||||
if entry.output_cost_per_million is not None:
|
||||
amount += Decimal(usage.output_tokens) * entry.output_cost_per_million / _ONE_MILLION
|
||||
if entry.cache_read_cost_per_million is not None:
|
||||
amount += Decimal(usage.cache_read_tokens) * entry.cache_read_cost_per_million / _ONE_MILLION
|
||||
if entry.cache_write_cost_per_million is not None:
|
||||
amount += Decimal(usage.cache_write_tokens) * entry.cache_write_cost_per_million / _ONE_MILLION
|
||||
if entry.request_cost is not None and usage.request_count:
|
||||
amount += Decimal(usage.request_count) * entry.request_cost
|
||||
|
||||
status: CostStatus = "estimated"
|
||||
label = f"~${amount:.2f}"
|
||||
if entry.source == "none" and amount == _ZERO:
|
||||
status = "included"
|
||||
label = "included"
|
||||
|
||||
if route.provider == "openrouter":
|
||||
notes.append("OpenRouter cost is estimated from the models API until reconciled.")
|
||||
|
||||
return CostResult(
|
||||
amount_usd=amount,
|
||||
status=status,
|
||||
source=entry.source,
|
||||
label=label,
|
||||
fetched_at=entry.fetched_at,
|
||||
pricing_version=entry.pricing_version,
|
||||
notes=tuple(notes),
|
||||
)
|
||||
|
||||
|
||||
def has_known_pricing(
|
||||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Check whether we have pricing data for this model+route.
|
||||
|
||||
Uses direct lookup instead of routing through the full estimation
|
||||
pipeline — avoids creating dummy usage objects just to check status.
|
||||
"""
|
||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
return True
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
||||
return entry is not None
|
||||
|
||||
|
||||
def get_pricing(
|
||||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Backward-compatible thin wrapper for legacy callers.
|
||||
|
||||
Returns only non-cache input/output fields when a pricing entry exists.
|
||||
Unknown routes return zeroes.
|
||||
"""
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
||||
if not entry:
|
||||
return {"input": 0.0, "output": 0.0}
|
||||
return {
|
||||
"input": float(entry.input_cost_per_million or _ZERO),
|
||||
"output": float(entry.output_cost_per_million or _ZERO),
|
||||
}
|
||||
|
||||
|
||||
def estimate_cost_usd(
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> float:
|
||||
"""Backward-compatible helper for legacy callers.
|
||||
|
||||
This uses non-cached input/output only. New code should call
|
||||
`estimate_usage_cost()` with canonical usage buckets.
|
||||
"""
|
||||
result = estimate_usage_cost(
|
||||
model,
|
||||
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
)
|
||||
return float(result.amount_usd or _ZERO)
|
||||
def estimate_cost_usd(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
pricing = get_pricing(model)
|
||||
total = (
|
||||
Decimal(input_tokens) * Decimal(str(pricing["input"]))
|
||||
+ Decimal(output_tokens) * Decimal(str(pricing["output"]))
|
||||
) / Decimal("1000000")
|
||||
return float(total)
|
||||
|
||||
|
||||
def format_duration_compact(seconds: float) -> str:
|
||||
|
||||
95
cli.py
95
cli.py
@@ -58,12 +58,7 @@ except (ImportError, AttributeError):
|
||||
import threading
|
||||
import queue
|
||||
|
||||
from agent.usage_pricing import (
|
||||
CanonicalUsage,
|
||||
estimate_usage_cost,
|
||||
format_duration_compact,
|
||||
format_token_count_compact,
|
||||
)
|
||||
from agent.usage_pricing import estimate_cost_usd, format_duration_compact, format_token_count_compact, has_known_pricing
|
||||
from hermes_cli.banner import _format_context_length
|
||||
|
||||
_COMMAND_SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏")
|
||||
@@ -217,7 +212,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"resume_display": "full",
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
|
||||
"show_cost": False,
|
||||
"skin": "default",
|
||||
"theme_mode": "auto",
|
||||
},
|
||||
@@ -1039,7 +1034,8 @@ class HermesCLI:
|
||||
self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False)
|
||||
# show_reasoning: display model thinking/reasoning before the response
|
||||
self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False)
|
||||
|
||||
# show_cost: display $ cost in the status bar (off by default)
|
||||
self.show_cost = CLI_CONFIG["display"].get("show_cost", False)
|
||||
self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose")
|
||||
|
||||
# streaming: stream tokens to the terminal as they arrive (display.streaming in config.yaml)
|
||||
@@ -1264,14 +1260,12 @@ class HermesCLI:
|
||||
"context_tokens": 0,
|
||||
"context_length": None,
|
||||
"context_percent": None,
|
||||
"session_input_tokens": 0,
|
||||
"session_output_tokens": 0,
|
||||
"session_cache_read_tokens": 0,
|
||||
"session_cache_write_tokens": 0,
|
||||
"session_prompt_tokens": 0,
|
||||
"session_completion_tokens": 0,
|
||||
"session_total_tokens": 0,
|
||||
"session_api_calls": 0,
|
||||
"session_cost": 0.0,
|
||||
"pricing_known": has_known_pricing(model_name),
|
||||
"compressions": 0,
|
||||
}
|
||||
|
||||
@@ -1279,14 +1273,15 @@ class HermesCLI:
|
||||
if not agent:
|
||||
return snapshot
|
||||
|
||||
snapshot["session_input_tokens"] = getattr(agent, "session_input_tokens", 0) or 0
|
||||
snapshot["session_output_tokens"] = getattr(agent, "session_output_tokens", 0) or 0
|
||||
snapshot["session_cache_read_tokens"] = getattr(agent, "session_cache_read_tokens", 0) or 0
|
||||
snapshot["session_cache_write_tokens"] = getattr(agent, "session_cache_write_tokens", 0) or 0
|
||||
snapshot["session_prompt_tokens"] = getattr(agent, "session_prompt_tokens", 0) or 0
|
||||
snapshot["session_completion_tokens"] = getattr(agent, "session_completion_tokens", 0) or 0
|
||||
snapshot["session_total_tokens"] = getattr(agent, "session_total_tokens", 0) or 0
|
||||
snapshot["session_api_calls"] = getattr(agent, "session_api_calls", 0) or 0
|
||||
snapshot["session_cost"] = estimate_cost_usd(
|
||||
model_name,
|
||||
snapshot["session_prompt_tokens"],
|
||||
snapshot["session_completion_tokens"],
|
||||
)
|
||||
|
||||
compressor = getattr(agent, "context_compressor", None)
|
||||
if compressor:
|
||||
@@ -1307,11 +1302,19 @@ class HermesCLI:
|
||||
percent = snapshot["context_percent"]
|
||||
percent_label = f"{percent}%" if percent is not None else "--"
|
||||
duration_label = snapshot["duration"]
|
||||
show_cost = getattr(self, "show_cost", False)
|
||||
|
||||
if show_cost:
|
||||
cost_label = f"${snapshot['session_cost']:.2f}" if snapshot["pricing_known"] else "cost n/a"
|
||||
else:
|
||||
cost_label = None
|
||||
|
||||
if width < 52:
|
||||
return f"⚕ {snapshot['model_short']} · {duration_label}"
|
||||
if width < 76:
|
||||
parts = [f"⚕ {snapshot['model_short']}", percent_label]
|
||||
if cost_label:
|
||||
parts.append(cost_label)
|
||||
parts.append(duration_label)
|
||||
return " · ".join(parts)
|
||||
|
||||
@@ -1323,6 +1326,8 @@ class HermesCLI:
|
||||
context_label = "ctx --"
|
||||
|
||||
parts = [f"⚕ {snapshot['model_short']}", context_label, percent_label]
|
||||
if cost_label:
|
||||
parts.append(cost_label)
|
||||
parts.append(duration_label)
|
||||
return " │ ".join(parts)
|
||||
except Exception:
|
||||
@@ -1333,6 +1338,12 @@ class HermesCLI:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
duration_label = snapshot["duration"]
|
||||
show_cost = getattr(self, "show_cost", False)
|
||||
|
||||
if show_cost:
|
||||
cost_label = f"${snapshot['session_cost']:.2f}" if snapshot["pricing_known"] else "cost n/a"
|
||||
else:
|
||||
cost_label = None
|
||||
|
||||
if width < 52:
|
||||
return [
|
||||
@@ -1352,6 +1363,11 @@ class HermesCLI:
|
||||
("class:status-bar-dim", " · "),
|
||||
(self._status_bar_context_style(percent), percent_label),
|
||||
]
|
||||
if cost_label:
|
||||
frags.extend([
|
||||
("class:status-bar-dim", " · "),
|
||||
("class:status-bar-dim", cost_label),
|
||||
])
|
||||
frags.extend([
|
||||
("class:status-bar-dim", " · "),
|
||||
("class:status-bar-dim", duration_label),
|
||||
@@ -1377,6 +1393,11 @@ class HermesCLI:
|
||||
("class:status-bar-dim", " "),
|
||||
(bar_style, percent_label),
|
||||
]
|
||||
if cost_label:
|
||||
frags.extend([
|
||||
("class:status-bar-dim", " │ "),
|
||||
("class:status-bar-dim", cost_label),
|
||||
])
|
||||
frags.extend([
|
||||
("class:status-bar-dim", " │ "),
|
||||
("class:status-bar-dim", duration_label),
|
||||
@@ -4229,10 +4250,6 @@ class HermesCLI:
|
||||
return
|
||||
|
||||
agent = self.agent
|
||||
input_tokens = getattr(agent, "session_input_tokens", 0) or 0
|
||||
output_tokens = getattr(agent, "session_output_tokens", 0) or 0
|
||||
cache_read_tokens = getattr(agent, "session_cache_read_tokens", 0) or 0
|
||||
cache_write_tokens = getattr(agent, "session_cache_write_tokens", 0) or 0
|
||||
prompt = agent.session_prompt_tokens
|
||||
completion = agent.session_completion_tokens
|
||||
total = agent.session_total_tokens
|
||||
@@ -4250,45 +4267,33 @@ class HermesCLI:
|
||||
compressions = compressor.compression_count
|
||||
|
||||
msg_count = len(self.conversation_history)
|
||||
cost_result = estimate_usage_cost(
|
||||
agent.model,
|
||||
CanonicalUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
),
|
||||
provider=getattr(agent, "provider", None),
|
||||
base_url=getattr(agent, "base_url", None),
|
||||
)
|
||||
cost = estimate_cost_usd(agent.model, prompt, completion)
|
||||
prompt_cost = estimate_cost_usd(agent.model, prompt, 0)
|
||||
completion_cost = estimate_cost_usd(agent.model, 0, completion)
|
||||
pricing_known = has_known_pricing(agent.model)
|
||||
elapsed = format_duration_compact((datetime.now() - self.session_start).total_seconds())
|
||||
|
||||
print(f" 📊 Session Token Usage")
|
||||
print(f" {'─' * 40}")
|
||||
print(f" Model: {agent.model}")
|
||||
print(f" Input tokens: {input_tokens:>10,}")
|
||||
print(f" Cache read tokens: {cache_read_tokens:>10,}")
|
||||
print(f" Cache write tokens: {cache_write_tokens:>10,}")
|
||||
print(f" Output tokens: {output_tokens:>10,}")
|
||||
print(f" Prompt tokens (total): {prompt:>10,}")
|
||||
print(f" Completion tokens: {completion:>10,}")
|
||||
print(f" Prompt tokens (input): {prompt:>10,}")
|
||||
print(f" Completion tokens (output): {completion:>9,}")
|
||||
print(f" Total tokens: {total:>10,}")
|
||||
print(f" API calls: {calls:>10,}")
|
||||
print(f" Session duration: {elapsed:>10}")
|
||||
print(f" Cost status: {cost_result.status:>10}")
|
||||
print(f" Cost source: {cost_result.source:>10}")
|
||||
if cost_result.amount_usd is not None:
|
||||
prefix = "~" if cost_result.status == "estimated" else ""
|
||||
print(f" Total cost: {prefix}${float(cost_result.amount_usd):>10.4f}")
|
||||
elif cost_result.status == "included":
|
||||
print(f" Total cost: {'included':>10}")
|
||||
if pricing_known:
|
||||
print(f" Input cost: ${prompt_cost:>10.4f}")
|
||||
print(f" Output cost: ${completion_cost:>10.4f}")
|
||||
print(f" Total cost: ${cost:>10.4f}")
|
||||
else:
|
||||
print(f" Input cost: {'n/a':>10}")
|
||||
print(f" Output cost: {'n/a':>10}")
|
||||
print(f" Total cost: {'n/a':>10}")
|
||||
print(f" {'─' * 40}")
|
||||
print(f" Current context: {last_prompt:,} / {ctx_len:,} ({pct:.0f}%)")
|
||||
print(f" Messages: {msg_count}")
|
||||
print(f" Compressions: {compressions}")
|
||||
if cost_result.status == "unknown":
|
||||
if not pricing_known:
|
||||
print(f" Note: Pricing unknown for {agent.model}")
|
||||
|
||||
if self.verbose:
|
||||
|
||||
@@ -5,7 +5,6 @@ Jobs are stored in ~/.hermes/cron/jobs.json
|
||||
Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
@@ -540,8 +539,8 @@ def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
immediately. This prevents a burst of missed jobs on gateway restart.
|
||||
"""
|
||||
now = _hermes_now()
|
||||
raw_jobs = load_jobs()
|
||||
jobs = [_apply_skill_fields(j) for j in copy.deepcopy(raw_jobs)]
|
||||
jobs = [_apply_skill_fields(j) for j in load_jobs()]
|
||||
raw_jobs = load_jobs() # For saving updates
|
||||
due = []
|
||||
needs_save = False
|
||||
|
||||
|
||||
@@ -1,608 +0,0 @@
|
||||
# Pricing Accuracy Architecture
|
||||
|
||||
Date: 2026-03-16
|
||||
|
||||
## Goal
|
||||
|
||||
Hermes should only show dollar costs when they are backed by an official source for the user's actual billing path.
|
||||
|
||||
This design replaces the current static, heuristic pricing flow in:
|
||||
|
||||
- `run_agent.py`
|
||||
- `agent/usage_pricing.py`
|
||||
- `agent/insights.py`
|
||||
- `cli.py`
|
||||
|
||||
with a provider-aware pricing system that:
|
||||
|
||||
- handles cache billing correctly
|
||||
- distinguishes `actual` vs `estimated` vs `included` vs `unknown`
|
||||
- reconciles post-hoc costs when providers expose authoritative billing data
|
||||
- supports direct providers, OpenRouter, subscriptions, enterprise pricing, and custom endpoints
|
||||
|
||||
## Problems In The Current Design
|
||||
|
||||
Current Hermes behavior has four structural issues:
|
||||
|
||||
1. It stores only `prompt_tokens` and `completion_tokens`, which is insufficient for providers that bill cache reads and cache writes separately.
|
||||
2. It uses a static model price table and fuzzy heuristics, which can drift from current official pricing.
|
||||
3. It assumes public API list pricing matches the user's real billing path.
|
||||
4. It has no distinction between live estimates and reconciled billed cost.
|
||||
|
||||
## Design Principles
|
||||
|
||||
1. Normalize usage before pricing.
|
||||
2. Never fold cached tokens into plain input cost.
|
||||
3. Track certainty explicitly.
|
||||
4. Treat the billing path as part of the model identity.
|
||||
5. Prefer official machine-readable sources over scraped docs.
|
||||
6. Use post-hoc provider cost APIs when available.
|
||||
7. Show `n/a` rather than inventing precision.
|
||||
|
||||
## High-Level Architecture
|
||||
|
||||
The new system has four layers:
|
||||
|
||||
1. `usage_normalization`
|
||||
Converts raw provider usage into a canonical usage record.
|
||||
2. `pricing_source_resolution`
|
||||
Determines the billing path, source of truth, and applicable pricing source.
|
||||
3. `cost_estimation_and_reconciliation`
|
||||
Produces an immediate estimate when possible, then replaces or annotates it with actual billed cost later.
|
||||
4. `presentation`
|
||||
`/usage`, `/insights`, and the status bar display cost with certainty metadata.
|
||||
|
||||
## Canonical Usage Record
|
||||
|
||||
Add a canonical usage model that every provider path maps into before any pricing math happens.
|
||||
|
||||
Suggested structure:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CanonicalUsage:
|
||||
provider: str
|
||||
billing_provider: str
|
||||
model: str
|
||||
billing_route: str
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
request_count: int = 1
|
||||
|
||||
raw_usage: dict[str, Any] | None = None
|
||||
raw_usage_fields: dict[str, str] | None = None
|
||||
computed_fields: set[str] | None = None
|
||||
|
||||
provider_request_id: str | None = None
|
||||
provider_generation_id: str | None = None
|
||||
provider_response_id: str | None = None
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- `input_tokens` means non-cached input only.
|
||||
- `cache_read_tokens` and `cache_write_tokens` are never merged into `input_tokens`.
|
||||
- `output_tokens` excludes cache metrics.
|
||||
- `reasoning_tokens` is telemetry unless a provider officially bills it separately.
|
||||
|
||||
This is the same normalization pattern used by `opencode`, extended with provenance and reconciliation ids.
|
||||
|
||||
## Provider Normalization Rules
|
||||
|
||||
### OpenAI Direct
|
||||
|
||||
Source usage fields:
|
||||
|
||||
- `prompt_tokens`
|
||||
- `completion_tokens`
|
||||
- `prompt_tokens_details.cached_tokens`
|
||||
|
||||
Normalization:
|
||||
|
||||
- `cache_read_tokens = cached_tokens`
|
||||
- `input_tokens = prompt_tokens - cached_tokens`
|
||||
- `cache_write_tokens = 0` unless OpenAI exposes it in the relevant route
|
||||
- `output_tokens = completion_tokens`
|
||||
|
||||
### Anthropic Direct
|
||||
|
||||
Source usage fields:
|
||||
|
||||
- `input_tokens`
|
||||
- `output_tokens`
|
||||
- `cache_read_input_tokens`
|
||||
- `cache_creation_input_tokens`
|
||||
|
||||
Normalization:
|
||||
|
||||
- `input_tokens = input_tokens`
|
||||
- `output_tokens = output_tokens`
|
||||
- `cache_read_tokens = cache_read_input_tokens`
|
||||
- `cache_write_tokens = cache_creation_input_tokens`
|
||||
|
||||
### OpenRouter
|
||||
|
||||
Estimate-time usage normalization should use the response usage payload with the same rules as the underlying provider when possible.
|
||||
|
||||
Reconciliation-time records should also store:
|
||||
|
||||
- OpenRouter generation id
|
||||
- native token fields when available
|
||||
- `total_cost`
|
||||
- `cache_discount`
|
||||
- `upstream_inference_cost`
|
||||
- `is_byok`
|
||||
|
||||
### Gemini / Vertex
|
||||
|
||||
Use official Gemini or Vertex usage fields where available.
|
||||
|
||||
If cached content tokens are exposed:
|
||||
|
||||
- map them to `cache_read_tokens`
|
||||
|
||||
If a route exposes no cache creation metric:
|
||||
|
||||
- store `cache_write_tokens = 0`
|
||||
- preserve the raw usage payload for later extension
|
||||
|
||||
### DeepSeek And Other Direct Providers
|
||||
|
||||
Normalize only the fields that are officially exposed.
|
||||
|
||||
If a provider does not expose cache buckets:
|
||||
|
||||
- do not infer them unless the provider explicitly documents how to derive them
|
||||
|
||||
### Subscription / Included-Cost Routes
|
||||
|
||||
These still use the canonical usage model.
|
||||
|
||||
Tokens are tracked normally. Cost depends on billing mode, not on whether usage exists.
|
||||
|
||||
## Billing Route Model
|
||||
|
||||
Hermes must stop keying pricing solely by `model`.
|
||||
|
||||
Introduce a billing route descriptor:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class BillingRoute:
|
||||
provider: str
|
||||
base_url: str | None
|
||||
model: str
|
||||
billing_mode: str
|
||||
organization_hint: str | None = None
|
||||
```
|
||||
|
||||
`billing_mode` values:
|
||||
|
||||
- `official_cost_api`
|
||||
- `official_generation_api`
|
||||
- `official_models_api`
|
||||
- `official_docs_snapshot`
|
||||
- `subscription_included`
|
||||
- `user_override`
|
||||
- `custom_contract`
|
||||
- `unknown`
|
||||
|
||||
Examples:
|
||||
|
||||
- OpenAI direct API with Costs API access: `official_cost_api`
|
||||
- Anthropic direct API with Usage & Cost API access: `official_cost_api`
|
||||
- OpenRouter request before reconciliation: `official_models_api`
|
||||
- OpenRouter request after generation lookup: `official_generation_api`
|
||||
- GitHub Copilot style subscription route: `subscription_included`
|
||||
- local OpenAI-compatible server: `unknown`
|
||||
- enterprise contract with configured rates: `custom_contract`
|
||||
|
||||
## Cost Status Model
|
||||
|
||||
Every displayed cost should have:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CostResult:
|
||||
amount_usd: Decimal | None
|
||||
status: Literal["actual", "estimated", "included", "unknown"]
|
||||
source: Literal[
|
||||
"provider_cost_api",
|
||||
"provider_generation_api",
|
||||
"provider_models_api",
|
||||
"official_docs_snapshot",
|
||||
"user_override",
|
||||
"custom_contract",
|
||||
"none",
|
||||
]
|
||||
label: str
|
||||
fetched_at: datetime | None
|
||||
pricing_version: str | None
|
||||
notes: list[str]
|
||||
```
|
||||
|
||||
Presentation rules:
|
||||
|
||||
- `actual`: show dollar amount as final
|
||||
- `estimated`: show dollar amount with estimate labeling
|
||||
- `included`: show `included` or `$0.00 (included)` depending on UX choice
|
||||
- `unknown`: show `n/a`
|
||||
|
||||
## Official Source Hierarchy
|
||||
|
||||
Resolve cost using this order:
|
||||
|
||||
1. Request-level or account-level official billed cost
|
||||
2. Official machine-readable model pricing
|
||||
3. Official docs snapshot
|
||||
4. User override or custom contract
|
||||
5. Unknown
|
||||
|
||||
The system must never skip to a lower level if a higher-confidence source exists for the current billing route.
|
||||
|
||||
## Provider-Specific Truth Rules
|
||||
|
||||
### OpenAI Direct
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. Costs API for reconciled spend
|
||||
2. Official pricing page for live estimate
|
||||
|
||||
### Anthropic Direct
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. Usage & Cost API for reconciled spend
|
||||
2. Official pricing docs for live estimate
|
||||
|
||||
### OpenRouter
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. `GET /api/v1/generation` for reconciled `total_cost`
|
||||
2. `GET /api/v1/models` pricing for live estimate
|
||||
|
||||
Do not use underlying provider public pricing as the source of truth for OpenRouter billing.
|
||||
|
||||
### Gemini / Vertex
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. official billing export or billing API for reconciled spend when available for the route
|
||||
2. official pricing docs for estimate
|
||||
|
||||
### DeepSeek
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. official machine-readable cost source if available in the future
|
||||
2. official pricing docs snapshot today
|
||||
|
||||
### Subscription-Included Routes
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. explicit route config marking the model as included in subscription
|
||||
|
||||
These should display `included`, not an API list-price estimate.
|
||||
|
||||
### Custom Endpoint / Local Model
|
||||
|
||||
Preferred truth:
|
||||
|
||||
1. user override
|
||||
2. custom contract config
|
||||
3. unknown
|
||||
|
||||
These should default to `unknown`.
|
||||
|
||||
## Pricing Catalog
|
||||
|
||||
Replace the current `MODEL_PRICING` dict with a richer pricing catalog.
|
||||
|
||||
Suggested record:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PricingEntry:
|
||||
provider: str
|
||||
route_pattern: str
|
||||
model_pattern: str
|
||||
|
||||
input_cost_per_million: Decimal | None = None
|
||||
output_cost_per_million: Decimal | None = None
|
||||
cache_read_cost_per_million: Decimal | None = None
|
||||
cache_write_cost_per_million: Decimal | None = None
|
||||
request_cost: Decimal | None = None
|
||||
image_cost: Decimal | None = None
|
||||
|
||||
source: str = "official_docs_snapshot"
|
||||
source_url: str | None = None
|
||||
fetched_at: datetime | None = None
|
||||
pricing_version: str | None = None
|
||||
```
|
||||
|
||||
The catalog should be route-aware:
|
||||
|
||||
- `openai:gpt-5`
|
||||
- `anthropic:claude-opus-4-6`
|
||||
- `openrouter:anthropic/claude-opus-4.6`
|
||||
- `copilot:gpt-4o`
|
||||
|
||||
This avoids conflating direct-provider billing with aggregator billing.
|
||||
|
||||
## Pricing Sync Architecture
|
||||
|
||||
Introduce a pricing sync subsystem instead of manually maintaining a single hardcoded table.
|
||||
|
||||
Suggested modules:
|
||||
|
||||
- `agent/pricing/catalog.py`
|
||||
- `agent/pricing/sources.py`
|
||||
- `agent/pricing/sync.py`
|
||||
- `agent/pricing/reconcile.py`
|
||||
- `agent/pricing/types.py`
|
||||
|
||||
### Sync Sources
|
||||
|
||||
- OpenRouter models API
|
||||
- official provider docs snapshots where no API exists
|
||||
- user overrides from config
|
||||
|
||||
### Sync Output
|
||||
|
||||
Cache pricing entries locally with:
|
||||
|
||||
- source URL
|
||||
- fetch timestamp
|
||||
- version/hash
|
||||
- confidence/source type
|
||||
|
||||
### Sync Frequency
|
||||
|
||||
- startup warm cache
|
||||
- background refresh every 6 to 24 hours depending on source
|
||||
- manual `hermes pricing sync`
|
||||
|
||||
## Reconciliation Architecture
|
||||
|
||||
Live requests may produce only an estimate initially. Hermes should reconcile them later when a provider exposes actual billed cost.
|
||||
|
||||
Suggested flow:
|
||||
|
||||
1. Agent call completes.
|
||||
2. Hermes stores canonical usage plus reconciliation ids.
|
||||
3. Hermes computes an immediate estimate if a pricing source exists.
|
||||
4. A reconciliation worker fetches actual cost when supported.
|
||||
5. Session and message records are updated with `actual` cost.
|
||||
|
||||
This can run:
|
||||
|
||||
- inline for cheap lookups
|
||||
- asynchronously for delayed provider accounting
|
||||
|
||||
## Persistence Changes
|
||||
|
||||
Session storage should stop storing only aggregate prompt/completion totals.
|
||||
|
||||
Add fields for both usage and cost certainty:
|
||||
|
||||
- `input_tokens`
|
||||
- `output_tokens`
|
||||
- `cache_read_tokens`
|
||||
- `cache_write_tokens`
|
||||
- `reasoning_tokens`
|
||||
- `estimated_cost_usd`
|
||||
- `actual_cost_usd`
|
||||
- `cost_status`
|
||||
- `cost_source`
|
||||
- `pricing_version`
|
||||
- `billing_provider`
|
||||
- `billing_mode`
|
||||
|
||||
If schema expansion is too large for one PR, add a new pricing events table:
|
||||
|
||||
```text
|
||||
session_cost_events
|
||||
id
|
||||
session_id
|
||||
request_id
|
||||
provider
|
||||
model
|
||||
billing_mode
|
||||
input_tokens
|
||||
output_tokens
|
||||
cache_read_tokens
|
||||
cache_write_tokens
|
||||
estimated_cost_usd
|
||||
actual_cost_usd
|
||||
cost_status
|
||||
cost_source
|
||||
pricing_version
|
||||
created_at
|
||||
updated_at
|
||||
```
|
||||
|
||||
## Hermes Touchpoints
|
||||
|
||||
### `run_agent.py`
|
||||
|
||||
Current responsibility:
|
||||
|
||||
- parse raw provider usage
|
||||
- update session token counters
|
||||
|
||||
New responsibility:
|
||||
|
||||
- build `CanonicalUsage`
|
||||
- update canonical counters
|
||||
- store reconciliation ids
|
||||
- emit usage event to pricing subsystem
|
||||
|
||||
### `agent/usage_pricing.py`
|
||||
|
||||
Current responsibility:
|
||||
|
||||
- static lookup table
|
||||
- direct cost arithmetic
|
||||
|
||||
New responsibility:
|
||||
|
||||
- move or replace with pricing catalog facade
|
||||
- no fuzzy model-family heuristics
|
||||
- no direct pricing without billing-route context
|
||||
|
||||
### `cli.py`
|
||||
|
||||
Current responsibility:
|
||||
|
||||
- compute session cost directly from prompt/completion totals
|
||||
|
||||
New responsibility:
|
||||
|
||||
- display `CostResult`
|
||||
- show status badges:
|
||||
- `actual`
|
||||
- `estimated`
|
||||
- `included`
|
||||
- `n/a`
|
||||
|
||||
### `agent/insights.py`
|
||||
|
||||
Current responsibility:
|
||||
|
||||
- recompute historical estimates from static pricing
|
||||
|
||||
New responsibility:
|
||||
|
||||
- aggregate stored pricing events
|
||||
- prefer actual cost over estimate
|
||||
- surface estimates only when reconciliation is unavailable
|
||||
|
||||
## UX Rules
|
||||
|
||||
### Status Bar
|
||||
|
||||
Show one of:
|
||||
|
||||
- `$1.42`
|
||||
- `~$1.42`
|
||||
- `included`
|
||||
- `cost n/a`
|
||||
|
||||
Where:
|
||||
|
||||
- `$1.42` means `actual`
|
||||
- `~$1.42` means `estimated`
|
||||
- `included` means subscription-backed or explicitly zero-cost route
|
||||
- `cost n/a` means unknown
|
||||
|
||||
### `/usage`
|
||||
|
||||
Show:
|
||||
|
||||
- token buckets
|
||||
- estimated cost
|
||||
- actual cost if available
|
||||
- cost status
|
||||
- pricing source
|
||||
|
||||
### `/insights`
|
||||
|
||||
Aggregate:
|
||||
|
||||
- actual cost totals
|
||||
- estimated-only totals
|
||||
- unknown-cost sessions count
|
||||
- included-cost sessions count
|
||||
|
||||
## Config And Overrides
|
||||
|
||||
Add user-configurable pricing overrides in config:
|
||||
|
||||
```yaml
|
||||
pricing:
|
||||
mode: hybrid
|
||||
sync_on_startup: true
|
||||
sync_interval_hours: 12
|
||||
overrides:
|
||||
- provider: openrouter
|
||||
model: anthropic/claude-opus-4.6
|
||||
billing_mode: custom_contract
|
||||
input_cost_per_million: 4.25
|
||||
output_cost_per_million: 22.0
|
||||
cache_read_cost_per_million: 0.5
|
||||
cache_write_cost_per_million: 6.0
|
||||
included_routes:
|
||||
- provider: copilot
|
||||
model: "*"
|
||||
- provider: codex-subscription
|
||||
model: "*"
|
||||
```
|
||||
|
||||
Overrides must win over catalog defaults for the matching billing route.
|
||||
|
||||
## Rollout Plan
|
||||
|
||||
### Phase 1
|
||||
|
||||
- add canonical usage model
|
||||
- split cache token buckets in `run_agent.py`
|
||||
- stop pricing cache-inflated prompt totals
|
||||
- preserve current UI with improved backend math
|
||||
|
||||
### Phase 2
|
||||
|
||||
- add route-aware pricing catalog
|
||||
- integrate OpenRouter models API sync
|
||||
- add `estimated` vs `included` vs `unknown`
|
||||
|
||||
### Phase 3
|
||||
|
||||
- add reconciliation for OpenRouter generation cost
|
||||
- add actual cost persistence
|
||||
- update `/insights` to prefer actual cost
|
||||
|
||||
### Phase 4
|
||||
|
||||
- add direct OpenAI and Anthropic reconciliation paths
|
||||
- add user overrides and contract pricing
|
||||
- add pricing sync CLI command
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
Add tests for:
|
||||
|
||||
- OpenAI cached token subtraction
|
||||
- Anthropic cache read/write separation
|
||||
- OpenRouter estimated vs actual reconciliation
|
||||
- subscription-backed models showing `included`
|
||||
- custom endpoints showing `n/a`
|
||||
- override precedence
|
||||
- stale catalog fallback behavior
|
||||
|
||||
Current tests that assume heuristic pricing should be replaced with route-aware expectations.
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- exact enterprise billing reconstruction without an official source or user override
|
||||
- backfilling perfect historical cost for old sessions that lack cache bucket data
|
||||
- scraping arbitrary provider web pages at request time
|
||||
|
||||
## Recommendation
|
||||
|
||||
Do not expand the existing `MODEL_PRICING` dict.
|
||||
|
||||
That path cannot satisfy the product requirement. Hermes should instead migrate to:
|
||||
|
||||
- canonical usage normalization
|
||||
- route-aware pricing sources
|
||||
- estimate-then-reconcile cost lifecycle
|
||||
- explicit certainty states in the UI
|
||||
|
||||
This is the minimum architecture that makes the statement "Hermes pricing is backed by official sources where possible, and otherwise clearly labeled" defensible.
|
||||
@@ -60,7 +60,7 @@ def check_dingtalk_requirements() -> bool:
|
||||
"""Check if DingTalk dependencies are available and configured."""
|
||||
if not DINGTALK_STREAM_AVAILABLE or not HTTPX_AVAILABLE:
|
||||
return False
|
||||
if not os.getenv("DINGTALK_CLIENT_ID") or not os.getenv("DINGTALK_CLIENT_SECRET"):
|
||||
if not os.getenv("DINGTALK_CLIENT_ID") and not os.getenv("DINGTALK_CLIENT_SECRET"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -220,7 +220,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
# Start the sync loop.
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
self._mark_connected()
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
|
||||
@@ -222,7 +222,6 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
|
||||
# Start WebSocket in background.
|
||||
self._ws_task = asyncio.create_task(self._ws_loop())
|
||||
self._mark_connected()
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
|
||||
@@ -984,16 +984,6 @@ class GatewayRunner:
|
||||
):
|
||||
self._schedule_update_notification_watch()
|
||||
|
||||
# Drain any recovered process watchers (from crash recovery checkpoint)
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
while process_registry.pending_watchers:
|
||||
watcher = process_registry.pending_watchers.pop(0)
|
||||
asyncio.create_task(self._run_process_watcher(watcher))
|
||||
logger.info("Resumed watcher for recovered process %s", watcher.get("session_id"))
|
||||
except Exception as e:
|
||||
logger.error("Recovered watcher setup error: %s", e)
|
||||
|
||||
# Start background session expiry watcher for proactive memory flushing
|
||||
asyncio.create_task(self._session_expiry_watcher())
|
||||
|
||||
@@ -1491,7 +1481,7 @@ class GatewayRunner:
|
||||
if cmd_key in skill_cmds:
|
||||
user_instruction = event.get_command_args().strip()
|
||||
msg = build_skill_invocation_message(
|
||||
cmd_key, user_instruction, task_id=_quick_key
|
||||
cmd_key, user_instruction, task_id=session_key
|
||||
)
|
||||
if msg:
|
||||
event.text = msg
|
||||
@@ -1552,9 +1542,8 @@ class GatewayRunner:
|
||||
# Read privacy.redact_pii from config (re-read per message)
|
||||
_redact_pii = False
|
||||
try:
|
||||
import yaml as _pii_yaml
|
||||
with open(_config_path, encoding="utf-8") as _pf:
|
||||
_pcfg = _pii_yaml.safe_load(_pf) or {}
|
||||
_pcfg = yaml.safe_load(_pf) or {}
|
||||
_redact_pii = bool((_pcfg.get("privacy") or {}).get("redact_pii", False))
|
||||
except Exception:
|
||||
pass
|
||||
@@ -2100,15 +2089,8 @@ class GatewayRunner:
|
||||
session_entry.session_key,
|
||||
input_tokens=agent_result.get("input_tokens", 0),
|
||||
output_tokens=agent_result.get("output_tokens", 0),
|
||||
cache_read_tokens=agent_result.get("cache_read_tokens", 0),
|
||||
cache_write_tokens=agent_result.get("cache_write_tokens", 0),
|
||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||
model=agent_result.get("model"),
|
||||
estimated_cost_usd=agent_result.get("estimated_cost_usd"),
|
||||
cost_status=agent_result.get("cost_status"),
|
||||
cost_source=agent_result.get("cost_source"),
|
||||
provider=agent_result.get("provider"),
|
||||
base_url=agent_result.get("base_url"),
|
||||
)
|
||||
|
||||
# Auto voice reply: send TTS audio before the text response
|
||||
|
||||
@@ -343,11 +343,7 @@ class SessionEntry:
|
||||
# Token tracking
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
cost_status: str = "unknown"
|
||||
|
||||
# Last API-reported prompt tokens (for accurate compression pre-check)
|
||||
last_prompt_tokens: int = 0
|
||||
@@ -367,12 +363,8 @@ class SessionEntry:
|
||||
"chat_type": self.chat_type,
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cache_read_tokens": self.cache_read_tokens,
|
||||
"cache_write_tokens": self.cache_write_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"last_prompt_tokens": self.last_prompt_tokens,
|
||||
"estimated_cost_usd": self.estimated_cost_usd,
|
||||
"cost_status": self.cost_status,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
@@ -402,12 +394,8 @@ class SessionEntry:
|
||||
chat_type=data.get("chat_type", "dm"),
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
cache_read_tokens=data.get("cache_read_tokens", 0),
|
||||
cache_write_tokens=data.get("cache_write_tokens", 0),
|
||||
total_tokens=data.get("total_tokens", 0),
|
||||
last_prompt_tokens=data.get("last_prompt_tokens", 0),
|
||||
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
|
||||
cost_status=data.get("cost_status", "unknown"),
|
||||
)
|
||||
|
||||
|
||||
@@ -708,15 +696,8 @@ class SessionStore:
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
last_prompt_tokens: int = None,
|
||||
model: str = None,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
@@ -726,35 +707,15 @@ class SessionStore:
|
||||
entry.updated_at = datetime.now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
entry.cache_read_tokens += cache_read_tokens
|
||||
entry.cache_write_tokens += cache_write_tokens
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd += estimated_cost_usd
|
||||
if cost_status:
|
||||
entry.cost_status = cost_status
|
||||
entry.total_tokens = (
|
||||
entry.input_tokens
|
||||
+ entry.output_tokens
|
||||
+ entry.cache_read_tokens
|
||||
+ entry.cache_write_tokens
|
||||
)
|
||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||
self._save()
|
||||
|
||||
if self._db:
|
||||
try:
|
||||
self._db.update_token_counts(
|
||||
entry.session_id,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
estimated_cost_usd=estimated_cost_usd,
|
||||
cost_status=cost_status,
|
||||
cost_source=cost_source,
|
||||
billing_provider=provider,
|
||||
billing_base_url=base_url,
|
||||
entry.session_id, input_tokens, output_tokens,
|
||||
model=model,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -550,14 +550,6 @@ OPTIONAL_ENV_VARS = {
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"PARALLEL_API_KEY": {
|
||||
"description": "Parallel API key for AI-native web search and extract",
|
||||
"prompt": "Parallel API key",
|
||||
"url": "https://parallel.ai/",
|
||||
"tools": ["web_search", "web_extract"],
|
||||
"password": True,
|
||||
"category": "tool",
|
||||
},
|
||||
"FIRECRAWL_API_KEY": {
|
||||
"description": "Firecrawl API key for web search and scraping",
|
||||
"prompt": "Firecrawl API key",
|
||||
@@ -1514,7 +1506,6 @@ def show_config():
|
||||
keys = [
|
||||
("OPENROUTER_API_KEY", "OpenRouter"),
|
||||
("VOICE_TOOLS_OPENAI_KEY", "OpenAI (STT/TTS)"),
|
||||
("PARALLEL_API_KEY", "Parallel"),
|
||||
("FIRECRAWL_API_KEY", "Firecrawl"),
|
||||
("BROWSERBASE_API_KEY", "Browserbase"),
|
||||
("BROWSER_USE_API_KEY", "Browser Use"),
|
||||
@@ -1664,7 +1655,7 @@ def set_config_value(key: str, value: str):
|
||||
# Check if it's an API key (goes to .env)
|
||||
api_keys = [
|
||||
'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
|
||||
'PARALLEL_API_KEY', 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY',
|
||||
'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY',
|
||||
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
||||
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
||||
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
|
||||
|
||||
@@ -473,7 +473,7 @@ def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
from hermes_cli.auth import fetch_nous_models, resolve_nous_runtime_credentials
|
||||
creds = resolve_nous_runtime_credentials()
|
||||
if creds:
|
||||
live = fetch_nous_models(api_key=creds.get("api_key", ""), inference_base_url=creds.get("base_url", ""))
|
||||
live = fetch_nous_models(creds.get("api_key", ""), creds.get("base_url", ""))
|
||||
if live:
|
||||
return live
|
||||
except Exception:
|
||||
|
||||
@@ -444,11 +444,11 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
else:
|
||||
tool_status.append(("Mixture of Agents", False, "OPENROUTER_API_KEY"))
|
||||
|
||||
# Web tools (Parallel or Firecrawl)
|
||||
if get_env_value("PARALLEL_API_KEY") or get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL"):
|
||||
# Firecrawl (web tools)
|
||||
if get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL"):
|
||||
tool_status.append(("Web Search & Extract", True, None))
|
||||
else:
|
||||
tool_status.append(("Web Search & Extract", False, "PARALLEL_API_KEY or FIRECRAWL_API_KEY"))
|
||||
tool_status.append(("Web Search & Extract", False, "FIRECRAWL_API_KEY"))
|
||||
|
||||
# Browser tools (local Chromium or Browserbase cloud)
|
||||
import shutil
|
||||
|
||||
@@ -151,29 +151,19 @@ TOOL_CATEGORIES = {
|
||||
"web": {
|
||||
"name": "Web Search & Extract",
|
||||
"setup_title": "Select Search Provider",
|
||||
"setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need a premium provider.",
|
||||
"setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need Firecrawl.",
|
||||
"icon": "🔍",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Firecrawl Cloud",
|
||||
"tag": "Hosted service - search, extract, and crawl",
|
||||
"web_backend": "firecrawl",
|
||||
"tag": "Recommended - hosted service",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Parallel",
|
||||
"tag": "AI-native search and extract",
|
||||
"web_backend": "parallel",
|
||||
"env_vars": [
|
||||
{"key": "PARALLEL_API_KEY", "prompt": "Parallel API key", "url": "https://parallel.ai"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl Self-Hosted",
|
||||
"tag": "Free - run your own instance",
|
||||
"web_backend": "firecrawl",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
|
||||
],
|
||||
@@ -628,9 +618,6 @@ def _is_provider_active(provider: dict, config: dict) -> bool:
|
||||
if "browser_provider" in provider:
|
||||
current = config.get("browser", {}).get("cloud_provider")
|
||||
return provider["browser_provider"] == current
|
||||
if provider.get("web_backend"):
|
||||
current = config.get("web", {}).get("backend")
|
||||
return current == provider["web_backend"]
|
||||
return False
|
||||
|
||||
|
||||
@@ -663,11 +650,6 @@ def _configure_provider(provider: dict, config: dict):
|
||||
else:
|
||||
config.get("browser", {}).pop("cloud_provider", None)
|
||||
|
||||
# Set web search backend in config if applicable
|
||||
if provider.get("web_backend"):
|
||||
config.setdefault("web", {})["backend"] = provider["web_backend"]
|
||||
_print_success(f" Web backend set to: {provider['web_backend']}")
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
@@ -1003,19 +985,12 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
if len(platform_keys) > 1:
|
||||
platform_choices.append("Configure all platforms (global)")
|
||||
platform_choices.append("Reconfigure an existing tool's provider or API key")
|
||||
|
||||
# Show MCP option if any MCP servers are configured
|
||||
_has_mcp = bool(config.get("mcp_servers"))
|
||||
if _has_mcp:
|
||||
platform_choices.append("Configure MCP server tools")
|
||||
|
||||
platform_choices.append("Done")
|
||||
|
||||
# Index offsets for the extra options after per-platform entries
|
||||
_global_idx = len(platform_keys) if len(platform_keys) > 1 else -1
|
||||
_reconfig_idx = len(platform_keys) + (1 if len(platform_keys) > 1 else 0)
|
||||
_mcp_idx = (_reconfig_idx + 1) if _has_mcp else -1
|
||||
_done_idx = _reconfig_idx + (2 if _has_mcp else 1)
|
||||
_done_idx = _reconfig_idx + 1
|
||||
|
||||
while True:
|
||||
idx = _prompt_choice("Select an option:", platform_choices, default=0)
|
||||
@@ -1030,12 +1005,6 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print()
|
||||
continue
|
||||
|
||||
# "Configure MCP tools" selected
|
||||
if idx == _mcp_idx:
|
||||
_configure_mcp_tools_interactive(config)
|
||||
print()
|
||||
continue
|
||||
|
||||
# "Configure all platforms (global)" selected
|
||||
if idx == _global_idx:
|
||||
# Use the union of all platforms' current tools as the starting state
|
||||
@@ -1122,137 +1091,6 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print()
|
||||
|
||||
|
||||
# ─── MCP Tools Interactive Configuration ─────────────────────────────────────
|
||||
|
||||
|
||||
def _configure_mcp_tools_interactive(config: dict):
|
||||
"""Probe MCP servers for available tools and let user toggle them on/off.
|
||||
|
||||
Connects to each configured MCP server, discovers tools, then shows
|
||||
a per-server curses checklist. Writes changes back as ``tools.exclude``
|
||||
entries in config.yaml.
|
||||
"""
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
mcp_servers = config.get("mcp_servers") or {}
|
||||
if not mcp_servers:
|
||||
_print_info("No MCP servers configured.")
|
||||
return
|
||||
|
||||
# Count enabled servers
|
||||
enabled_names = [
|
||||
k for k, v in mcp_servers.items()
|
||||
if v.get("enabled", True) not in (False, "false", "0", "no", "off")
|
||||
]
|
||||
if not enabled_names:
|
||||
_print_info("All MCP servers are disabled.")
|
||||
return
|
||||
|
||||
print()
|
||||
print(color(" Discovering tools from MCP servers...", Colors.YELLOW))
|
||||
print(color(f" Connecting to {len(enabled_names)} server(s): {', '.join(enabled_names)}", Colors.DIM))
|
||||
|
||||
try:
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
server_tools = probe_mcp_server_tools()
|
||||
except Exception as exc:
|
||||
_print_error(f"Failed to probe MCP servers: {exc}")
|
||||
return
|
||||
|
||||
if not server_tools:
|
||||
_print_warning("Could not discover tools from any MCP server.")
|
||||
_print_info("Check that server commands/URLs are correct and dependencies are installed.")
|
||||
return
|
||||
|
||||
# Report discovery results
|
||||
failed = [n for n in enabled_names if n not in server_tools]
|
||||
if failed:
|
||||
for name in failed:
|
||||
_print_warning(f" Could not connect to '{name}'")
|
||||
|
||||
total_tools = sum(len(tools) for tools in server_tools.values())
|
||||
print(color(f" Found {total_tools} tool(s) across {len(server_tools)} server(s)", Colors.GREEN))
|
||||
print()
|
||||
|
||||
any_changes = False
|
||||
|
||||
for server_name, tools in server_tools.items():
|
||||
if not tools:
|
||||
_print_info(f" {server_name}: no tools found")
|
||||
continue
|
||||
|
||||
srv_cfg = mcp_servers.get(server_name, {})
|
||||
tools_cfg = srv_cfg.get("tools") or {}
|
||||
include_list = tools_cfg.get("include") or []
|
||||
exclude_list = tools_cfg.get("exclude") or []
|
||||
|
||||
# Build checklist labels
|
||||
labels = []
|
||||
for tool_name, description in tools:
|
||||
desc_short = description[:70] + "..." if len(description) > 70 else description
|
||||
if desc_short:
|
||||
labels.append(f"{tool_name} ({desc_short})")
|
||||
else:
|
||||
labels.append(tool_name)
|
||||
|
||||
# Determine which tools are currently enabled
|
||||
pre_selected: Set[int] = set()
|
||||
tool_names = [t[0] for t in tools]
|
||||
for i, tool_name in enumerate(tool_names):
|
||||
if include_list:
|
||||
# Include mode: only included tools are selected
|
||||
if tool_name in include_list:
|
||||
pre_selected.add(i)
|
||||
elif exclude_list:
|
||||
# Exclude mode: everything except excluded
|
||||
if tool_name not in exclude_list:
|
||||
pre_selected.add(i)
|
||||
else:
|
||||
# No filter: all enabled
|
||||
pre_selected.add(i)
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"MCP Server: {server_name} ({len(tools)} tools)",
|
||||
labels,
|
||||
pre_selected,
|
||||
cancel_returns=pre_selected,
|
||||
)
|
||||
|
||||
if chosen == pre_selected:
|
||||
_print_info(f" {server_name}: no changes")
|
||||
continue
|
||||
|
||||
# Compute new exclude list based on unchecked tools
|
||||
new_exclude = [tool_names[i] for i in range(len(tool_names)) if i not in chosen]
|
||||
|
||||
# Update config
|
||||
srv_cfg = mcp_servers.setdefault(server_name, {})
|
||||
tools_cfg = srv_cfg.setdefault("tools", {})
|
||||
|
||||
if new_exclude:
|
||||
tools_cfg["exclude"] = new_exclude
|
||||
# Remove include if present — we're switching to exclude mode
|
||||
tools_cfg.pop("include", None)
|
||||
else:
|
||||
# All tools enabled — clear filters
|
||||
tools_cfg.pop("exclude", None)
|
||||
tools_cfg.pop("include", None)
|
||||
|
||||
enabled_count = len(chosen)
|
||||
disabled_count = len(tools) - enabled_count
|
||||
_print_success(
|
||||
f" {server_name}: {enabled_count} enabled, {disabled_count} disabled"
|
||||
)
|
||||
any_changes = True
|
||||
|
||||
if any_changes:
|
||||
save_config(config)
|
||||
print()
|
||||
print(color(" ✓ MCP tool configuration saved", Colors.GREEN))
|
||||
else:
|
||||
print(color(" No changes to MCP tools", Colors.DIM))
|
||||
|
||||
|
||||
# ─── Non-interactive disable/enable ──────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
174
hermes_state.py
174
hermes_state.py
@@ -26,7 +26,7 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
|
||||
|
||||
SCHEMA_VERSION = 5
|
||||
SCHEMA_VERSION = 4
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
@@ -48,17 +48,6 @@ CREATE TABLE IF NOT EXISTS sessions (
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
cache_read_tokens INTEGER DEFAULT 0,
|
||||
cache_write_tokens INTEGER DEFAULT 0,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
billing_provider TEXT,
|
||||
billing_base_url TEXT,
|
||||
billing_mode TEXT,
|
||||
estimated_cost_usd REAL,
|
||||
actual_cost_usd REAL,
|
||||
cost_status TEXT,
|
||||
cost_source TEXT,
|
||||
pricing_version TEXT,
|
||||
title TEXT,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
@@ -165,26 +154,6 @@ class SessionDB:
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 4")
|
||||
if current_version < 5:
|
||||
new_columns = [
|
||||
("cache_read_tokens", "INTEGER DEFAULT 0"),
|
||||
("cache_write_tokens", "INTEGER DEFAULT 0"),
|
||||
("reasoning_tokens", "INTEGER DEFAULT 0"),
|
||||
("billing_provider", "TEXT"),
|
||||
("billing_base_url", "TEXT"),
|
||||
("billing_mode", "TEXT"),
|
||||
("estimated_cost_usd", "REAL"),
|
||||
("actual_cost_usd", "REAL"),
|
||||
("cost_status", "TEXT"),
|
||||
("cost_source", "TEXT"),
|
||||
("pricing_version", "TEXT"),
|
||||
]
|
||||
for name, column_type in new_columns:
|
||||
try:
|
||||
cursor.execute(f"ALTER TABLE sessions ADD COLUMN {name} {column_type}")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
cursor.execute("UPDATE schema_version SET version = 5")
|
||||
|
||||
# Unique title index — always ensure it exists (safe to run after migrations
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
@@ -264,22 +233,8 @@ class SessionDB:
|
||||
self._conn.commit()
|
||||
|
||||
def update_token_counts(
|
||||
self,
|
||||
session_id: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0,
|
||||
model: str = None,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
actual_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
pricing_version: Optional[str] = None,
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Increment token counters and backfill model if not already set."""
|
||||
with self._lock:
|
||||
@@ -287,40 +242,9 @@ class SessionDB:
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
cache_read_tokens = cache_read_tokens + ?,
|
||||
cache_write_tokens = cache_write_tokens + ?,
|
||||
reasoning_tokens = reasoning_tokens + ?,
|
||||
estimated_cost_usd = COALESCE(estimated_cost_usd, 0) + COALESCE(?, 0),
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE COALESCE(actual_cost_usd, 0) + ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
cost_status,
|
||||
cost_source,
|
||||
pricing_version,
|
||||
billing_provider,
|
||||
billing_base_url,
|
||||
billing_mode,
|
||||
model,
|
||||
session_id,
|
||||
),
|
||||
(input_tokens, output_tokens, model, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
@@ -809,18 +733,17 @@ class SessionDB:
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions, optionally filtered by source."""
|
||||
with self._lock:
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(source, limit, offset),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(source, limit, offset),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
# =========================================================================
|
||||
# Utility
|
||||
@@ -872,28 +795,26 @@ class SessionDB:
|
||||
|
||||
def clear_messages(self, session_id: str) -> None:
|
||||
"""Delete all messages for a session and reset its counters."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session and all its messages. Returns True if found."""
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
self._conn.commit()
|
||||
return True
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
self._conn.commit()
|
||||
return True
|
||||
|
||||
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
||||
"""
|
||||
@@ -903,23 +824,22 @@ class SessionDB:
|
||||
import time as _time
|
||||
cutoff = _time.time() - (older_than_days * 86400)
|
||||
|
||||
with self._lock:
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT id FROM sessions
|
||||
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
|
||||
(cutoff,),
|
||||
)
|
||||
session_ids = [row["id"] for row in cursor.fetchall()]
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT id FROM sessions
|
||||
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
|
||||
(cutoff,),
|
||||
)
|
||||
session_ids = [row["id"] for row in cursor.fetchall()]
|
||||
|
||||
for sid in session_ids:
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
for sid in session_ids:
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
|
||||
self._conn.commit()
|
||||
self._conn.commit()
|
||||
return len(session_ids)
|
||||
|
||||
@@ -27,7 +27,6 @@ dependencies = [
|
||||
"prompt_toolkit",
|
||||
# Tools
|
||||
"firecrawl-py",
|
||||
"parallel-web>=0.4.2",
|
||||
"fal-client",
|
||||
# Text-to-speech (Edge TTS is free, no API key needed)
|
||||
"edge-tts",
|
||||
|
||||
@@ -18,7 +18,6 @@ PyJWT[crypto]
|
||||
|
||||
# Web tools
|
||||
firecrawl-py
|
||||
parallel-web>=0.4.2
|
||||
|
||||
# Image generation
|
||||
fal-client
|
||||
|
||||
140
run_agent.py
140
run_agent.py
@@ -86,7 +86,6 @@ from agent.model_metadata import (
|
||||
from agent.context_compressor import ContextCompressor
|
||||
from agent.prompt_caching import apply_anthropic_cache_control
|
||||
from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt
|
||||
from agent.usage_pricing import estimate_usage_cost, normalize_usage
|
||||
from agent.display import (
|
||||
KawaiiSpinner, build_tool_preview as _build_tool_preview,
|
||||
get_cute_tool_message as _get_cute_tool_message_impl,
|
||||
@@ -392,15 +391,6 @@ class AIAgent:
|
||||
else:
|
||||
self.api_mode = "chat_completions"
|
||||
|
||||
# Pre-warm OpenRouter model metadata cache in a background thread.
|
||||
# fetch_model_metadata() is cached for 1 hour; this avoids a blocking
|
||||
# HTTP request on the first API response when pricing is estimated.
|
||||
if self.provider == "openrouter" or "openrouter" in self.base_url.lower():
|
||||
threading.Thread(
|
||||
target=lambda: fetch_model_metadata(),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.thinking_callback = thinking_callback
|
||||
self.reasoning_callback = reasoning_callback
|
||||
@@ -467,8 +457,8 @@ class AIAgent:
|
||||
and Path(getattr(handler, "baseFilename", "")).resolve() == resolved_error_log_path
|
||||
for handler in root_logger.handlers
|
||||
)
|
||||
from agent.redact import RedactingFormatter
|
||||
if not has_errors_log_handler:
|
||||
from agent.redact import RedactingFormatter
|
||||
error_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
error_file_handler = RotatingFileHandler(
|
||||
error_log_path, maxBytes=2 * 1024 * 1024, backupCount=2,
|
||||
@@ -860,14 +850,6 @@ class AIAgent:
|
||||
self.session_completion_tokens = 0
|
||||
self.session_total_tokens = 0
|
||||
self.session_api_calls = 0
|
||||
self.session_input_tokens = 0
|
||||
self.session_output_tokens = 0
|
||||
self.session_cache_read_tokens = 0
|
||||
self.session_cache_write_tokens = 0
|
||||
self.session_reasoning_tokens = 0
|
||||
self.session_estimated_cost_usd = 0.0
|
||||
self.session_cost_status = "unknown"
|
||||
self.session_cost_source = "none"
|
||||
|
||||
if not self.quiet_mode:
|
||||
if compression_enabled:
|
||||
@@ -4884,7 +4866,6 @@ class AIAgent:
|
||||
codex_ack_continuations = 0
|
||||
length_continue_retries = 0
|
||||
truncated_response_prefix = ""
|
||||
compression_attempts = 0
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
@@ -5030,6 +5011,7 @@ class AIAgent:
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
compression_attempts = 0
|
||||
max_compression_attempts = 3
|
||||
codex_auth_retry_attempted = False
|
||||
anthropic_auth_retry_attempted = False
|
||||
@@ -5290,14 +5272,26 @@ class AIAgent:
|
||||
|
||||
# Track actual token usage from response for context management
|
||||
if hasattr(response, 'usage') and response.usage:
|
||||
canonical_usage = normalize_usage(
|
||||
response.usage,
|
||||
provider=self.provider,
|
||||
api_mode=self.api_mode,
|
||||
)
|
||||
prompt_tokens = canonical_usage.prompt_tokens
|
||||
completion_tokens = canonical_usage.output_tokens
|
||||
total_tokens = canonical_usage.total_tokens
|
||||
if self.api_mode in ("codex_responses", "anthropic_messages"):
|
||||
prompt_tokens = getattr(response.usage, 'input_tokens', 0) or 0
|
||||
if self.api_mode == "anthropic_messages":
|
||||
# Anthropic splits input into cache_read + cache_creation
|
||||
# + non-cached input_tokens. Without adding the cached
|
||||
# portions, the context bar shows only the tiny non-cached
|
||||
# portion (e.g. 3 tokens) instead of the real total (~18K).
|
||||
# Other providers (OpenAI/Codex) already include cached
|
||||
# tokens in their input_tokens/prompt_tokens field.
|
||||
prompt_tokens += getattr(response.usage, 'cache_read_input_tokens', 0) or 0
|
||||
prompt_tokens += getattr(response.usage, 'cache_creation_input_tokens', 0) or 0
|
||||
completion_tokens = getattr(response.usage, 'output_tokens', 0) or 0
|
||||
total_tokens = (
|
||||
getattr(response.usage, 'total_tokens', None)
|
||||
or (prompt_tokens + completion_tokens)
|
||||
)
|
||||
else:
|
||||
prompt_tokens = getattr(response.usage, 'prompt_tokens', 0) or 0
|
||||
completion_tokens = getattr(response.usage, 'completion_tokens', 0) or 0
|
||||
total_tokens = getattr(response.usage, 'total_tokens', 0) or 0
|
||||
usage_dict = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
@@ -5316,22 +5310,6 @@ class AIAgent:
|
||||
self.session_completion_tokens += completion_tokens
|
||||
self.session_total_tokens += total_tokens
|
||||
self.session_api_calls += 1
|
||||
self.session_input_tokens += canonical_usage.input_tokens
|
||||
self.session_output_tokens += canonical_usage.output_tokens
|
||||
self.session_cache_read_tokens += canonical_usage.cache_read_tokens
|
||||
self.session_cache_write_tokens += canonical_usage.cache_write_tokens
|
||||
self.session_reasoning_tokens += canonical_usage.reasoning_tokens
|
||||
|
||||
cost_result = estimate_usage_cost(
|
||||
self.model,
|
||||
canonical_usage,
|
||||
provider=self.provider,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
if cost_result.amount_usd is not None:
|
||||
self.session_estimated_cost_usd += float(cost_result.amount_usd)
|
||||
self.session_cost_status = cost_result.status
|
||||
self.session_cost_source = cost_result.source
|
||||
|
||||
# Persist token counts to session DB for /insights.
|
||||
# Gateway sessions persist via session_store.update_session()
|
||||
@@ -5342,19 +5320,8 @@ class AIAgent:
|
||||
try:
|
||||
self._session_db.update_token_counts(
|
||||
self.session_id,
|
||||
input_tokens=canonical_usage.input_tokens,
|
||||
output_tokens=canonical_usage.output_tokens,
|
||||
cache_read_tokens=canonical_usage.cache_read_tokens,
|
||||
cache_write_tokens=canonical_usage.cache_write_tokens,
|
||||
reasoning_tokens=canonical_usage.reasoning_tokens,
|
||||
estimated_cost_usd=float(cost_result.amount_usd)
|
||||
if cost_result.amount_usd is not None else None,
|
||||
cost_status=cost_result.status,
|
||||
cost_source=cost_result.source,
|
||||
billing_provider=self.provider,
|
||||
billing_base_url=self.base_url,
|
||||
billing_mode="subscription_included"
|
||||
if cost_result.status == "included" else None,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
model=self.model,
|
||||
)
|
||||
except Exception:
|
||||
@@ -5971,32 +5938,19 @@ class AIAgent:
|
||||
# Don't add anything to messages, just retry the API call
|
||||
continue
|
||||
else:
|
||||
# Instead of returning partial, inject tool error results so the model can recover.
|
||||
# Using tool results (not user messages) preserves role alternation.
|
||||
self._vprint(f"{self.log_prefix}⚠️ Injecting recovery tool results for invalid JSON...")
|
||||
# Instead of returning partial, inject a helpful message and let model recover
|
||||
self._vprint(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...")
|
||||
self._invalid_json_retries = 0 # Reset for next attempt
|
||||
|
||||
# Append the assistant message with its (broken) tool_calls
|
||||
recovery_assistant = self._build_assistant_message(assistant_message, finish_reason)
|
||||
messages.append(recovery_assistant)
|
||||
|
||||
# Respond with tool error results for each tool call
|
||||
invalid_names = {name for name, _ in invalid_json_args}
|
||||
for tc in assistant_message.tool_calls:
|
||||
if tc.function.name in invalid_names:
|
||||
err = next(e for n, e in invalid_json_args if n == tc.function.name)
|
||||
tool_result = (
|
||||
f"Error: Invalid JSON arguments. {err}. "
|
||||
f"For tools with no required parameters, use an empty object: {{}}. "
|
||||
f"Please retry with valid JSON."
|
||||
)
|
||||
else:
|
||||
tool_result = "Skipped: other tool call in this response had invalid JSON."
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
})
|
||||
# Add a user message explaining the issue
|
||||
recovery_msg = (
|
||||
f"Your tool call to '{tool_name}' had invalid JSON arguments. "
|
||||
f"Error: {error_msg}. "
|
||||
f"For tools with no required parameters, use an empty object: {{}}. "
|
||||
f"Please either retry the tool call with valid JSON, or respond without using that tool."
|
||||
)
|
||||
recovery_dict = {"role": "user", "content": recovery_msg}
|
||||
messages.append(recovery_dict)
|
||||
continue
|
||||
|
||||
# Reset retry counter on successful JSON validation
|
||||
@@ -6182,8 +6136,6 @@ class AIAgent:
|
||||
|
||||
if truncated_response_prefix:
|
||||
final_response = truncated_response_prefix + final_response
|
||||
truncated_response_prefix = ""
|
||||
length_continue_retries = 0
|
||||
|
||||
# Strip <think> blocks from user-facing response (keep raw in messages for trajectory)
|
||||
final_response = self._strip_think_blocks(final_response).strip()
|
||||
@@ -6235,11 +6187,10 @@ class AIAgent:
|
||||
|
||||
if not pending_handled:
|
||||
# Error happened before tool processing (e.g. response parsing).
|
||||
# Choose role to avoid consecutive same-role messages.
|
||||
last_role = messages[-1].get("role") if messages else None
|
||||
err_role = "assistant" if last_role == "user" else "user"
|
||||
# Use a user-role message so the model can see what went wrong
|
||||
# without confusing the API with a fabricated assistant turn.
|
||||
sys_err_msg = {
|
||||
"role": err_role,
|
||||
"role": "user",
|
||||
"content": f"[System error during processing: {error_msg}]",
|
||||
}
|
||||
messages.append(sys_err_msg)
|
||||
@@ -6291,21 +6242,6 @@ class AIAgent:
|
||||
"partial": False, # True only when stopped due to invalid tool calls
|
||||
"interrupted": interrupted,
|
||||
"response_previewed": getattr(self, "_response_was_previewed", False),
|
||||
"model": self.model,
|
||||
"provider": self.provider,
|
||||
"base_url": self.base_url,
|
||||
"input_tokens": self.session_input_tokens,
|
||||
"output_tokens": self.session_output_tokens,
|
||||
"cache_read_tokens": self.session_cache_read_tokens,
|
||||
"cache_write_tokens": self.session_cache_write_tokens,
|
||||
"reasoning_tokens": self.session_reasoning_tokens,
|
||||
"prompt_tokens": self.session_prompt_tokens,
|
||||
"completion_tokens": self.session_completion_tokens,
|
||||
"total_tokens": self.session_total_tokens,
|
||||
"last_prompt_tokens": getattr(self.context_compressor, "last_prompt_tokens", 0) or 0,
|
||||
"estimated_cost_usd": self.session_estimated_cost_usd,
|
||||
"cost_status": self.session_cost_status,
|
||||
"cost_source": self.session_cost_source,
|
||||
}
|
||||
self._response_was_previewed = False
|
||||
|
||||
|
||||
@@ -112,14 +112,9 @@ class TestDefaultContextLengths:
|
||||
|
||||
def test_gpt4_models_128k(self):
|
||||
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
if "gpt-4" in key and "gpt-4.1" not in key:
|
||||
if "gpt-4" in key:
|
||||
assert value == 128000, f"{key} should be 128000"
|
||||
|
||||
def test_gpt41_models_1m(self):
|
||||
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
if "gpt-4.1" in key:
|
||||
assert value == 1047576, f"{key} should be 1047576"
|
||||
|
||||
def test_gemini_models_1m(self):
|
||||
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
if "gemini" in key:
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent.usage_pricing import (
|
||||
CanonicalUsage,
|
||||
estimate_usage_cost,
|
||||
get_pricing_entry,
|
||||
normalize_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_usage_anthropic_keeps_cache_buckets_separate():
|
||||
usage = SimpleNamespace(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cache_read_input_tokens=2000,
|
||||
cache_creation_input_tokens=400,
|
||||
)
|
||||
|
||||
normalized = normalize_usage(usage, provider="anthropic", api_mode="anthropic_messages")
|
||||
|
||||
assert normalized.input_tokens == 1000
|
||||
assert normalized.output_tokens == 500
|
||||
assert normalized.cache_read_tokens == 2000
|
||||
assert normalized.cache_write_tokens == 400
|
||||
assert normalized.prompt_tokens == 3400
|
||||
|
||||
|
||||
def test_normalize_usage_openai_subtracts_cached_prompt_tokens():
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=3000,
|
||||
completion_tokens=700,
|
||||
prompt_tokens_details=SimpleNamespace(cached_tokens=1800),
|
||||
)
|
||||
|
||||
normalized = normalize_usage(usage, provider="openai", api_mode="chat_completions")
|
||||
|
||||
assert normalized.input_tokens == 1200
|
||||
assert normalized.cache_read_tokens == 1800
|
||||
assert normalized.output_tokens == 700
|
||||
|
||||
|
||||
def test_openrouter_models_api_pricing_is_converted_from_per_token_to_per_million(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.fetch_model_metadata",
|
||||
lambda: {
|
||||
"anthropic/claude-opus-4.6": {
|
||||
"pricing": {
|
||||
"prompt": "0.000005",
|
||||
"completion": "0.000025",
|
||||
"input_cache_read": "0.0000005",
|
||||
"input_cache_write": "0.00000625",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
entry = get_pricing_entry(
|
||||
"anthropic/claude-opus-4.6",
|
||||
provider="openrouter",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
|
||||
assert float(entry.input_cost_per_million) == 5.0
|
||||
assert float(entry.output_cost_per_million) == 25.0
|
||||
assert float(entry.cache_read_cost_per_million) == 0.5
|
||||
assert float(entry.cache_write_cost_per_million) == 6.25
|
||||
|
||||
|
||||
def test_estimate_usage_cost_marks_subscription_routes_included():
|
||||
result = estimate_usage_cost(
|
||||
"gpt-5.3-codex",
|
||||
CanonicalUsage(input_tokens=1000, output_tokens=500),
|
||||
provider="openai-codex",
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
)
|
||||
|
||||
assert result.status == "included"
|
||||
assert float(result.amount_usd) == 0.0
|
||||
|
||||
|
||||
def test_estimate_usage_cost_refuses_cache_pricing_without_official_cache_rate(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.fetch_model_metadata",
|
||||
lambda: {
|
||||
"google/gemini-2.5-pro": {
|
||||
"pricing": {
|
||||
"prompt": "0.00000125",
|
||||
"completion": "0.00001",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result = estimate_usage_cost(
|
||||
"google/gemini-2.5-pro",
|
||||
CanonicalUsage(input_tokens=1000, output_tokens=500, cache_read_tokens=100),
|
||||
provider="openrouter",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
|
||||
assert result.status == "unknown"
|
||||
@@ -50,16 +50,13 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
|
||||
return runner
|
||||
|
||||
|
||||
def _watcher_dict(session_id="proc_test", thread_id=""):
|
||||
d = {
|
||||
def _watcher_dict(session_id="proc_test"):
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"check_interval": 0,
|
||||
"platform": "telegram",
|
||||
"chat_id": "123",
|
||||
}
|
||||
if thread_id:
|
||||
d["thread_id"] = thread_id
|
||||
return d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -199,47 +196,3 @@ async def test_run_process_watcher_respects_notification_mode(
|
||||
if expected_fragment is not None:
|
||||
sent_message = adapter.send.await_args.args[1]
|
||||
assert expected_fragment in sent_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_passed_to_send(monkeypatch, tmp_path):
|
||||
"""thread_id from watcher dict is forwarded as metadata to adapter.send()."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict(thread_id="42"))
|
||||
|
||||
assert adapter.send.await_count == 1
|
||||
_, kwargs = adapter.send.call_args
|
||||
assert kwargs["metadata"] == {"thread_id": "42"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
|
||||
"""When thread_id is empty, metadata should be None (general topic)."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict())
|
||||
|
||||
assert adapter.send.await_count == 1
|
||||
_, kwargs = adapter.send.call_args
|
||||
assert kwargs["metadata"] is None
|
||||
|
||||
@@ -703,15 +703,5 @@ class TestLastPromptTokens:
|
||||
store.update_session("k1", model="openai/gpt-5.4")
|
||||
|
||||
store._db.update_token_counts.assert_called_once_with(
|
||||
"s1",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
estimated_cost_usd=None,
|
||||
cost_status=None,
|
||||
cost_source=None,
|
||||
billing_provider=None,
|
||||
billing_base_url=None,
|
||||
model="openai/gpt-5.4",
|
||||
"s1", 0, 0, model="openai/gpt-5.4"
|
||||
)
|
||||
|
||||
@@ -128,13 +128,6 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
session_entry.session_key,
|
||||
input_tokens=120,
|
||||
output_tokens=45,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
last_prompt_tokens=80,
|
||||
model="openai/test-model",
|
||||
estimated_cost_usd=None,
|
||||
cost_status=None,
|
||||
cost_source=None,
|
||||
provider=None,
|
||||
base_url=None,
|
||||
)
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
"""Tests for MCP tools interactive configuration in hermes_cli.tools_config."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hermes_cli.tools_config import _configure_mcp_tools_interactive
|
||||
|
||||
# Patch targets: imports happen inside the function body, so patch at source
|
||||
_PROBE = "tools.mcp_tool.probe_mcp_server_tools"
|
||||
_CHECKLIST = "hermes_cli.curses_ui.curses_checklist"
|
||||
_SAVE = "hermes_cli.tools_config.save_config"
|
||||
|
||||
|
||||
def test_no_mcp_servers_prints_info(capsys):
|
||||
"""Returns immediately when no MCP servers are configured."""
|
||||
config = {}
|
||||
_configure_mcp_tools_interactive(config)
|
||||
captured = capsys.readouterr()
|
||||
assert "No MCP servers configured" in captured.out
|
||||
|
||||
|
||||
def test_all_servers_disabled_prints_info(capsys):
|
||||
"""Returns immediately when all configured servers have enabled=false."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {"command": "npx", "enabled": False},
|
||||
"slack": {"command": "npx", "enabled": "false"},
|
||||
}
|
||||
}
|
||||
_configure_mcp_tools_interactive(config)
|
||||
captured = capsys.readouterr()
|
||||
assert "disabled" in captured.out
|
||||
|
||||
|
||||
def test_probe_failure_shows_warning(capsys):
|
||||
"""Shows warning when probe returns no tools."""
|
||||
config = {"mcp_servers": {"github": {"command": "npx"}}}
|
||||
with patch(_PROBE, return_value={}):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
captured = capsys.readouterr()
|
||||
assert "Could not discover" in captured.out
|
||||
|
||||
|
||||
def test_probe_exception_shows_error(capsys):
|
||||
"""Shows error when probe raises an exception."""
|
||||
config = {"mcp_servers": {"github": {"command": "npx"}}}
|
||||
with patch(_PROBE, side_effect=RuntimeError("MCP not installed")):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
captured = capsys.readouterr()
|
||||
assert "Failed to probe" in captured.out
|
||||
|
||||
|
||||
def test_no_changes_when_checklist_cancelled(capsys):
|
||||
"""No config changes when user cancels (ESC) the checklist."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {"command": "npx", "args": ["-y", "server-github"]},
|
||||
}
|
||||
}
|
||||
tools = [("create_issue", "Create an issue"), ("search_repos", "Search repos")]
|
||||
|
||||
with patch(_PROBE, return_value={"github": tools}), \
|
||||
patch(_CHECKLIST, return_value={0, 1}), \
|
||||
patch(_SAVE) as mock_save:
|
||||
_configure_mcp_tools_interactive(config)
|
||||
mock_save.assert_not_called()
|
||||
captured = capsys.readouterr()
|
||||
assert "no changes" in captured.out.lower()
|
||||
|
||||
|
||||
def test_disabling_tool_writes_exclude_list(capsys):
|
||||
"""Unchecking a tool adds it to the exclude list."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {"command": "npx"},
|
||||
}
|
||||
}
|
||||
tools = [
|
||||
("create_issue", "Create an issue"),
|
||||
("delete_repo", "Delete a repo"),
|
||||
("search_repos", "Search repos"),
|
||||
]
|
||||
|
||||
# User unchecks delete_repo (index 1)
|
||||
with patch(_PROBE, return_value={"github": tools}), \
|
||||
patch(_CHECKLIST, return_value={0, 2}), \
|
||||
patch(_SAVE) as mock_save:
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
tools_cfg = config["mcp_servers"]["github"]["tools"]
|
||||
assert tools_cfg["exclude"] == ["delete_repo"]
|
||||
assert "include" not in tools_cfg
|
||||
|
||||
|
||||
def test_enabling_all_clears_filters(capsys):
|
||||
"""Checking all tools clears both include and exclude lists."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {
|
||||
"command": "npx",
|
||||
"tools": {"exclude": ["delete_repo"], "include": ["create_issue"]},
|
||||
},
|
||||
}
|
||||
}
|
||||
tools = [("create_issue", "Create"), ("delete_repo", "Delete")]
|
||||
|
||||
# User checks all tools — pre_selected would be {0} (include mode),
|
||||
# so returning {0, 1} is a change
|
||||
with patch(_PROBE, return_value={"github": tools}), \
|
||||
patch(_CHECKLIST, return_value={0, 1}), \
|
||||
patch(_SAVE) as mock_save:
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
mock_save.assert_called_once()
|
||||
tools_cfg = config["mcp_servers"]["github"]["tools"]
|
||||
assert "exclude" not in tools_cfg
|
||||
assert "include" not in tools_cfg
|
||||
|
||||
|
||||
def test_pre_selection_respects_existing_exclude(capsys):
|
||||
"""Tools in exclude list start unchecked."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {
|
||||
"command": "npx",
|
||||
"tools": {"exclude": ["delete_repo"]},
|
||||
},
|
||||
}
|
||||
}
|
||||
tools = [("create_issue", "Create"), ("delete_repo", "Delete"), ("search", "Search")]
|
||||
captured_pre_selected = {}
|
||||
|
||||
def fake_checklist(title, labels, pre_selected, **kwargs):
|
||||
captured_pre_selected["value"] = set(pre_selected)
|
||||
return pre_selected # No changes
|
||||
|
||||
with patch(_PROBE, return_value={"github": tools}), \
|
||||
patch(_CHECKLIST, side_effect=fake_checklist), \
|
||||
patch(_SAVE):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
# create_issue (0) and search (2) should be pre-selected, delete_repo (1) should not
|
||||
assert captured_pre_selected["value"] == {0, 2}
|
||||
|
||||
|
||||
def test_pre_selection_respects_existing_include(capsys):
|
||||
"""Only tools in include list start checked."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {
|
||||
"command": "npx",
|
||||
"tools": {"include": ["search"]},
|
||||
},
|
||||
}
|
||||
}
|
||||
tools = [("create_issue", "Create"), ("delete_repo", "Delete"), ("search", "Search")]
|
||||
captured_pre_selected = {}
|
||||
|
||||
def fake_checklist(title, labels, pre_selected, **kwargs):
|
||||
captured_pre_selected["value"] = set(pre_selected)
|
||||
return pre_selected # No changes
|
||||
|
||||
with patch(_PROBE, return_value={"github": tools}), \
|
||||
patch(_CHECKLIST, side_effect=fake_checklist), \
|
||||
patch(_SAVE):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
# Only search (2) should be pre-selected
|
||||
assert captured_pre_selected["value"] == {2}
|
||||
|
||||
|
||||
def test_multiple_servers_each_get_checklist(capsys):
|
||||
"""Each server gets its own checklist."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {"command": "npx"},
|
||||
"slack": {"url": "https://mcp.example.com"},
|
||||
}
|
||||
}
|
||||
checklist_calls = []
|
||||
|
||||
def fake_checklist(title, labels, pre_selected, **kwargs):
|
||||
checklist_calls.append(title)
|
||||
return pre_selected # No changes
|
||||
|
||||
with patch(
|
||||
_PROBE,
|
||||
return_value={
|
||||
"github": [("create_issue", "Create")],
|
||||
"slack": [("send_message", "Send")],
|
||||
},
|
||||
), patch(_CHECKLIST, side_effect=fake_checklist), \
|
||||
patch(_SAVE):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
assert len(checklist_calls) == 2
|
||||
assert any("github" in t for t in checklist_calls)
|
||||
assert any("slack" in t for t in checklist_calls)
|
||||
|
||||
|
||||
def test_failed_server_shows_warning(capsys):
|
||||
"""Servers that fail to connect show warnings."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {"command": "npx"},
|
||||
"broken": {"command": "nonexistent"},
|
||||
}
|
||||
}
|
||||
|
||||
# Only github succeeds
|
||||
with patch(
|
||||
_PROBE, return_value={"github": [("create_issue", "Create")]},
|
||||
), patch(_CHECKLIST, return_value={0}), \
|
||||
patch(_SAVE):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "broken" in captured.out
|
||||
|
||||
|
||||
def test_description_truncation_in_labels():
|
||||
"""Long descriptions are truncated in checklist labels."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {"command": "npx"},
|
||||
}
|
||||
}
|
||||
long_desc = "A" * 100
|
||||
captured_labels = {}
|
||||
|
||||
def fake_checklist(title, labels, pre_selected, **kwargs):
|
||||
captured_labels["value"] = labels
|
||||
return pre_selected
|
||||
|
||||
with patch(
|
||||
_PROBE, return_value={"github": [("my_tool", long_desc)]},
|
||||
), patch(_CHECKLIST, side_effect=fake_checklist), \
|
||||
patch(_SAVE):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
label = captured_labels["value"][0]
|
||||
assert "..." in label
|
||||
assert len(label) < len(long_desc) + 30 # truncated + tool name + parens
|
||||
|
||||
|
||||
def test_switching_from_include_to_exclude(capsys):
|
||||
"""When user modifies selection, include list is replaced by exclude list."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"github": {
|
||||
"command": "npx",
|
||||
"tools": {"include": ["create_issue"]},
|
||||
},
|
||||
}
|
||||
}
|
||||
tools = [("create_issue", "Create"), ("search", "Search"), ("delete", "Delete")]
|
||||
|
||||
# User selects create_issue and search (deselects delete)
|
||||
# pre_selected would be {0} (only create_issue from include), so {0, 1} is a change
|
||||
with patch(_PROBE, return_value={"github": tools}), \
|
||||
patch(_CHECKLIST, return_value={0, 1}), \
|
||||
patch(_SAVE):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
tools_cfg = config["mcp_servers"]["github"]["tools"]
|
||||
assert tools_cfg["exclude"] == ["delete"]
|
||||
assert "include" not in tools_cfg
|
||||
|
||||
|
||||
def test_empty_tools_server_skipped(capsys):
|
||||
"""Server with no tools shows info message and skips checklist."""
|
||||
config = {
|
||||
"mcp_servers": {
|
||||
"empty": {"command": "npx"},
|
||||
}
|
||||
}
|
||||
checklist_calls = []
|
||||
|
||||
def fake_checklist(title, labels, pre_selected, **kwargs):
|
||||
checklist_calls.append(title)
|
||||
return pre_selected
|
||||
|
||||
with patch(_PROBE, return_value={"empty": []}), \
|
||||
patch(_CHECKLIST, side_effect=fake_checklist), \
|
||||
patch(_SAVE):
|
||||
_configure_mcp_tools_interactive(config)
|
||||
|
||||
assert len(checklist_calls) == 0
|
||||
captured = capsys.readouterr()
|
||||
assert "no tools found" in captured.out
|
||||
@@ -5,13 +5,6 @@ from hermes_cli.config import load_config, save_config
|
||||
from hermes_cli.setup import setup_model_provider
|
||||
|
||||
|
||||
def _maybe_keep_current_tts(question, choices):
|
||||
if question != "Select TTS provider:":
|
||||
return None
|
||||
assert choices[-1].startswith("Keep current (")
|
||||
return len(choices) - 1
|
||||
|
||||
|
||||
def _clear_provider_env(monkeypatch):
|
||||
for key in (
|
||||
"NOUS_API_KEY",
|
||||
@@ -32,22 +25,16 @@ def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider(
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 0
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
assert choices[-1] == "Keep current (anthropic/claude-opus-4.6)"
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
# Provider selection always comes first. Depending on available vision
|
||||
# backends, setup may either skip the optional vision step or prompt for
|
||||
# it before the default-model choice. Provide enough selections for both
|
||||
# paths while still ending on "keep current model".
|
||||
prompt_choices = iter([0, 2, 2])
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt_choice",
|
||||
lambda *args, **kwargs: next(prompt_choices),
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
|
||||
def _fake_login_nous(*args, **kwargs):
|
||||
auth_path = tmp_path / "auth.json"
|
||||
@@ -87,29 +74,20 @@ def test_custom_setup_clears_active_oauth_provider(tmp_path, monkeypatch):
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 3
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: 3)
|
||||
|
||||
prompt_values = iter(
|
||||
[
|
||||
"https://custom.example/v1",
|
||||
"custom-api-key",
|
||||
"custom/model",
|
||||
"",
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt",
|
||||
lambda *args, **kwargs: next(prompt_values),
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
@@ -131,17 +109,11 @@ def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, mon
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 1
|
||||
if question == "Select default model:":
|
||||
return 0
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
prompt_choices = iter([1, 0])
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt_choice",
|
||||
lambda *args, **kwargs: next(prompt_choices),
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.auth._login_openai_codex", lambda *args, **kwargs: None)
|
||||
|
||||
@@ -6,13 +6,6 @@ from hermes_cli.config import load_config, save_config, save_env_value
|
||||
from hermes_cli.setup import _print_setup_summary, setup_model_provider
|
||||
|
||||
|
||||
def _maybe_keep_current_tts(question, choices):
|
||||
if question != "Select TTS provider:":
|
||||
return None
|
||||
assert choices[-1].startswith("Keep current (")
|
||||
return len(choices) - 1
|
||||
|
||||
|
||||
def _read_env(home):
|
||||
env_path = home / ".env"
|
||||
data = {}
|
||||
@@ -57,13 +50,13 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
|
||||
}
|
||||
save_config(config)
|
||||
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
assert choices[-1] == "Keep current (Custom: https://example.invalid/v1)"
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError("Model menu should not appear for keep-current custom")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
@@ -79,6 +72,7 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
|
||||
assert reloaded["model"]["provider"] == "custom"
|
||||
assert reloaded["model"]["default"] == "custom/model"
|
||||
assert reloaded["model"]["base_url"] == "https://example.invalid/v1"
|
||||
assert calls["count"] == 1
|
||||
|
||||
|
||||
def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
|
||||
@@ -92,9 +86,6 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
|
||||
return 3 # Custom endpoint
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1 # Skip
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
def fake_prompt(message, current=None, **kwargs):
|
||||
@@ -149,23 +140,22 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
|
||||
save_config(config)
|
||||
|
||||
captured = {"provider_choices": None, "model_choices": None}
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 1:
|
||||
captured["provider_choices"] = list(choices)
|
||||
assert choices[-1] == "Keep current (Anthropic)"
|
||||
return len(choices) - 1
|
||||
if question == "Configure vision:":
|
||||
if calls["count"] == 2:
|
||||
assert question == "Configure vision:"
|
||||
assert choices[-1] == "Skip for now"
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
if calls["count"] == 3:
|
||||
captured["model_choices"] = list(choices)
|
||||
return len(choices) - 1 # keep current model
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
raise AssertionError("Unexpected extra prompt_choice call")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
@@ -182,6 +172,7 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
|
||||
assert captured["model_choices"] is not None
|
||||
assert captured["model_choices"][0] == "claude-opus-4-6"
|
||||
assert "anthropic/claude-opus-4.6 (recommended)" not in captured["model_choices"]
|
||||
assert calls["count"] == 3
|
||||
|
||||
|
||||
def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_path, monkeypatch):
|
||||
@@ -195,24 +186,14 @@ def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_pa
|
||||
}
|
||||
save_config(config)
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
assert choices[-1] == "Keep current (Anthropic)"
|
||||
return len(choices) - 1
|
||||
if question == "Configure vision:":
|
||||
return 1
|
||||
if question == "Select vision model:":
|
||||
assert choices[-1] == "Use default (gpt-4o-mini)"
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
assert choices[-1] == "Keep current (claude-opus-4-6)"
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
picks = iter([
|
||||
10, # keep current provider (shifted +1 by kilocode insertion)
|
||||
1, # configure vision with OpenAI
|
||||
5, # use default gpt-4o-mini vision model
|
||||
4, # keep current Anthropic model
|
||||
])
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt",
|
||||
lambda message, *args, **kwargs: "sk-openai" if "OpenAI API key" in message else "",
|
||||
@@ -248,17 +229,8 @@ def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config(
|
||||
}
|
||||
save_config(config)
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 1
|
||||
if question == "Select default model:":
|
||||
return 0
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
picks = iter([1, 0])
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks))
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
|
||||
@@ -63,13 +63,11 @@ class TestFromEnv:
|
||||
|
||||
class TestFromGlobalConfig:
|
||||
def test_missing_config_falls_back_to_env(self, tmp_path):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
config = HonchoClientConfig.from_global_config(
|
||||
config_path=tmp_path / "nonexistent.json"
|
||||
)
|
||||
config = HonchoClientConfig.from_global_config(
|
||||
config_path=tmp_path / "nonexistent.json"
|
||||
)
|
||||
# Should fall back to from_env
|
||||
assert config.enabled is False
|
||||
assert config.api_key is None
|
||||
assert config.enabled is True or config.api_key is None # depends on env
|
||||
|
||||
def test_reads_full_config(self, tmp_path):
|
||||
config_file = tmp_path / "config.json"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Comprehensive Test Suite for Web Tools Module
|
||||
|
||||
This script tests all web tools functionality to ensure they work correctly.
|
||||
Run this after any updates to the web_tools.py module or backend libraries.
|
||||
Run this after any updates to the web_tools.py module or Firecrawl library.
|
||||
|
||||
Usage:
|
||||
python test_web_tools.py # Run all tests
|
||||
@@ -11,7 +11,7 @@ Usage:
|
||||
python test_web_tools.py --verbose # Show detailed output
|
||||
|
||||
Requirements:
|
||||
- PARALLEL_API_KEY or FIRECRAWL_API_KEY environment variable must be set
|
||||
- FIRECRAWL_API_KEY environment variable must be set
|
||||
- An auxiliary LLM provider (OPENROUTER_API_KEY or Nous Portal auth) (optional, for LLM tests)
|
||||
"""
|
||||
|
||||
@@ -28,14 +28,12 @@ from typing import List
|
||||
|
||||
# Import the web tools to test (updated path after moving tools/)
|
||||
from tools.web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key,
|
||||
check_web_api_key,
|
||||
check_auxiliary_model,
|
||||
get_debug_session_info,
|
||||
_get_backend,
|
||||
get_debug_session_info
|
||||
)
|
||||
|
||||
|
||||
@@ -123,13 +121,12 @@ class WebToolsTester:
|
||||
"""Test environment setup and API keys"""
|
||||
print_section("Environment Check")
|
||||
|
||||
# Check web backend API key (Parallel or Firecrawl)
|
||||
if not check_web_api_key():
|
||||
self.log_result("Web Backend API Key", "failed", "PARALLEL_API_KEY or FIRECRAWL_API_KEY not set")
|
||||
# Check Firecrawl API key
|
||||
if not check_firecrawl_api_key():
|
||||
self.log_result("Firecrawl API Key", "failed", "FIRECRAWL_API_KEY not set")
|
||||
return False
|
||||
else:
|
||||
backend = _get_backend()
|
||||
self.log_result("Web Backend API Key", "passed", f"Using {backend} backend")
|
||||
self.log_result("Firecrawl API Key", "passed", "Found")
|
||||
|
||||
# Check auxiliary LLM provider (optional)
|
||||
if not check_auxiliary_model():
|
||||
@@ -581,9 +578,7 @@ class WebToolsTester:
|
||||
},
|
||||
"results": self.test_results,
|
||||
"environment": {
|
||||
"web_backend": _get_backend() if check_web_api_key() else None,
|
||||
"firecrawl_api_key": check_firecrawl_api_key(),
|
||||
"parallel_api_key": bool(os.getenv("PARALLEL_API_KEY")),
|
||||
"auxiliary_model": check_auxiliary_model(),
|
||||
"debug_mode": get_debug_session_info()["enabled"]
|
||||
}
|
||||
|
||||
@@ -98,14 +98,11 @@ class TestProviderRegistry:
|
||||
# =============================================================================
|
||||
|
||||
PROVIDER_ENV_VARS = (
|
||||
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN",
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY",
|
||||
"KIMI_API_KEY", "KIMI_BASE_URL", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY",
|
||||
"AI_GATEWAY_API_KEY", "AI_GATEWAY_BASE_URL",
|
||||
"KILOCODE_API_KEY", "KILOCODE_BASE_URL",
|
||||
"DASHSCOPE_API_KEY", "OPENCODE_ZEN_API_KEY", "OPENCODE_GO_API_KEY",
|
||||
"NOUS_API_KEY",
|
||||
"OPENAI_BASE_URL",
|
||||
)
|
||||
|
||||
@@ -114,7 +111,6 @@ PROVIDER_ENV_VARS = (
|
||||
def _clear_provider_env(monkeypatch):
|
||||
for key in PROVIDER_ENV_VARS:
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
monkeypatch.setattr("hermes_cli.auth._load_auth_store", lambda: {})
|
||||
|
||||
|
||||
class TestResolveProvider:
|
||||
|
||||
@@ -16,10 +16,6 @@ def _make_cli(model: str = "anthropic/claude-sonnet-4-20250514"):
|
||||
def _attach_agent(
|
||||
cli_obj,
|
||||
*,
|
||||
input_tokens: int | None = None,
|
||||
output_tokens: int | None = None,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
@@ -30,12 +26,6 @@ def _attach_agent(
|
||||
):
|
||||
cli_obj.agent = SimpleNamespace(
|
||||
model=cli_obj.model,
|
||||
provider="anthropic" if cli_obj.model.startswith("anthropic/") else None,
|
||||
base_url="",
|
||||
session_input_tokens=input_tokens if input_tokens is not None else prompt_tokens,
|
||||
session_output_tokens=output_tokens if output_tokens is not None else completion_tokens,
|
||||
session_cache_read_tokens=cache_read_tokens,
|
||||
session_cache_write_tokens=cache_write_tokens,
|
||||
session_prompt_tokens=prompt_tokens,
|
||||
session_completion_tokens=completion_tokens,
|
||||
session_total_tokens=total_tokens,
|
||||
@@ -78,19 +68,20 @@ class TestCLIStatusBar:
|
||||
assert "$0.06" not in text # cost hidden by default
|
||||
assert "15m" in text
|
||||
|
||||
def test_build_status_bar_text_no_cost_in_status_bar(self):
|
||||
def test_build_status_bar_text_shows_cost_when_enabled(self):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=10000,
|
||||
completion_tokens=5000,
|
||||
total_tokens=15000,
|
||||
completion_tokens=2400,
|
||||
total_tokens=12400,
|
||||
api_calls=7,
|
||||
context_tokens=50000,
|
||||
context_tokens=12400,
|
||||
context_length=200_000,
|
||||
)
|
||||
cli_obj.show_cost = True
|
||||
|
||||
text = cli_obj._build_status_bar_text(width=120)
|
||||
assert "$" not in text # cost is never shown in status bar
|
||||
assert "$" in text # cost is shown when enabled
|
||||
|
||||
def test_build_status_bar_text_collapses_for_narrow_terminal(self):
|
||||
cli_obj = _attach_agent(
|
||||
@@ -137,8 +128,8 @@ class TestCLIUsageReport:
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Model:" in output
|
||||
assert "Cost status:" in output
|
||||
assert "Cost source:" in output
|
||||
assert "Input cost:" in output
|
||||
assert "Output cost:" in output
|
||||
assert "Total cost:" in output
|
||||
assert "$" in output
|
||||
assert "0.064" in output
|
||||
|
||||
@@ -657,7 +657,7 @@ class TestSchemaInit:
|
||||
def test_schema_version(self, db):
|
||||
cursor = db._conn.execute("SELECT version FROM schema_version")
|
||||
version = cursor.fetchone()[0]
|
||||
assert version == 5
|
||||
assert version == 4
|
||||
|
||||
def test_title_column_exists(self, db):
|
||||
"""Verify the title column was created in the sessions table."""
|
||||
@@ -713,12 +713,12 @@ class TestSchemaInit:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Open with SessionDB — should migrate to v5
|
||||
# Open with SessionDB — should migrate to v4
|
||||
migrated_db = SessionDB(db_path=db_path)
|
||||
|
||||
# Verify migration
|
||||
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 5
|
||||
assert cursor.fetchone()[0] == 4
|
||||
|
||||
# Verify title column exists and is NULL for existing sessions
|
||||
session = migrated_db.get_session("existing")
|
||||
|
||||
@@ -123,16 +123,28 @@ def populated_db(db):
|
||||
# =========================================================================
|
||||
|
||||
class TestPricing:
|
||||
def test_exact_match(self):
|
||||
pricing = _get_pricing("gpt-4o")
|
||||
assert pricing["input"] == 2.50
|
||||
assert pricing["output"] == 10.00
|
||||
|
||||
def test_provider_prefix_stripped(self):
|
||||
pricing = _get_pricing("anthropic/claude-sonnet-4-20250514")
|
||||
assert pricing["input"] == 3.00
|
||||
assert pricing["output"] == 15.00
|
||||
|
||||
def test_unknown_models_do_not_use_heuristics(self):
|
||||
def test_prefix_match(self):
|
||||
pricing = _get_pricing("claude-3-5-sonnet-20241022")
|
||||
assert pricing["input"] == 3.00
|
||||
|
||||
def test_keyword_heuristic_opus(self):
|
||||
pricing = _get_pricing("some-new-opus-model")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
assert pricing["input"] == 15.00
|
||||
assert pricing["output"] == 75.00
|
||||
|
||||
def test_keyword_heuristic_haiku(self):
|
||||
pricing = _get_pricing("anthropic/claude-haiku-future")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
assert pricing["input"] == 0.80
|
||||
|
||||
def test_unknown_model_returns_zero_cost(self):
|
||||
"""Unknown/custom models should NOT have fabricated costs."""
|
||||
@@ -156,12 +168,40 @@ class TestPricing:
|
||||
pricing = _get_pricing("")
|
||||
assert pricing == _DEFAULT_PRICING
|
||||
|
||||
def test_deepseek_heuristic(self):
|
||||
pricing = _get_pricing("deepseek-v3")
|
||||
assert pricing["input"] == 0.14
|
||||
|
||||
def test_gemini_heuristic(self):
|
||||
pricing = _get_pricing("gemini-3.0-ultra")
|
||||
assert pricing["input"] == 0.15
|
||||
|
||||
def test_dated_model_gpt4o_mini(self):
|
||||
"""gpt-4o-mini-2024-07-18 should match gpt-4o-mini, NOT gpt-4o."""
|
||||
pricing = _get_pricing("gpt-4o-mini-2024-07-18")
|
||||
assert pricing["input"] == 0.15 # gpt-4o-mini price, not gpt-4o's 2.50
|
||||
|
||||
def test_dated_model_o3_mini(self):
|
||||
"""o3-mini-2025-01-31 should match o3-mini, NOT o3."""
|
||||
pricing = _get_pricing("o3-mini-2025-01-31")
|
||||
assert pricing["input"] == 1.10 # o3-mini price, not o3's 10.00
|
||||
|
||||
def test_dated_model_gpt41_mini(self):
|
||||
"""gpt-4.1-mini-2025-04-14 should match gpt-4.1-mini, NOT gpt-4.1."""
|
||||
pricing = _get_pricing("gpt-4.1-mini-2025-04-14")
|
||||
assert pricing["input"] == 0.40 # gpt-4.1-mini, not gpt-4.1's 2.00
|
||||
|
||||
def test_dated_model_gpt41_nano(self):
|
||||
"""gpt-4.1-nano-2025-04-14 should match gpt-4.1-nano, NOT gpt-4.1."""
|
||||
pricing = _get_pricing("gpt-4.1-nano-2025-04-14")
|
||||
assert pricing["input"] == 0.10 # gpt-4.1-nano, not gpt-4.1's 2.00
|
||||
|
||||
|
||||
class TestHasKnownPricing:
|
||||
def test_known_commercial_model(self):
|
||||
assert _has_known_pricing("gpt-4o", provider="openai") is True
|
||||
assert _has_known_pricing("gpt-4o") is True
|
||||
assert _has_known_pricing("anthropic/claude-sonnet-4-20250514") is True
|
||||
assert _has_known_pricing("gpt-4.1", provider="openai") is True
|
||||
assert _has_known_pricing("deepseek-chat") is True
|
||||
|
||||
def test_unknown_custom_model(self):
|
||||
assert _has_known_pricing("FP16_Hermes_4.5") is False
|
||||
@@ -170,39 +210,26 @@ class TestHasKnownPricing:
|
||||
assert _has_known_pricing("") is False
|
||||
assert _has_known_pricing(None) is False
|
||||
|
||||
def test_heuristic_matched_models_are_not_considered_known(self):
|
||||
assert _has_known_pricing("some-opus-model") is False
|
||||
assert _has_known_pricing("future-sonnet-v2") is False
|
||||
def test_heuristic_matched_models(self):
|
||||
"""Models matched by keyword heuristics should be considered known."""
|
||||
assert _has_known_pricing("some-opus-model") is True
|
||||
assert _has_known_pricing("future-sonnet-v2") is True
|
||||
|
||||
|
||||
class TestEstimateCost:
|
||||
def test_basic_cost(self):
|
||||
cost, status = _estimate_cost(
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
1_000_000,
|
||||
1_000_000,
|
||||
provider="anthropic",
|
||||
)
|
||||
assert status == "estimated"
|
||||
assert cost == pytest.approx(18.0, abs=0.01)
|
||||
# gpt-4o: 2.50/M input, 10.00/M output
|
||||
cost = _estimate_cost("gpt-4o", 1_000_000, 1_000_000)
|
||||
assert cost == pytest.approx(12.50, abs=0.01)
|
||||
|
||||
def test_zero_tokens(self):
|
||||
cost, status = _estimate_cost("gpt-4o", 0, 0, provider="openai")
|
||||
assert status == "estimated"
|
||||
cost = _estimate_cost("gpt-4o", 0, 0)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_cache_aware_usage(self):
|
||||
cost, status = _estimate_cost(
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
1000,
|
||||
500,
|
||||
cache_read_tokens=2000,
|
||||
cache_write_tokens=400,
|
||||
provider="anthropic",
|
||||
)
|
||||
assert status == "estimated"
|
||||
expected = (1000 * 3.0 + 500 * 15.0 + 2000 * 0.30 + 400 * 3.75) / 1_000_000
|
||||
assert cost == pytest.approx(expected, abs=0.0001)
|
||||
def test_small_usage(self):
|
||||
cost = _estimate_cost("gpt-4o", 1000, 500)
|
||||
# 1000 * 2.50/1M + 500 * 10.00/1M = 0.0025 + 0.005 = 0.0075
|
||||
assert cost == pytest.approx(0.0075, abs=0.0001)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
@@ -633,13 +660,8 @@ class TestEdgeCases:
|
||||
|
||||
def test_mixed_commercial_and_custom_models(self, db):
|
||||
"""Mix of commercial and custom models: only commercial ones get costs."""
|
||||
db.create_session(session_id="s1", source="cli", model="anthropic/claude-sonnet-4-20250514")
|
||||
db.update_token_counts(
|
||||
"s1",
|
||||
input_tokens=10000,
|
||||
output_tokens=5000,
|
||||
billing_provider="anthropic",
|
||||
)
|
||||
db.create_session(session_id="s1", source="cli", model="gpt-4o")
|
||||
db.update_token_counts("s1", input_tokens=10000, output_tokens=5000)
|
||||
db.create_session(session_id="s2", source="cli", model="my-local-llama")
|
||||
db.update_token_counts("s2", input_tokens=10000, output_tokens=5000)
|
||||
db._conn.commit()
|
||||
@@ -650,13 +672,13 @@ class TestEdgeCases:
|
||||
# Cost should only come from gpt-4o, not from the custom model
|
||||
overview = report["overview"]
|
||||
assert overview["estimated_cost"] > 0
|
||||
assert "claude-sonnet-4-20250514" in overview["models_with_pricing"] # list now, not set
|
||||
assert "gpt-4o" in overview["models_with_pricing"] # list now, not set
|
||||
assert "my-local-llama" in overview["models_without_pricing"]
|
||||
|
||||
# Verify individual model entries
|
||||
claude = next(m for m in report["models"] if m["model"] == "claude-sonnet-4-20250514")
|
||||
assert claude["has_pricing"] is True
|
||||
assert claude["cost"] > 0
|
||||
gpt = next(m for m in report["models"] if m["model"] == "gpt-4o")
|
||||
assert gpt["has_pricing"] is True
|
||||
assert gpt["cost"] > 0
|
||||
|
||||
llama = next(m for m in report["models"] if m["model"] == "my-local-llama")
|
||||
assert llama["has_pricing"] is False
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
"""Tests for probe_mcp_server_tools() in tools.mcp_tool."""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_mcp_state():
|
||||
"""Ensure clean MCP module state before/after each test."""
|
||||
import tools.mcp_tool as mcp
|
||||
old_loop = mcp._mcp_loop
|
||||
old_thread = mcp._mcp_thread
|
||||
old_servers = dict(mcp._servers)
|
||||
yield
|
||||
mcp._servers.clear()
|
||||
mcp._servers.update(old_servers)
|
||||
mcp._mcp_loop = old_loop
|
||||
mcp._mcp_thread = old_thread
|
||||
|
||||
|
||||
class TestProbeMcpServerTools:
|
||||
"""Tests for the lightweight probe_mcp_server_tools function."""
|
||||
|
||||
def test_returns_empty_when_mcp_not_available(self):
|
||||
with patch("tools.mcp_tool._MCP_AVAILABLE", False):
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
result = probe_mcp_server_tools()
|
||||
assert result == {}
|
||||
|
||||
def test_returns_empty_when_no_config(self):
|
||||
with patch("tools.mcp_tool._load_mcp_config", return_value={}):
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
result = probe_mcp_server_tools()
|
||||
assert result == {}
|
||||
|
||||
def test_returns_empty_when_all_servers_disabled(self):
|
||||
config = {
|
||||
"github": {"command": "npx", "enabled": False},
|
||||
"slack": {"command": "npx", "enabled": "off"},
|
||||
}
|
||||
with patch("tools.mcp_tool._load_mcp_config", return_value=config):
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
result = probe_mcp_server_tools()
|
||||
assert result == {}
|
||||
|
||||
def test_returns_tools_from_successful_server(self):
|
||||
"""Successfully probed server returns its tools list."""
|
||||
config = {
|
||||
"github": {"command": "npx", "connect_timeout": 5},
|
||||
}
|
||||
mock_tool_1 = SimpleNamespace(name="create_issue", description="Create a new issue")
|
||||
mock_tool_2 = SimpleNamespace(name="search_repos", description="Search repositories")
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server._tools = [mock_tool_1, mock_tool_2]
|
||||
mock_server.shutdown = AsyncMock()
|
||||
|
||||
async def fake_connect(name, cfg):
|
||||
return mock_server
|
||||
|
||||
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.mcp_tool._ensure_mcp_loop"), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
# Simulate running the async probe
|
||||
def run_coro(coro, timeout=120):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
mock_run.side_effect = run_coro
|
||||
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
result = probe_mcp_server_tools()
|
||||
|
||||
assert "github" in result
|
||||
assert len(result["github"]) == 2
|
||||
assert result["github"][0] == ("create_issue", "Create a new issue")
|
||||
assert result["github"][1] == ("search_repos", "Search repositories")
|
||||
mock_server.shutdown.assert_awaited_once()
|
||||
|
||||
def test_failed_server_omitted_from_results(self):
|
||||
"""Servers that fail to connect are silently skipped."""
|
||||
config = {
|
||||
"github": {"command": "npx", "connect_timeout": 5},
|
||||
"broken": {"command": "nonexistent", "connect_timeout": 5},
|
||||
}
|
||||
mock_tool = SimpleNamespace(name="create_issue", description="Create")
|
||||
mock_server = MagicMock()
|
||||
mock_server._tools = [mock_tool]
|
||||
mock_server.shutdown = AsyncMock()
|
||||
|
||||
async def fake_connect(name, cfg):
|
||||
if name == "broken":
|
||||
raise ConnectionError("Server not found")
|
||||
return mock_server
|
||||
|
||||
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.mcp_tool._ensure_mcp_loop"), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
mock_run.side_effect = run_coro
|
||||
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
result = probe_mcp_server_tools()
|
||||
|
||||
assert "github" in result
|
||||
assert "broken" not in result
|
||||
|
||||
def test_handles_tool_without_description(self):
|
||||
"""Tools without descriptions get empty string."""
|
||||
config = {"github": {"command": "npx", "connect_timeout": 5}}
|
||||
mock_tool = SimpleNamespace(name="my_tool") # no description attribute
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server._tools = [mock_tool]
|
||||
mock_server.shutdown = AsyncMock()
|
||||
|
||||
async def fake_connect(name, cfg):
|
||||
return mock_server
|
||||
|
||||
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.mcp_tool._ensure_mcp_loop"), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
mock_run.side_effect = run_coro
|
||||
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
result = probe_mcp_server_tools()
|
||||
|
||||
assert result["github"][0] == ("my_tool", "")
|
||||
|
||||
def test_cleanup_called_even_on_failure(self):
|
||||
"""_stop_mcp_loop is called even when probe fails."""
|
||||
config = {"github": {"command": "npx", "connect_timeout": 5}}
|
||||
|
||||
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
|
||||
patch("tools.mcp_tool._ensure_mcp_loop"), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop", side_effect=RuntimeError("boom")), \
|
||||
patch("tools.mcp_tool._stop_mcp_loop") as mock_stop:
|
||||
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
result = probe_mcp_server_tools()
|
||||
|
||||
assert result == {}
|
||||
mock_stop.assert_called_once()
|
||||
|
||||
def test_skips_disabled_servers(self):
|
||||
"""Disabled servers are not probed."""
|
||||
config = {
|
||||
"github": {"command": "npx", "connect_timeout": 5},
|
||||
"disabled_one": {"command": "npx", "enabled": False},
|
||||
}
|
||||
mock_tool = SimpleNamespace(name="create_issue", description="Create")
|
||||
mock_server = MagicMock()
|
||||
mock_server._tools = [mock_tool]
|
||||
mock_server.shutdown = AsyncMock()
|
||||
|
||||
connect_calls = []
|
||||
|
||||
async def fake_connect(name, cfg):
|
||||
connect_calls.append(name)
|
||||
return mock_server
|
||||
|
||||
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
|
||||
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch("tools.mcp_tool._ensure_mcp_loop"), \
|
||||
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
|
||||
patch("tools.mcp_tool._stop_mcp_loop"):
|
||||
|
||||
def run_coro(coro, timeout=120):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
mock_run.side_effect = run_coro
|
||||
|
||||
from tools.mcp_tool import probe_mcp_server_tools
|
||||
result = probe_mcp_server_tools()
|
||||
|
||||
assert "github" in result
|
||||
assert "disabled_one" not in result
|
||||
assert "disabled_one" not in connect_calls
|
||||
@@ -2596,19 +2596,17 @@ class TestMCPSelectiveToolLoading:
|
||||
|
||||
async def run():
|
||||
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
|
||||
patch.dict("tools.mcp_tool._servers", {}, clear=True), \
|
||||
patch("tools.registry.registry", mock_registry), \
|
||||
patch("toolsets.create_custom_toolset"):
|
||||
registered = await _discover_and_register_server(
|
||||
return await _discover_and_register_server(
|
||||
"ink_existing",
|
||||
{"url": "https://mcp.example.com", "tools": {"include": ["create_service"]}},
|
||||
)
|
||||
return registered, _existing_tool_names()
|
||||
|
||||
try:
|
||||
registered, existing = asyncio.run(run())
|
||||
registered = asyncio.run(run())
|
||||
assert registered == ["mcp_ink_existing_create_service"]
|
||||
assert existing == ["mcp_ink_existing_create_service"]
|
||||
assert _existing_tool_names() == ["mcp_ink_existing_create_service"]
|
||||
finally:
|
||||
_servers.pop("ink_existing", None)
|
||||
|
||||
|
||||
@@ -294,61 +294,6 @@ class TestCheckpoint:
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 0
|
||||
|
||||
def test_write_checkpoint_includes_watcher_metadata(self, registry, tmp_path):
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", tmp_path / "procs.json"):
|
||||
s = _make_session()
|
||||
s.watcher_platform = "telegram"
|
||||
s.watcher_chat_id = "999"
|
||||
s.watcher_thread_id = "42"
|
||||
s.watcher_interval = 60
|
||||
registry._running[s.id] = s
|
||||
registry._write_checkpoint()
|
||||
|
||||
data = json.loads((tmp_path / "procs.json").read_text())
|
||||
assert len(data) == 1
|
||||
assert data[0]["watcher_platform"] == "telegram"
|
||||
assert data[0]["watcher_chat_id"] == "999"
|
||||
assert data[0]["watcher_thread_id"] == "42"
|
||||
assert data[0]["watcher_interval"] == 60
|
||||
|
||||
def test_recover_enqueues_watchers(self, registry, tmp_path):
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
checkpoint.write_text(json.dumps([{
|
||||
"session_id": "proc_live",
|
||||
"command": "sleep 999",
|
||||
"pid": os.getpid(), # current process — guaranteed alive
|
||||
"task_id": "t1",
|
||||
"session_key": "sk1",
|
||||
"watcher_platform": "telegram",
|
||||
"watcher_chat_id": "123",
|
||||
"watcher_thread_id": "42",
|
||||
"watcher_interval": 60,
|
||||
}]))
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 1
|
||||
assert len(registry.pending_watchers) == 1
|
||||
w = registry.pending_watchers[0]
|
||||
assert w["session_id"] == "proc_live"
|
||||
assert w["platform"] == "telegram"
|
||||
assert w["chat_id"] == "123"
|
||||
assert w["thread_id"] == "42"
|
||||
assert w["check_interval"] == 60
|
||||
|
||||
def test_recover_skips_watcher_when_no_interval(self, registry, tmp_path):
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
checkpoint.write_text(json.dumps([{
|
||||
"session_id": "proc_live",
|
||||
"command": "sleep 999",
|
||||
"pid": os.getpid(),
|
||||
"task_id": "t1",
|
||||
"watcher_interval": 0,
|
||||
}]))
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 1
|
||||
assert len(registry.pending_watchers) == 0
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Kill process
|
||||
|
||||
@@ -25,7 +25,7 @@ def _make_config():
|
||||
|
||||
|
||||
def _install_telegram_mock(monkeypatch, bot):
|
||||
parse_mode = SimpleNamespace(MARKDOWN_V2="MarkdownV2", HTML="HTML")
|
||||
parse_mode = SimpleNamespace(MARKDOWN_V2="MarkdownV2")
|
||||
constants_mod = SimpleNamespace(ParseMode=parse_mode)
|
||||
telegram_mod = SimpleNamespace(Bot=lambda token: bot, constants=constants_mod)
|
||||
monkeypatch.setitem(sys.modules, "telegram", telegram_mod)
|
||||
@@ -391,97 +391,3 @@ class TestSendToPlatformChunking:
|
||||
assert len(sent_calls) >= 3
|
||||
assert all(call == [] for call in sent_calls[:-1])
|
||||
assert sent_calls[-1] == media
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTML auto-detection in Telegram send
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendTelegramHtmlDetection:
|
||||
"""Verify that messages containing HTML tags are sent with parse_mode=HTML
|
||||
and that plain / markdown messages use MarkdownV2."""
|
||||
|
||||
def _make_bot(self):
|
||||
bot = MagicMock()
|
||||
bot.send_message = AsyncMock(return_value=SimpleNamespace(message_id=1))
|
||||
bot.send_photo = AsyncMock()
|
||||
bot.send_video = AsyncMock()
|
||||
bot.send_voice = AsyncMock()
|
||||
bot.send_audio = AsyncMock()
|
||||
bot.send_document = AsyncMock()
|
||||
return bot
|
||||
|
||||
def test_html_message_uses_html_parse_mode(self, monkeypatch):
|
||||
bot = self._make_bot()
|
||||
_install_telegram_mock(monkeypatch, bot)
|
||||
|
||||
asyncio.run(
|
||||
_send_telegram("tok", "123", "<b>Hello</b> world")
|
||||
)
|
||||
|
||||
bot.send_message.assert_awaited_once()
|
||||
kwargs = bot.send_message.await_args.kwargs
|
||||
assert kwargs["parse_mode"] == "HTML"
|
||||
assert kwargs["text"] == "<b>Hello</b> world"
|
||||
|
||||
def test_plain_text_uses_markdown_v2(self, monkeypatch):
|
||||
bot = self._make_bot()
|
||||
_install_telegram_mock(monkeypatch, bot)
|
||||
|
||||
asyncio.run(
|
||||
_send_telegram("tok", "123", "Just plain text, no tags")
|
||||
)
|
||||
|
||||
bot.send_message.assert_awaited_once()
|
||||
kwargs = bot.send_message.await_args.kwargs
|
||||
assert kwargs["parse_mode"] == "MarkdownV2"
|
||||
|
||||
def test_html_with_code_and_pre_tags(self, monkeypatch):
|
||||
bot = self._make_bot()
|
||||
_install_telegram_mock(monkeypatch, bot)
|
||||
|
||||
html = "<pre>code block</pre> and <code>inline</code>"
|
||||
asyncio.run(_send_telegram("tok", "123", html))
|
||||
|
||||
kwargs = bot.send_message.await_args.kwargs
|
||||
assert kwargs["parse_mode"] == "HTML"
|
||||
|
||||
def test_closing_tag_detected(self, monkeypatch):
|
||||
bot = self._make_bot()
|
||||
_install_telegram_mock(monkeypatch, bot)
|
||||
|
||||
asyncio.run(_send_telegram("tok", "123", "text </div> more"))
|
||||
|
||||
kwargs = bot.send_message.await_args.kwargs
|
||||
assert kwargs["parse_mode"] == "HTML"
|
||||
|
||||
def test_angle_brackets_in_math_not_detected(self, monkeypatch):
|
||||
"""Expressions like 'x < 5' or '3 > 2' should not trigger HTML mode."""
|
||||
bot = self._make_bot()
|
||||
_install_telegram_mock(monkeypatch, bot)
|
||||
|
||||
asyncio.run(_send_telegram("tok", "123", "if x < 5 then y > 2"))
|
||||
|
||||
kwargs = bot.send_message.await_args.kwargs
|
||||
assert kwargs["parse_mode"] == "MarkdownV2"
|
||||
|
||||
def test_html_parse_failure_falls_back_to_plain(self, monkeypatch):
|
||||
"""If Telegram rejects the HTML, fall back to plain text."""
|
||||
bot = self._make_bot()
|
||||
bot.send_message = AsyncMock(
|
||||
side_effect=[
|
||||
Exception("Bad Request: can't parse entities: unsupported html tag"),
|
||||
SimpleNamespace(message_id=2), # plain fallback succeeds
|
||||
]
|
||||
)
|
||||
_install_telegram_mock(monkeypatch, bot)
|
||||
|
||||
result = asyncio.run(
|
||||
_send_telegram("tok", "123", "<invalid>broken html</invalid>")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert bot.send_message.await_count == 2
|
||||
second_call = bot.send_message.await_args_list[1].kwargs
|
||||
assert second_call["parse_mode"] is None
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
"""Tests for web backend client configuration and singleton behavior.
|
||||
"""Tests for Firecrawl client configuration and singleton behavior.
|
||||
|
||||
Coverage:
|
||||
_get_firecrawl_client() — configuration matrix, singleton caching,
|
||||
constructor failure recovery, return value verification, edge cases.
|
||||
_get_backend() — backend selection logic with env var combinations.
|
||||
_get_parallel_client() — Parallel client configuration, singleton caching.
|
||||
check_web_api_key() — unified availability check.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -120,157 +117,3 @@ class TestFirecrawlClientConfig:
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
with pytest.raises(ValueError):
|
||||
_get_firecrawl_client()
|
||||
|
||||
|
||||
class TestBackendSelection:
|
||||
"""Test suite for _get_backend() backend selection logic.
|
||||
|
||||
The backend is configured via config.yaml (web.backend), set by
|
||||
``hermes tools``. Falls back to key-based detection for legacy/manual
|
||||
setups.
|
||||
"""
|
||||
|
||||
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL")
|
||||
|
||||
def setup_method(self):
|
||||
for key in self._ENV_KEYS:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def teardown_method(self):
|
||||
for key in self._ENV_KEYS:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# ── Config-based selection (web.backend in config.yaml) ───────────
|
||||
|
||||
def test_config_parallel(self):
|
||||
"""web.backend=parallel in config → 'parallel' regardless of keys."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "parallel"}):
|
||||
assert _get_backend() == "parallel"
|
||||
|
||||
def test_config_firecrawl(self):
|
||||
"""web.backend=firecrawl in config → 'firecrawl' even if Parallel key set."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "firecrawl"}), \
|
||||
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
assert _get_backend() == "firecrawl"
|
||||
|
||||
def test_config_case_insensitive(self):
|
||||
"""web.backend=Parallel (mixed case) → 'parallel'."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "Parallel"}):
|
||||
assert _get_backend() == "parallel"
|
||||
|
||||
# ── Fallback (no web.backend in config) ───────────────────────────
|
||||
|
||||
def test_fallback_parallel_only_key(self):
|
||||
"""Only PARALLEL_API_KEY set → 'parallel'."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={}), \
|
||||
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
assert _get_backend() == "parallel"
|
||||
|
||||
def test_fallback_both_keys_defaults_to_firecrawl(self):
|
||||
"""Both keys set, no config → 'firecrawl' (backward compat)."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={}), \
|
||||
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key", "FIRECRAWL_API_KEY": "fc-test"}):
|
||||
assert _get_backend() == "firecrawl"
|
||||
|
||||
def test_fallback_firecrawl_only_key(self):
|
||||
"""Only FIRECRAWL_API_KEY set → 'firecrawl'."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={}), \
|
||||
patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
|
||||
assert _get_backend() == "firecrawl"
|
||||
|
||||
def test_fallback_no_keys_defaults_to_firecrawl(self):
|
||||
"""No keys, no config → 'firecrawl' (will fail at client init)."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={}):
|
||||
assert _get_backend() == "firecrawl"
|
||||
|
||||
def test_invalid_config_falls_through_to_fallback(self):
|
||||
"""web.backend=invalid → ignored, uses key-based fallback."""
|
||||
from tools.web_tools import _get_backend
|
||||
with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}), \
|
||||
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
assert _get_backend() == "parallel"
|
||||
|
||||
|
||||
class TestParallelClientConfig:
|
||||
"""Test suite for Parallel client initialization."""
|
||||
|
||||
def setup_method(self):
|
||||
import tools.web_tools
|
||||
tools.web_tools._parallel_client = None
|
||||
os.environ.pop("PARALLEL_API_KEY", None)
|
||||
|
||||
def teardown_method(self):
|
||||
import tools.web_tools
|
||||
tools.web_tools._parallel_client = None
|
||||
os.environ.pop("PARALLEL_API_KEY", None)
|
||||
|
||||
def test_creates_client_with_key(self):
|
||||
"""PARALLEL_API_KEY set → creates Parallel client."""
|
||||
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
from tools.web_tools import _get_parallel_client
|
||||
from parallel import Parallel
|
||||
client = _get_parallel_client()
|
||||
assert client is not None
|
||||
assert isinstance(client, Parallel)
|
||||
|
||||
def test_no_key_raises_with_helpful_message(self):
|
||||
"""No PARALLEL_API_KEY → ValueError with guidance."""
|
||||
from tools.web_tools import _get_parallel_client
|
||||
with pytest.raises(ValueError, match="PARALLEL_API_KEY"):
|
||||
_get_parallel_client()
|
||||
|
||||
def test_singleton_returns_same_instance(self):
|
||||
"""Second call returns cached client."""
|
||||
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
from tools.web_tools import _get_parallel_client
|
||||
client1 = _get_parallel_client()
|
||||
client2 = _get_parallel_client()
|
||||
assert client1 is client2
|
||||
|
||||
|
||||
class TestCheckWebApiKey:
|
||||
"""Test suite for check_web_api_key() unified availability check."""
|
||||
|
||||
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL")
|
||||
|
||||
def setup_method(self):
|
||||
for key in self._ENV_KEYS:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def teardown_method(self):
|
||||
for key in self._ENV_KEYS:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_parallel_key_only(self):
|
||||
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is True
|
||||
|
||||
def test_firecrawl_key_only(self):
|
||||
with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is True
|
||||
|
||||
def test_firecrawl_url_only(self):
|
||||
with patch.dict(os.environ, {"FIRECRAWL_API_URL": "http://localhost:3002"}):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is True
|
||||
|
||||
def test_no_keys_returns_false(self):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is False
|
||||
|
||||
def test_both_keys_returns_true(self):
|
||||
with patch.dict(os.environ, {
|
||||
"PARALLEL_API_KEY": "test-key",
|
||||
"FIRECRAWL_API_KEY": "fc-test",
|
||||
}):
|
||||
from tools.web_tools import check_web_api_key
|
||||
assert check_web_api_key() is True
|
||||
|
||||
@@ -426,8 +426,6 @@ async def test_web_extract_blocks_redirected_final_url(monkeypatch):
|
||||
async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
|
||||
# web_crawl_tool checks for Firecrawl env before website policy
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
monkeypatch.setattr(
|
||||
web_tools,
|
||||
"check_website_access",
|
||||
@@ -455,9 +453,6 @@ async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
|
||||
async def test_web_crawl_blocks_redirected_final_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
|
||||
# web_crawl_tool checks for Firecrawl env before website policy
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test":
|
||||
return None
|
||||
|
||||
@@ -555,11 +555,6 @@ def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
||||
session_info = provider.create_session(task_id)
|
||||
|
||||
with _cleanup_lock:
|
||||
# Double-check: another thread may have created a session while we
|
||||
# were doing the network call. Use the existing one to avoid leaking
|
||||
# orphan cloud sessions.
|
||||
if task_id in _active_sessions:
|
||||
return _active_sessions[task_id]
|
||||
_active_sessions[task_id] = session_info
|
||||
|
||||
return session_info
|
||||
|
||||
@@ -82,9 +82,6 @@ def _build_provider_env_blocklist() -> frozenset:
|
||||
"FIREWORKS_API_KEY", # Fireworks AI
|
||||
"XAI_API_KEY", # xAI (Grok)
|
||||
"HELICONE_API_KEY", # LLM Observability proxy
|
||||
"PARALLEL_API_KEY",
|
||||
"FIRECRAWL_API_KEY",
|
||||
"FIRECRAWL_API_URL",
|
||||
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||
|
||||
@@ -1624,72 +1624,6 @@ def get_mcp_status() -> List[dict]:
|
||||
return result
|
||||
|
||||
|
||||
def probe_mcp_server_tools() -> Dict[str, List[tuple]]:
|
||||
"""Temporarily connect to configured MCP servers and list their tools.
|
||||
|
||||
Designed for ``hermes tools`` interactive configuration — connects to each
|
||||
enabled server, grabs tool names and descriptions, then disconnects.
|
||||
Does NOT register tools in the Hermes registry.
|
||||
|
||||
Returns:
|
||||
Dict mapping server name to list of (tool_name, description) tuples.
|
||||
Servers that fail to connect are omitted from the result.
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
return {}
|
||||
|
||||
servers_config = _load_mcp_config()
|
||||
if not servers_config:
|
||||
return {}
|
||||
|
||||
enabled = {
|
||||
k: v for k, v in servers_config.items()
|
||||
if _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
if not enabled:
|
||||
return {}
|
||||
|
||||
_ensure_mcp_loop()
|
||||
|
||||
result: Dict[str, List[tuple]] = {}
|
||||
probed_servers: List[MCPServerTask] = []
|
||||
|
||||
async def _probe_all():
|
||||
names = list(enabled.keys())
|
||||
coros = []
|
||||
for name, cfg in enabled.items():
|
||||
ct = cfg.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
coros.append(asyncio.wait_for(_connect_server(name, cfg), timeout=ct))
|
||||
|
||||
outcomes = await asyncio.gather(*coros, return_exceptions=True)
|
||||
|
||||
for name, outcome in zip(names, outcomes):
|
||||
if isinstance(outcome, Exception):
|
||||
logger.debug("Probe: failed to connect to '%s': %s", name, outcome)
|
||||
continue
|
||||
probed_servers.append(outcome)
|
||||
tools = []
|
||||
for t in outcome._tools:
|
||||
desc = getattr(t, "description", "") or ""
|
||||
tools.append((t.name, desc))
|
||||
result[name] = tools
|
||||
|
||||
# Shut down all probed connections
|
||||
await asyncio.gather(
|
||||
*(s.shutdown() for s in probed_servers),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
try:
|
||||
_run_on_mcp_loop(_probe_all(), timeout=120)
|
||||
except Exception as exc:
|
||||
logger.debug("MCP probe failed: %s", exc)
|
||||
finally:
|
||||
_stop_mcp_loop()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def shutdown_mcp_servers():
|
||||
"""Close all MCP server connections and stop the background loop.
|
||||
|
||||
|
||||
@@ -23,13 +23,11 @@ Design:
|
||||
- Frozen snapshot pattern: system prompt is stable, tool responses show live state
|
||||
"""
|
||||
|
||||
import fcntl
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
@@ -122,43 +120,14 @@ class MemoryStore:
|
||||
"user": self._render_block("user", self.user_entries),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _file_lock(path: Path):
|
||||
"""Acquire an exclusive file lock for read-modify-write safety.
|
||||
|
||||
Uses a separate .lock file so the memory file itself can still be
|
||||
atomically replaced via os.replace().
|
||||
"""
|
||||
lock_path = path.with_suffix(path.suffix + ".lock")
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = open(lock_path, "w")
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
fd.close()
|
||||
|
||||
@staticmethod
|
||||
def _path_for(target: str) -> Path:
|
||||
if target == "user":
|
||||
return MEMORY_DIR / "USER.md"
|
||||
return MEMORY_DIR / "MEMORY.md"
|
||||
|
||||
def _reload_target(self, target: str):
|
||||
"""Re-read entries from disk into in-memory state.
|
||||
|
||||
Called under file lock to get the latest state before mutating.
|
||||
"""
|
||||
fresh = self._read_file(self._path_for(target))
|
||||
fresh = list(dict.fromkeys(fresh)) # deduplicate
|
||||
self._set_entries(target, fresh)
|
||||
|
||||
def save_to_disk(self, target: str):
|
||||
"""Persist entries to the appropriate file. Called after every mutation."""
|
||||
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
self._write_file(self._path_for(target), self._entries_for(target))
|
||||
|
||||
if target == "memory":
|
||||
self._write_file(MEMORY_DIR / "MEMORY.md", self.memory_entries)
|
||||
elif target == "user":
|
||||
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
|
||||
|
||||
def _entries_for(self, target: str) -> List[str]:
|
||||
if target == "user":
|
||||
@@ -193,37 +162,33 @@ class MemoryStore:
|
||||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
# Re-read from disk under lock to pick up writes from other sessions
|
||||
self._reload_target(target)
|
||||
entries = self._entries_for(target)
|
||||
limit = self._char_limit(target)
|
||||
|
||||
entries = self._entries_for(target)
|
||||
limit = self._char_limit(target)
|
||||
# Reject exact duplicates
|
||||
if content in entries:
|
||||
return self._success_response(target, "Entry already exists (no duplicate added).")
|
||||
|
||||
# Reject exact duplicates
|
||||
if content in entries:
|
||||
return self._success_response(target, "Entry already exists (no duplicate added).")
|
||||
# Calculate what the new total would be
|
||||
new_entries = entries + [content]
|
||||
new_total = len(ENTRY_DELIMITER.join(new_entries))
|
||||
|
||||
# Calculate what the new total would be
|
||||
new_entries = entries + [content]
|
||||
new_total = len(ENTRY_DELIMITER.join(new_entries))
|
||||
if new_total > limit:
|
||||
current = self._char_count(target)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Memory at {current:,}/{limit:,} chars. "
|
||||
f"Adding this entry ({len(content)} chars) would exceed the limit. "
|
||||
f"Replace or remove existing entries first."
|
||||
),
|
||||
"current_entries": entries,
|
||||
"usage": f"{current:,}/{limit:,}",
|
||||
}
|
||||
|
||||
if new_total > limit:
|
||||
current = self._char_count(target)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Memory at {current:,}/{limit:,} chars. "
|
||||
f"Adding this entry ({len(content)} chars) would exceed the limit. "
|
||||
f"Replace or remove existing entries first."
|
||||
),
|
||||
"current_entries": entries,
|
||||
"usage": f"{current:,}/{limit:,}",
|
||||
}
|
||||
|
||||
entries.append(content)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
entries.append(content)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry added.")
|
||||
|
||||
@@ -241,47 +206,44 @@ class MemoryStore:
|
||||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), operate on the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to replace just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Check that replacement doesn't blow the budget
|
||||
test_entries = entries.copy()
|
||||
test_entries[idx] = new_content
|
||||
new_total = len(ENTRY_DELIMITER.join(test_entries))
|
||||
|
||||
if new_total > limit:
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), operate on the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
|
||||
f"Shorten the new content or remove other entries first."
|
||||
),
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to replace just the first
|
||||
|
||||
entries[idx] = new_content
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
idx = matches[0][0]
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Check that replacement doesn't blow the budget
|
||||
test_entries = entries.copy()
|
||||
test_entries[idx] = new_content
|
||||
new_total = len(ENTRY_DELIMITER.join(test_entries))
|
||||
|
||||
if new_total > limit:
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
|
||||
f"Shorten the new content or remove other entries first."
|
||||
),
|
||||
}
|
||||
|
||||
entries[idx] = new_content
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry replaced.")
|
||||
|
||||
@@ -291,31 +253,28 @@ class MemoryStore:
|
||||
if not old_text:
|
||||
return {"success": False, "error": "old_text cannot be empty."}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), remove the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to remove just the first
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), remove the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to remove just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
entries.pop(idx)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
idx = matches[0][0]
|
||||
entries.pop(idx)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry removed.")
|
||||
|
||||
|
||||
@@ -78,11 +78,6 @@ class ProcessSession:
|
||||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||
max_output_chars: int = MAX_OUTPUT_CHARS
|
||||
detached: bool = False # True if recovered from crash (no pipe)
|
||||
# Watcher/notification metadata (persisted for crash recovery)
|
||||
watcher_platform: str = ""
|
||||
watcher_chat_id: str = ""
|
||||
watcher_thread_id: str = ""
|
||||
watcher_interval: int = 0 # 0 = no watcher configured
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
||||
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
||||
@@ -714,10 +709,6 @@ class ProcessRegistry:
|
||||
"started_at": s.started_at,
|
||||
"task_id": s.task_id,
|
||||
"session_key": s.session_key,
|
||||
"watcher_platform": s.watcher_platform,
|
||||
"watcher_chat_id": s.watcher_chat_id,
|
||||
"watcher_thread_id": s.watcher_thread_id,
|
||||
"watcher_interval": s.watcher_interval,
|
||||
})
|
||||
|
||||
# Atomic write to avoid corruption on crash
|
||||
@@ -764,27 +755,12 @@ class ProcessRegistry:
|
||||
cwd=entry.get("cwd"),
|
||||
started_at=entry.get("started_at", time.time()),
|
||||
detached=True, # Can't read output, but can report status + kill
|
||||
watcher_platform=entry.get("watcher_platform", ""),
|
||||
watcher_chat_id=entry.get("watcher_chat_id", ""),
|
||||
watcher_thread_id=entry.get("watcher_thread_id", ""),
|
||||
watcher_interval=entry.get("watcher_interval", 0),
|
||||
)
|
||||
with self._lock:
|
||||
self._running[session.id] = session
|
||||
recovered += 1
|
||||
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
|
||||
|
||||
# Re-enqueue watcher so gateway can resume notifications
|
||||
if session.watcher_interval > 0:
|
||||
self.pending_watchers.append({
|
||||
"session_id": session.id,
|
||||
"check_interval": session.watcher_interval,
|
||||
"session_key": session.session_key,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
})
|
||||
|
||||
# Clear the checkpoint (will be rewritten as processes finish)
|
||||
try:
|
||||
from utils import atomic_json_write
|
||||
|
||||
@@ -355,31 +355,20 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
|
||||
"""Send via Telegram Bot API (one-shot, no polling needed).
|
||||
|
||||
Applies markdown→MarkdownV2 formatting (same as the gateway adapter)
|
||||
so that bold, links, and headers render correctly. If the message
|
||||
already contains HTML tags, it is sent with ``parse_mode='HTML'``
|
||||
instead, bypassing MarkdownV2 conversion.
|
||||
so that bold, links, and headers render correctly.
|
||||
"""
|
||||
try:
|
||||
from telegram import Bot
|
||||
from telegram.constants import ParseMode
|
||||
|
||||
# Auto-detect HTML tags — if present, skip MarkdownV2 and send as HTML.
|
||||
# Inspired by github.com/ashaney — PR #1568.
|
||||
_has_html = bool(re.search(r'<[a-zA-Z/][^>]*>', message))
|
||||
|
||||
if _has_html:
|
||||
# Reuse the gateway adapter's format_message for markdown→MarkdownV2
|
||||
try:
|
||||
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2
|
||||
_adapter = TelegramAdapter.__new__(TelegramAdapter)
|
||||
formatted = _adapter.format_message(message)
|
||||
except Exception:
|
||||
# Fallback: send as-is if formatting unavailable
|
||||
formatted = message
|
||||
send_parse_mode = ParseMode.HTML
|
||||
else:
|
||||
# Reuse the gateway adapter's format_message for markdown→MarkdownV2
|
||||
try:
|
||||
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2
|
||||
_adapter = TelegramAdapter.__new__(TelegramAdapter)
|
||||
formatted = _adapter.format_message(message)
|
||||
except Exception:
|
||||
# Fallback: send as-is if formatting unavailable
|
||||
formatted = message
|
||||
send_parse_mode = ParseMode.MARKDOWN_V2
|
||||
|
||||
bot = Bot(token=token)
|
||||
int_chat_id = int(chat_id)
|
||||
@@ -395,19 +384,16 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
|
||||
try:
|
||||
last_msg = await bot.send_message(
|
||||
chat_id=int_chat_id, text=formatted,
|
||||
parse_mode=send_parse_mode, **thread_kwargs
|
||||
parse_mode=ParseMode.MARKDOWN_V2, **thread_kwargs
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Parse failed, fall back to plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower() or "html" in str(md_error).lower():
|
||||
logger.warning("Parse mode %s failed in _send_telegram, falling back to plain text: %s", send_parse_mode, md_error)
|
||||
if not _has_html:
|
||||
try:
|
||||
from gateway.platforms.telegram import _strip_mdv2
|
||||
plain = _strip_mdv2(formatted)
|
||||
except Exception:
|
||||
plain = message
|
||||
else:
|
||||
# MarkdownV2 failed, fall back to plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("MarkdownV2 parse failed in _send_telegram, falling back to plain text: %s", md_error)
|
||||
try:
|
||||
from gateway.platforms.telegram import _strip_mdv2
|
||||
plain = _strip_mdv2(formatted)
|
||||
except Exception:
|
||||
plain = message
|
||||
last_msg = await bot.send_message(
|
||||
chat_id=int_chat_id, text=plain,
|
||||
|
||||
@@ -1082,23 +1082,13 @@ def terminal_tool(
|
||||
result_data["check_interval_note"] = (
|
||||
f"Requested {check_interval}s raised to minimum 30s"
|
||||
)
|
||||
watcher_platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
watcher_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "")
|
||||
watcher_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "")
|
||||
|
||||
# Store on session for checkpoint persistence
|
||||
proc_session.watcher_platform = watcher_platform
|
||||
proc_session.watcher_chat_id = watcher_chat_id
|
||||
proc_session.watcher_thread_id = watcher_thread_id
|
||||
proc_session.watcher_interval = effective_interval
|
||||
|
||||
process_registry.pending_watchers.append({
|
||||
"session_id": proc_session.id,
|
||||
"check_interval": effective_interval,
|
||||
"session_key": session_key,
|
||||
"platform": watcher_platform,
|
||||
"chat_id": watcher_chat_id,
|
||||
"thread_id": watcher_thread_id,
|
||||
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
|
||||
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
|
||||
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID", ""),
|
||||
})
|
||||
|
||||
return json.dumps(result_data, ensure_ascii=False)
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
Standalone Web Tools Module
|
||||
|
||||
This module provides generic web tools that work with multiple backend providers.
|
||||
Backend is selected during ``hermes tools`` setup (web.backend in config.yaml).
|
||||
Currently uses Firecrawl as the backend, and the interface makes it easy to swap
|
||||
providers without changing the function signatures.
|
||||
|
||||
Available tools:
|
||||
- web_search_tool: Search the web for information
|
||||
- web_extract_tool: Extract content from specific web pages
|
||||
- web_crawl_tool: Crawl websites with specific instructions (Firecrawl only)
|
||||
- web_crawl_tool: Crawl websites with specific instructions
|
||||
|
||||
Backend compatibility:
|
||||
- Firecrawl: https://docs.firecrawl.dev/introduction (search, extract, crawl)
|
||||
- Parallel: https://docs.parallel.ai (search, extract)
|
||||
- Firecrawl: https://docs.firecrawl.dev/introduction
|
||||
|
||||
LLM Processing:
|
||||
- Uses OpenRouter API with Gemini 3 Flash Preview for intelligent content extraction
|
||||
@@ -53,39 +53,6 @@ from tools.website_policy import check_website_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── Backend Selection ────────────────────────────────────────────────────────
|
||||
|
||||
def _load_web_config() -> dict:
|
||||
"""Load the ``web:`` section from ~/.hermes/config.yaml."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
return load_config().get("web", {})
|
||||
except (ImportError, Exception):
|
||||
return {}
|
||||
|
||||
|
||||
def _get_backend() -> str:
|
||||
"""Determine which web backend to use.
|
||||
|
||||
Reads ``web.backend`` from config.yaml (set by ``hermes tools``).
|
||||
Falls back to whichever API key is present for users who configured
|
||||
keys manually without running setup.
|
||||
"""
|
||||
configured = _load_web_config().get("backend", "").lower().strip()
|
||||
if configured in ("parallel", "firecrawl"):
|
||||
return configured
|
||||
# Fallback for manual / legacy config — use whichever key is present.
|
||||
has_firecrawl = bool(os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL"))
|
||||
has_parallel = bool(os.getenv("PARALLEL_API_KEY"))
|
||||
if has_parallel and not has_firecrawl:
|
||||
return "parallel"
|
||||
# Default to firecrawl (backward compat, or when both are set)
|
||||
return "firecrawl"
|
||||
|
||||
|
||||
# ─── Firecrawl Client ────────────────────────────────────────────────────────
|
||||
|
||||
_firecrawl_client = None
|
||||
|
||||
def _get_firecrawl_client():
|
||||
@@ -114,47 +81,6 @@ def _get_firecrawl_client():
|
||||
_firecrawl_client = Firecrawl(**kwargs)
|
||||
return _firecrawl_client
|
||||
|
||||
|
||||
# ─── Parallel Client ─────────────────────────────────────────────────────────
|
||||
|
||||
_parallel_client = None
|
||||
_async_parallel_client = None
|
||||
|
||||
def _get_parallel_client():
|
||||
"""Get or create the Parallel sync client (lazy initialization).
|
||||
|
||||
Requires PARALLEL_API_KEY environment variable.
|
||||
"""
|
||||
from parallel import Parallel
|
||||
global _parallel_client
|
||||
if _parallel_client is None:
|
||||
api_key = os.getenv("PARALLEL_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"PARALLEL_API_KEY environment variable not set. "
|
||||
"Get your API key at https://parallel.ai"
|
||||
)
|
||||
_parallel_client = Parallel(api_key=api_key)
|
||||
return _parallel_client
|
||||
|
||||
|
||||
def _get_async_parallel_client():
|
||||
"""Get or create the Parallel async client (lazy initialization).
|
||||
|
||||
Requires PARALLEL_API_KEY environment variable.
|
||||
"""
|
||||
from parallel import AsyncParallel
|
||||
global _async_parallel_client
|
||||
if _async_parallel_client is None:
|
||||
api_key = os.getenv("PARALLEL_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"PARALLEL_API_KEY environment variable not set. "
|
||||
"Get your API key at https://parallel.ai"
|
||||
)
|
||||
_async_parallel_client = AsyncParallel(api_key=api_key)
|
||||
return _async_parallel_client
|
||||
|
||||
DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000
|
||||
|
||||
# Allow per-task override via env var
|
||||
@@ -502,89 +428,13 @@ def clean_base64_images(text: str) -> str:
|
||||
return cleaned_text
|
||||
|
||||
|
||||
# ─── Parallel Search & Extract Helpers ────────────────────────────────────────
|
||||
|
||||
def _parallel_search(query: str, limit: int = 5) -> dict:
|
||||
"""Search using the Parallel SDK and return results as a dict."""
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return {"error": "Interrupted", "success": False}
|
||||
|
||||
mode = os.getenv("PARALLEL_SEARCH_MODE", "agentic").lower().strip()
|
||||
if mode not in ("fast", "one-shot", "agentic"):
|
||||
mode = "agentic"
|
||||
|
||||
logger.info("Parallel search: '%s' (mode=%s, limit=%d)", query, mode, limit)
|
||||
response = _get_parallel_client().beta.search(
|
||||
search_queries=[query],
|
||||
objective=query,
|
||||
mode=mode,
|
||||
max_results=min(limit, 20),
|
||||
)
|
||||
|
||||
web_results = []
|
||||
for i, result in enumerate(response.results or []):
|
||||
excerpts = result.excerpts or []
|
||||
web_results.append({
|
||||
"url": result.url or "",
|
||||
"title": result.title or "",
|
||||
"description": " ".join(excerpts) if excerpts else "",
|
||||
"position": i + 1,
|
||||
})
|
||||
|
||||
return {"success": True, "data": {"web": web_results}}
|
||||
|
||||
|
||||
async def _parallel_extract(urls: List[str]) -> List[Dict[str, Any]]:
|
||||
"""Extract content from URLs using the Parallel async SDK.
|
||||
|
||||
Returns a list of result dicts matching the structure expected by the
|
||||
LLM post-processing pipeline (url, title, content, metadata).
|
||||
"""
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return [{"url": u, "error": "Interrupted", "title": ""} for u in urls]
|
||||
|
||||
logger.info("Parallel extract: %d URL(s)", len(urls))
|
||||
response = await _get_async_parallel_client().beta.extract(
|
||||
urls=urls,
|
||||
full_content=True,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in response.results or []:
|
||||
content = result.full_content or ""
|
||||
if not content:
|
||||
content = "\n\n".join(result.excerpts or [])
|
||||
url = result.url or ""
|
||||
title = result.title or ""
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"raw_content": content,
|
||||
"metadata": {"sourceURL": url, "title": title},
|
||||
})
|
||||
|
||||
for error in response.errors or []:
|
||||
results.append({
|
||||
"url": error.url or "",
|
||||
"title": "",
|
||||
"content": "",
|
||||
"error": error.content or error.error_type or "extraction failed",
|
||||
"metadata": {"sourceURL": error.url or ""},
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
"""
|
||||
Search the web for information using available search API backend.
|
||||
|
||||
|
||||
This function provides a generic interface for web search that can work
|
||||
with multiple backends (Parallel or Firecrawl).
|
||||
|
||||
with multiple backends. Currently uses Firecrawl.
|
||||
|
||||
Note: This function returns search result metadata only (URLs, titles, descriptions).
|
||||
Use web_extract_tool to get full content from specific URLs.
|
||||
|
||||
@@ -628,28 +478,17 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
if is_interrupted():
|
||||
return json.dumps({"error": "Interrupted", "success": False})
|
||||
|
||||
# Dispatch to the configured backend
|
||||
backend = _get_backend()
|
||||
if backend == "parallel":
|
||||
response_data = _parallel_search(query, limit)
|
||||
debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", []))
|
||||
result_json = json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
debug_call_data["final_response_size"] = len(result_json)
|
||||
_debug.log_call("web_search_tool", debug_call_data)
|
||||
_debug.save()
|
||||
return result_json
|
||||
|
||||
logger.info("Searching the web for: '%s' (limit: %d)", query, limit)
|
||||
|
||||
|
||||
response = _get_firecrawl_client().search(
|
||||
query=query,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
|
||||
# The response is a SearchData object with web, news, and images attributes
|
||||
# When not scraping, the results are directly in these attributes
|
||||
web_results = []
|
||||
|
||||
|
||||
# Check if response has web attribute (SearchData object)
|
||||
if hasattr(response, 'web'):
|
||||
# Response is a SearchData object with web attribute
|
||||
@@ -757,130 +596,123 @@ async def web_extract_tool(
|
||||
|
||||
try:
|
||||
logger.info("Extracting content from %d URL(s)", len(urls))
|
||||
|
||||
# Dispatch to the configured backend
|
||||
backend = _get_backend()
|
||||
|
||||
if backend == "parallel":
|
||||
results = await _parallel_extract(urls)
|
||||
|
||||
# Determine requested formats for Firecrawl v2
|
||||
formats: List[str] = []
|
||||
if format == "markdown":
|
||||
formats = ["markdown"]
|
||||
elif format == "html":
|
||||
formats = ["html"]
|
||||
else:
|
||||
# ── Firecrawl extraction ──
|
||||
# Determine requested formats for Firecrawl v2
|
||||
formats: List[str] = []
|
||||
if format == "markdown":
|
||||
formats = ["markdown"]
|
||||
elif format == "html":
|
||||
formats = ["html"]
|
||||
else:
|
||||
# Default: request markdown for LLM-readiness and include html as backup
|
||||
formats = ["markdown", "html"]
|
||||
# Default: request markdown for LLM-readiness and include html as backup
|
||||
formats = ["markdown", "html"]
|
||||
|
||||
# Always use individual scraping for simplicity and reliability
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
|
||||
# Always use individual scraping for simplicity and reliability
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
# Website policy check — block before fetching
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
results.append({
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
|
||||
# Website policy check — block before fetching
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
|
||||
# Process the result - properly handle object serialization
|
||||
metadata = {}
|
||||
title = ""
|
||||
content_markdown = None
|
||||
content_html = None
|
||||
|
||||
# Extract data from the scrape result
|
||||
if hasattr(scrape_result, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
result_dict = scrape_result.model_dump()
|
||||
content_markdown = result_dict.get('markdown')
|
||||
content_html = result_dict.get('html')
|
||||
metadata = result_dict.get('metadata', {})
|
||||
elif hasattr(scrape_result, '__dict__'):
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(scrape_result, 'markdown', None)
|
||||
content_html = getattr(scrape_result, 'html', None)
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(scrape_result, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(scrape_result, dict):
|
||||
# Already a dictionary
|
||||
content_markdown = scrape_result.get('markdown')
|
||||
content_html = scrape_result.get('html')
|
||||
metadata = scrape_result.get('metadata', {})
|
||||
|
||||
# Ensure metadata is a dict (not an object)
|
||||
if not isinstance(metadata, dict):
|
||||
if hasattr(metadata, 'model_dump'):
|
||||
metadata = metadata.model_dump()
|
||||
elif hasattr(metadata, '__dict__'):
|
||||
metadata = metadata.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Re-check final URL after redirect
|
||||
final_url = metadata.get("sourceURL", url)
|
||||
final_blocked = check_website_access(final_url)
|
||||
if final_blocked:
|
||||
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
|
||||
results.append({
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
|
||||
"url": final_url, "title": title, "content": "", "raw_content": "",
|
||||
"error": final_blocked["message"],
|
||||
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
|
||||
# Process the result - properly handle object serialization
|
||||
metadata = {}
|
||||
title = ""
|
||||
content_markdown = None
|
||||
content_html = None
|
||||
|
||||
# Extract data from the scrape result
|
||||
if hasattr(scrape_result, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
result_dict = scrape_result.model_dump()
|
||||
content_markdown = result_dict.get('markdown')
|
||||
content_html = result_dict.get('html')
|
||||
metadata = result_dict.get('metadata', {})
|
||||
elif hasattr(scrape_result, '__dict__'):
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(scrape_result, 'markdown', None)
|
||||
content_html = getattr(scrape_result, 'html', None)
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(scrape_result, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(scrape_result, dict):
|
||||
# Already a dictionary
|
||||
content_markdown = scrape_result.get('markdown')
|
||||
content_html = scrape_result.get('html')
|
||||
metadata = scrape_result.get('metadata', {})
|
||||
|
||||
# Ensure metadata is a dict (not an object)
|
||||
if not isinstance(metadata, dict):
|
||||
if hasattr(metadata, 'model_dump'):
|
||||
metadata = metadata.model_dump()
|
||||
elif hasattr(metadata, '__dict__'):
|
||||
metadata = metadata.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Re-check final URL after redirect
|
||||
final_url = metadata.get("sourceURL", url)
|
||||
final_blocked = check_website_access(final_url)
|
||||
if final_blocked:
|
||||
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
|
||||
results.append({
|
||||
"url": final_url, "title": title, "content": "", "raw_content": "",
|
||||
"error": final_blocked["message"],
|
||||
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
# Choose content based on requested format
|
||||
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
|
||||
|
||||
results.append({
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"content": chosen_content,
|
||||
"raw_content": chosen_content,
|
||||
"metadata": metadata # Now guaranteed to be a dict
|
||||
})
|
||||
|
||||
except Exception as scrape_err:
|
||||
logger.debug("Scrape failed for %s: %s", url, scrape_err)
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": str(scrape_err)
|
||||
})
|
||||
# Choose content based on requested format
|
||||
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
|
||||
|
||||
results.append({
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"content": chosen_content,
|
||||
"raw_content": chosen_content,
|
||||
"metadata": metadata # Now guaranteed to be a dict
|
||||
})
|
||||
|
||||
except Exception as scrape_err:
|
||||
logger.debug("Scrape failed for %s: %s", url, scrape_err)
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": str(scrape_err)
|
||||
})
|
||||
|
||||
response = {"results": results}
|
||||
|
||||
@@ -1055,14 +887,6 @@ async def web_crawl_tool(
|
||||
}
|
||||
|
||||
try:
|
||||
# web_crawl requires Firecrawl — Parallel has no crawl API
|
||||
if not (os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL")):
|
||||
return json.dumps({
|
||||
"error": "web_crawl requires Firecrawl. Set FIRECRAWL_API_KEY, "
|
||||
"or use web_search + web_extract instead.",
|
||||
"success": False,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Ensure URL has protocol
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = f'https://{url}'
|
||||
@@ -1327,22 +1151,13 @@ async def web_crawl_tool(
|
||||
def check_firecrawl_api_key() -> bool:
|
||||
"""
|
||||
Check if the Firecrawl API key is available in environment variables.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
return bool(os.getenv("FIRECRAWL_API_KEY"))
|
||||
|
||||
|
||||
def check_web_api_key() -> bool:
|
||||
"""Check if any web backend API key is available (Parallel or Firecrawl)."""
|
||||
return bool(
|
||||
os.getenv("PARALLEL_API_KEY")
|
||||
or os.getenv("FIRECRAWL_API_KEY")
|
||||
or os.getenv("FIRECRAWL_API_URL")
|
||||
)
|
||||
|
||||
|
||||
def check_auxiliary_model() -> bool:
|
||||
"""Check if an auxiliary text model is available for LLM content processing."""
|
||||
try:
|
||||
@@ -1369,30 +1184,26 @@ if __name__ == "__main__":
|
||||
print("=" * 40)
|
||||
|
||||
# Check if API keys are available
|
||||
web_available = check_web_api_key()
|
||||
firecrawl_available = check_firecrawl_api_key()
|
||||
nous_available = check_auxiliary_model()
|
||||
|
||||
if web_available:
|
||||
backend = _get_backend()
|
||||
print(f"✅ Web backend: {backend}")
|
||||
if backend == "parallel":
|
||||
print(" Using Parallel API (https://parallel.ai)")
|
||||
else:
|
||||
print(" Using Firecrawl API (https://firecrawl.dev)")
|
||||
|
||||
if not firecrawl_available:
|
||||
print("❌ FIRECRAWL_API_KEY environment variable not set")
|
||||
print("Please set your API key: export FIRECRAWL_API_KEY='your-key-here'")
|
||||
print("Get API key at: https://firecrawl.dev/")
|
||||
else:
|
||||
print("❌ No web search backend configured")
|
||||
print("Set PARALLEL_API_KEY (https://parallel.ai) or FIRECRAWL_API_KEY (https://firecrawl.dev)")
|
||||
|
||||
print("✅ Firecrawl API key found")
|
||||
|
||||
if not nous_available:
|
||||
print("❌ No auxiliary model available for LLM content processing")
|
||||
print("Set OPENROUTER_API_KEY, configure Nous Portal, or set OPENAI_BASE_URL + OPENAI_API_KEY")
|
||||
print("⚠️ Without an auxiliary model, LLM content processing will be disabled")
|
||||
else:
|
||||
print(f"✅ Auxiliary model available: {DEFAULT_SUMMARIZER_MODEL}")
|
||||
|
||||
if not web_available:
|
||||
|
||||
if not firecrawl_available:
|
||||
exit(1)
|
||||
|
||||
|
||||
print("🛠️ Web tools ready for use!")
|
||||
|
||||
if nous_available:
|
||||
@@ -1490,8 +1301,8 @@ registry.register(
|
||||
toolset="web",
|
||||
schema=WEB_SEARCH_SCHEMA,
|
||||
handler=lambda args, **kw: web_search_tool(args.get("query", ""), limit=5),
|
||||
check_fn=check_web_api_key,
|
||||
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY"],
|
||||
check_fn=check_firecrawl_api_key,
|
||||
requires_env=["FIRECRAWL_API_KEY"],
|
||||
emoji="🔍",
|
||||
)
|
||||
registry.register(
|
||||
@@ -1500,8 +1311,8 @@ registry.register(
|
||||
schema=WEB_EXTRACT_SCHEMA,
|
||||
handler=lambda args, **kw: web_extract_tool(
|
||||
args.get("urls", [])[:5] if isinstance(args.get("urls"), list) else [], "markdown"),
|
||||
check_fn=check_web_api_key,
|
||||
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY"],
|
||||
check_fn=check_firecrawl_api_key,
|
||||
requires_env=["FIRECRAWL_API_KEY"],
|
||||
is_async=True,
|
||||
emoji="📄",
|
||||
)
|
||||
|
||||
@@ -130,12 +130,6 @@ TOOLSETS = {
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"messaging": {
|
||||
"description": "Cross-platform messaging: send messages to Telegram, Discord, Slack, SMS, etc.",
|
||||
"tools": ["send_message"],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"rl": {
|
||||
"description": "RL training tools for running reinforcement learning on Tinker-Atropos",
|
||||
"tools": [
|
||||
|
||||
@@ -61,7 +61,6 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `PARALLEL_API_KEY` | AI-native web search ([parallel.ai](https://parallel.ai/)) |
|
||||
| `FIRECRAWL_API_KEY` | Web scraping ([firecrawl.dev](https://firecrawl.dev/)) |
|
||||
| `FIRECRAWL_API_URL` | Custom Firecrawl API endpoint for self-hosted instances (optional) |
|
||||
| `BROWSERBASE_API_KEY` | Browser automation ([browserbase.com](https://browserbase.com/)) |
|
||||
|
||||
Reference in New Issue
Block a user