mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-26 03:43:37 +08:00
Compare commits
8 Commits
feat/cache
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5b85531f9 | ||
|
|
b430b5acfe | ||
|
|
2001b88c23 | ||
|
|
79975692a5 | ||
|
|
77608c90ac | ||
|
|
e00064c58f | ||
|
|
230506a3ef | ||
|
|
861685684c |
@@ -7,7 +7,7 @@ protecting head and tail context.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.auxiliary_client import call_llm
|
||||
from agent.model_metadata import (
|
||||
@@ -17,24 +17,6 @@ from agent.model_metadata import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NEVER_PRUNE_TOOLS = {"clarify", "memory", "skill_view", "todo", "read_file"}
|
||||
|
||||
|
||||
def _adaptive_prune_protect(context_length: int) -> int:
|
||||
"""Scale the recent-tool-output protection window to the model context size."""
|
||||
if context_length >= 500_000:
|
||||
return 100_000
|
||||
if context_length >= 128_000:
|
||||
return 40_000
|
||||
if context_length >= 64_000:
|
||||
return 20_000
|
||||
return 10_000
|
||||
|
||||
|
||||
def _adaptive_prune_minimum(context_length: int) -> int:
|
||||
"""Only prune when it reclaims a meaningful amount of prompt budget."""
|
||||
return max(5_000, context_length // 20)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Compresses conversation context when approaching the model's context limit.
|
||||
@@ -72,10 +54,6 @@ class ContextCompressor:
|
||||
self.last_total_tokens = 0
|
||||
|
||||
self.summary_model = summary_model_override or ""
|
||||
self._prune_protect_tokens = _adaptive_prune_protect(self.context_length)
|
||||
self._prune_minimum_tokens = _adaptive_prune_minimum(self.context_length)
|
||||
self._prune_runway_tokens = max(self._prune_minimum_tokens, int(self.threshold_tokens * 0.15))
|
||||
self._prune_target_tokens = max(0, self.threshold_tokens - self._prune_runway_tokens)
|
||||
|
||||
def update_from_response(self, usage: Dict[str, Any]):
|
||||
"""Update tracked token usage from API response."""
|
||||
@@ -103,58 +81,6 @@ class ContextCompressor:
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _is_protected_tool(self, message: Dict[str, Any]) -> bool:
|
||||
"""Return True when a tool output should never be pruned."""
|
||||
return (message.get("name") or "") in NEVER_PRUNE_TOOLS
|
||||
|
||||
def _prune_tool_outputs(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""Replace older middle tool outputs with compact placeholders.
|
||||
|
||||
Only prunes tool outputs from the same middle region that would be eligible
|
||||
for summarization. The head/tail protected windows are left untouched.
|
||||
|
||||
Returns:
|
||||
(messages_after_prune, chars_saved)
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
compress_start = self.protect_first_n
|
||||
compress_end = n_messages - self.protect_last_n
|
||||
if compress_start >= compress_end:
|
||||
return messages, 0
|
||||
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
compress_end = self._align_boundary_backward(messages, compress_end)
|
||||
if compress_start >= compress_end:
|
||||
return messages, 0
|
||||
|
||||
pruned = [msg.copy() for msg in messages]
|
||||
chars_saved = 0
|
||||
recent_tool_tokens = 0
|
||||
|
||||
for i in range(compress_end - 1, compress_start - 1, -1):
|
||||
msg = pruned[i]
|
||||
if msg.get("role") != "tool" or self._is_protected_tool(msg):
|
||||
continue
|
||||
|
||||
content = msg.get("content")
|
||||
content_text = content if isinstance(content, str) else str(content or "")
|
||||
token_estimate = max(1, len(content_text) // 4)
|
||||
|
||||
if recent_tool_tokens < self._prune_protect_tokens:
|
||||
recent_tool_tokens += token_estimate
|
||||
continue
|
||||
|
||||
original_len = len(content_text)
|
||||
placeholder = f"[Tool output pruned — was {original_len:,} chars]"
|
||||
pruned[i]["content"] = placeholder
|
||||
chars_saved += max(0, original_len - len(placeholder))
|
||||
|
||||
tokens_saved = chars_saved // 4
|
||||
if tokens_saved < self._prune_minimum_tokens:
|
||||
return messages, 0
|
||||
|
||||
return pruned, chars_saved
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
|
||||
@@ -206,7 +132,11 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
if self.summary_model:
|
||||
call_kwargs["model"] = self.summary_model
|
||||
response = call_llm(**call_kwargs)
|
||||
summary = response.choices[0].message.content.strip()
|
||||
content = response.choices[0].message.content
|
||||
# Handle cases where content is not a string (e.g., dict from llama.cpp)
|
||||
if not isinstance(content, str):
|
||||
content = str(content) if content else ""
|
||||
summary = content.strip()
|
||||
if not summary.startswith("[CONTEXT SUMMARY]:"):
|
||||
summary = "[CONTEXT SUMMARY]: " + summary
|
||||
return summary
|
||||
@@ -341,49 +271,13 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
display_tokens = current_tokens if current_tokens is not None else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
|
||||
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
|
||||
|
||||
pruned_messages, chars_saved = self._prune_tool_outputs(messages)
|
||||
if chars_saved > 0:
|
||||
pruned_tokens = estimate_messages_tokens_rough(pruned_messages)
|
||||
tokens_saved_phase1 = max(0, display_tokens - pruned_tokens)
|
||||
if not self.quiet_mode:
|
||||
print(
|
||||
f" ✂️ Phase 1 (prune): removed {chars_saved:,} chars of old tool outputs "
|
||||
f"(~{tokens_saved_phase1:,} tokens saved)"
|
||||
)
|
||||
if pruned_tokens <= self._prune_target_tokens:
|
||||
self.compression_count += 1
|
||||
pruned_messages = self._sanitize_tool_pairs(pruned_messages)
|
||||
if not self.quiet_mode:
|
||||
print(
|
||||
f" ✅ Phase 1 sufficient: {n_messages} → {len(pruned_messages)} messages, "
|
||||
f"now {pruned_tokens:,} tokens"
|
||||
)
|
||||
print(f" 💡 Compression #{self.compression_count} complete (prune only — no LLM call needed)")
|
||||
return pruned_messages
|
||||
if not self.quiet_mode and pruned_tokens < self.threshold_tokens:
|
||||
print(
|
||||
f" ↪️ Phase 1 recovered tokens but not enough runway "
|
||||
f"({pruned_tokens:,} > target {self._prune_target_tokens:,}); continuing to compaction"
|
||||
)
|
||||
messages = pruned_messages
|
||||
n_messages = len(messages)
|
||||
compress_start = self.protect_first_n
|
||||
compress_end = n_messages - self.protect_last_n
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
compress_end = self._align_boundary_backward(messages, compress_end)
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
# Cache-Aware Context Compaction Design Note
|
||||
|
||||
> For Hermes: this note is a design/implementation sketch for revisiting prune-first compaction without optimizing token spend at the expense of prompt-cache stability.
|
||||
|
||||
Goal: reduce compression cost while keeping cache-break frequency as low as possible.
|
||||
|
||||
Architecture: keep Hermes' current invariant that conversation history is only mutated during context compression, then make prune-first compaction conservative enough that it only short-circuits when it buys meaningful runway. If pruning only gets us barely below threshold, fall through to the existing summary compaction immediately.
|
||||
|
||||
Tech Stack: `agent/context_compressor.py`, existing `call_llm()`-based summary path, pytest coverage in `tests/agent/test_context_compressor.py`.
|
||||
|
||||
---
|
||||
|
||||
## 1. Baseline behavior on current main
|
||||
|
||||
Today Hermes behaves like this:
|
||||
|
||||
1. Prompt crosses the compression threshold.
|
||||
2. We mutate transcript history once by summarizing the middle region with an LLM.
|
||||
3. We preserve role alternation and tool-call/tool-result integrity.
|
||||
4. We continue the conversation from the compressed transcript.
|
||||
|
||||
This is expensive in two ways:
|
||||
- an auxiliary summary call is often required
|
||||
- the entire compressed middle region is rewritten even when the real problem was just a few huge old tool outputs
|
||||
|
||||
But it has one strong cache property:
|
||||
- it tends to reclaim a lot of headroom per compression event, so the next compression is usually farther away
|
||||
|
||||
---
|
||||
|
||||
## 2. Why naive prune-first compaction is not enough
|
||||
|
||||
A naive prune-first policy says:
|
||||
- prune old tool outputs
|
||||
- if prompt is now below threshold, stop
|
||||
|
||||
This improves per-event token cost, but it can hurt cache economics:
|
||||
- prune-only may reclaim less headroom than full compaction
|
||||
- smaller headroom means the next compression may happen sooner
|
||||
- each compression event is still a cache-breaking transcript mutation
|
||||
|
||||
So there is a real failure mode:
|
||||
- fewer tokens per compression
|
||||
- more compression events overall
|
||||
- worse cache break cadence
|
||||
|
||||
That is exactly the tradeoff we want to avoid.
|
||||
|
||||
---
|
||||
|
||||
## 3. Cache-aware principle
|
||||
|
||||
Prune-first compaction should only short-circuit when it buys real runway, not when it merely dips under threshold.
|
||||
|
||||
Rule of thumb:
|
||||
- compression frequency matters as much as compression size
|
||||
- a smaller mutation is not automatically cheaper if it causes another mutation a few turns later
|
||||
|
||||
So the design target is:
|
||||
- fewer auxiliary summary calls
|
||||
- without materially increasing compression frequency
|
||||
|
||||
---
|
||||
|
||||
## 4. Conservative prototype policy
|
||||
|
||||
The conservative prototype keeps all existing compression invariants and only changes the acceptance rule for prune-only compaction.
|
||||
|
||||
### Phase 1: prune old middle tool outputs
|
||||
|
||||
Only prune tool outputs that are:
|
||||
- in the compressible middle region
|
||||
- not in protected head/tail windows
|
||||
- not from protected tools (`read_file`, `memory`, `clarify`, `skill_view`, `todo`)
|
||||
|
||||
### Phase 2: require a low-water mark
|
||||
|
||||
Do not accept prune-only just because it lands below threshold.
|
||||
|
||||
Instead require:
|
||||
- `post_prune_tokens <= prune_target_tokens`
|
||||
|
||||
Where:
|
||||
- `prune_runway_tokens = max(prune_minimum_tokens, 15% of threshold_tokens)`
|
||||
- `prune_target_tokens = threshold_tokens - prune_runway_tokens`
|
||||
|
||||
Interpretation:
|
||||
- pruning must get us comfortably below threshold
|
||||
- otherwise we immediately fall through to normal LLM summary compaction
|
||||
|
||||
Why this helps:
|
||||
- protects cache by avoiding "micro-compactions" that would be followed by another compression shortly after
|
||||
- still avoids the summary call when pruning truly buys useful runway
|
||||
|
||||
---
|
||||
|
||||
## 5. What the prototype currently does
|
||||
|
||||
The prototype branch currently:
|
||||
- keeps prune-first compaction
|
||||
- adds the low-water / runway requirement above
|
||||
- preserves current main behavior for summary role alternation
|
||||
- preserves the centralized `call_llm()` summary path
|
||||
- keeps head/tail and tool-call/result integrity handling unchanged
|
||||
|
||||
This means the branch is no longer optimizing only for token reduction per event; it is explicitly biased toward fewer compression events.
|
||||
|
||||
---
|
||||
|
||||
## 6. Metrics we should evaluate before merging any future version
|
||||
|
||||
A serious cache-aware review should measure all of these, not just token savings:
|
||||
|
||||
1. Compression events per 100 conversation turns
|
||||
2. Average turns between compressions
|
||||
3. Auxiliary summary calls per session
|
||||
4. Average tokens reclaimed per compression event
|
||||
5. Total prompt+auxiliary tokens spent over a long session
|
||||
6. Earliest changed message index during compression
|
||||
7. Ratio of prune-only compressions to full summary compressions
|
||||
|
||||
The most important comparison is:
|
||||
- baseline main vs conservative prune-first
|
||||
|
||||
Success is not:
|
||||
- "fewer tokens in one compression"
|
||||
|
||||
Success is:
|
||||
- "equal or better total session cost without increasing compression/cache-break cadence in a meaningful way"
|
||||
|
||||
---
|
||||
|
||||
## 7. Better long-term directions
|
||||
|
||||
If we want a stronger cache story than conservative prune-first, these are the real next-step options:
|
||||
|
||||
### A. Insertion-time trimming
|
||||
|
||||
Best cache-preserving option.
|
||||
|
||||
Idea:
|
||||
- trim or summarize giant tool outputs before they become durable transcript history
|
||||
- keep a compact representation from the start instead of mutating history later
|
||||
|
||||
Pros:
|
||||
- avoids later cache-breaking rewrites for those blobs
|
||||
- makes transcript size stable earlier
|
||||
|
||||
Cons:
|
||||
- more invasive design change
|
||||
- requires careful UX and provenance handling
|
||||
|
||||
### B. Provider/backend-aware compaction policy
|
||||
|
||||
Different providers may reward:
|
||||
- preserving a longer stable prefix
|
||||
- or simply reducing total prompt size
|
||||
|
||||
We may eventually want backend-specific heuristics for:
|
||||
- prune runway targets
|
||||
- compression thresholds
|
||||
- when to prefer summary vs pruning
|
||||
|
||||
### C. Explicit compression telemetry
|
||||
|
||||
If compression remains a core feature, `ContextCompressor` should expose enough telemetry to understand real-world cadence:
|
||||
- prune-only count
|
||||
- full summary count
|
||||
- average recovered tokens
|
||||
- last compression mode
|
||||
|
||||
This is not required for the conservative prototype, but it would make future tuning much easier.
|
||||
|
||||
---
|
||||
|
||||
## 8. Recommended next steps
|
||||
|
||||
1. Keep the conservative prototype local for review.
|
||||
2. Run targeted tests plus long-session manual trials.
|
||||
3. If it looks promising, add telemetry before opening another PR.
|
||||
4. If cache stability remains the top priority, pursue insertion-time trimming instead of further read-time pruning tweaks.
|
||||
|
||||
---
|
||||
|
||||
## 9. Review question for Teknium
|
||||
|
||||
The key product question is:
|
||||
|
||||
"Should Hermes optimize compression primarily for per-event token cost, or for minimizing the number of transcript mutations over the lifetime of a session?"
|
||||
|
||||
This prototype assumes the answer is:
|
||||
- prioritize fewer transcript mutations unless pruning buys substantial runway.
|
||||
@@ -83,10 +83,13 @@ class SessionResetPolicy:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionResetPolicy":
|
||||
# Handle both missing keys and explicit null values (YAML null → None)
|
||||
at_hour = data.get("at_hour")
|
||||
idle_minutes = data.get("idle_minutes")
|
||||
return cls(
|
||||
mode=data.get("mode", "both"),
|
||||
at_hour=data.get("at_hour", 4),
|
||||
idle_minutes=data.get("idle_minutes", 1440),
|
||||
at_hour=at_hour if at_hour is not None else 4,
|
||||
idle_minutes=idle_minutes if idle_minutes is not None else 1440,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -43,23 +43,6 @@ from gateway.platforms.base import (
|
||||
)
|
||||
|
||||
|
||||
def _clean_discord_id(entry: str) -> str:
|
||||
"""Strip common prefixes from a Discord user ID or username entry.
|
||||
|
||||
Users sometimes paste IDs with prefixes like ``user:123``, ``<@123>``,
|
||||
or ``<@!123>`` from Discord's UI or other tools. This normalises the
|
||||
entry to just the bare ID or username.
|
||||
"""
|
||||
entry = entry.strip()
|
||||
# Strip Discord mention syntax: <@123> or <@!123>
|
||||
if entry.startswith("<@") and entry.endswith(">"):
|
||||
entry = entry.lstrip("<@!").rstrip(">")
|
||||
# Strip "user:" prefix (seen in some Discord tools / onboarding pastes)
|
||||
if entry.lower().startswith("user:"):
|
||||
entry = entry[5:]
|
||||
return entry.strip()
|
||||
|
||||
|
||||
def check_discord_requirements() -> bool:
|
||||
"""Check if Discord dependencies are available."""
|
||||
return DISCORD_AVAILABLE
|
||||
@@ -116,8 +99,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
self._allowed_user_ids = {
|
||||
_clean_discord_id(uid) for uid in allowed_env.split(",")
|
||||
if uid.strip()
|
||||
uid.strip() for uid in allowed_env.split(",") if uid.strip()
|
||||
}
|
||||
|
||||
adapter_self = self # capture for closure
|
||||
|
||||
@@ -1541,20 +1541,8 @@ def detect_external_credentials() -> List[Dict[str, Any]]:
|
||||
# CLI Commands — login / logout
|
||||
# =============================================================================
|
||||
|
||||
def _update_config_for_provider(
|
||||
provider_id: str,
|
||||
inference_base_url: str,
|
||||
default_model: Optional[str] = None,
|
||||
) -> Path:
|
||||
"""Update config.yaml and auth.json to reflect the active provider.
|
||||
|
||||
When *default_model* is provided the function also writes it as the
|
||||
``model.default`` value. This prevents a race condition where the
|
||||
gateway (which re-reads config per-message) picks up the new provider
|
||||
before the caller has finished model selection, resulting in a
|
||||
mismatched model/provider (e.g. ``anthropic/claude-opus-4.6`` sent to
|
||||
MiniMax's API).
|
||||
"""
|
||||
def _update_config_for_provider(provider_id: str, inference_base_url: str) -> Path:
|
||||
"""Update config.yaml and auth.json to reflect the active provider."""
|
||||
# Set active_provider in auth.json so auto-resolution picks this provider
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
@@ -1588,15 +1576,6 @@ def _update_config_for_provider(
|
||||
else:
|
||||
# Clear stale base_url to prevent contamination when switching providers
|
||||
model_cfg.pop("base_url", None)
|
||||
|
||||
# When switching to a non-OpenRouter provider, ensure model.default is
|
||||
# valid for the new provider. An OpenRouter-formatted name like
|
||||
# "anthropic/claude-opus-4.6" will fail on direct-API providers.
|
||||
if default_model:
|
||||
cur_default = model_cfg.get("default", "")
|
||||
if not cur_default or "/" in cur_default:
|
||||
model_cfg["default"] = default_model
|
||||
|
||||
config["model"] = model_cfg
|
||||
|
||||
config_path.write_text(yaml.safe_dump(config, sort_keys=False))
|
||||
|
||||
@@ -194,8 +194,13 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
"stt": {
|
||||
"enabled": True,
|
||||
"model": "whisper-1",
|
||||
"provider": "local", # "local" (free, faster-whisper) | "openai" (Whisper API)
|
||||
"local": {
|
||||
"model": "base", # tiny, base, small, medium, large-v3
|
||||
},
|
||||
"openai": {
|
||||
"model": "whisper-1", # whisper-1, gpt-4o-mini-transcribe, gpt-4o-transcribe
|
||||
},
|
||||
},
|
||||
|
||||
"human_delay": {
|
||||
|
||||
@@ -623,18 +623,6 @@ def _setup_standard_platform(platform: dict):
|
||||
value = prompt(f" {var['prompt']}", password=False)
|
||||
if value:
|
||||
cleaned = value.replace(" ", "")
|
||||
# For Discord, strip common prefixes (user:123, <@123>, <@!123>)
|
||||
if "DISCORD" in var["name"]:
|
||||
parts = []
|
||||
for uid in cleaned.split(","):
|
||||
uid = uid.strip()
|
||||
if uid.startswith("<@") and uid.endswith(">"):
|
||||
uid = uid.lstrip("<@!").rstrip(">")
|
||||
if uid.lower().startswith("user:"):
|
||||
uid = uid[5:]
|
||||
if uid:
|
||||
parts.append(uid)
|
||||
cleaned = ",".join(parts)
|
||||
save_env_value(var["name"], cleaned)
|
||||
print_success(f" Saved — only these users can interact with the bot.")
|
||||
allowed_val_set = cleaned
|
||||
|
||||
@@ -111,17 +111,7 @@ def _setup_provider_model_selection(config, provider_id, current_model, prompt_c
|
||||
custom = prompt_fn("Enter model name")
|
||||
if custom:
|
||||
_set_default_model(config, custom)
|
||||
else:
|
||||
# "Keep current" selected — validate it's compatible with the new
|
||||
# provider. OpenRouter-formatted names (containing "/") won't work
|
||||
# on direct-API providers and would silently break the gateway.
|
||||
if "/" in (current_model or "") and provider_models:
|
||||
print_warning(
|
||||
f"Current model \"{current_model}\" looks like an OpenRouter model "
|
||||
f"and won't work with {pconfig.name}. "
|
||||
f"Switching to {provider_models[0]}."
|
||||
)
|
||||
_set_default_model(config, provider_models[0])
|
||||
# else: keep current
|
||||
|
||||
|
||||
def _sync_model_from_disk(config: Dict[str, Any]) -> None:
|
||||
@@ -977,7 +967,7 @@ def setup_model_provider(config: dict):
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("zai", zai_base_url, default_model="glm-5")
|
||||
_update_config_for_provider("zai", zai_base_url)
|
||||
_set_model_provider(config, "zai", zai_base_url)
|
||||
|
||||
elif provider_idx == 5: # Kimi / Moonshot
|
||||
@@ -1010,7 +1000,7 @@ def setup_model_provider(config: dict):
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("kimi-coding", pconfig.inference_base_url, default_model="kimi-k2.5")
|
||||
_update_config_for_provider("kimi-coding", pconfig.inference_base_url)
|
||||
_set_model_provider(config, "kimi-coding", pconfig.inference_base_url)
|
||||
|
||||
elif provider_idx == 6: # MiniMax
|
||||
@@ -1043,7 +1033,7 @@ def setup_model_provider(config: dict):
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("minimax", pconfig.inference_base_url, default_model="MiniMax-M2.5")
|
||||
_update_config_for_provider("minimax", pconfig.inference_base_url)
|
||||
_set_model_provider(config, "minimax", pconfig.inference_base_url)
|
||||
|
||||
elif provider_idx == 7: # MiniMax China
|
||||
@@ -1076,7 +1066,7 @@ def setup_model_provider(config: dict):
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("minimax-cn", pconfig.inference_base_url, default_model="MiniMax-M2.5")
|
||||
_update_config_for_provider("minimax-cn", pconfig.inference_base_url)
|
||||
_set_model_provider(config, "minimax-cn", pconfig.inference_base_url)
|
||||
|
||||
elif provider_idx == 8: # Anthropic
|
||||
@@ -1180,7 +1170,7 @@ def setup_model_provider(config: dict):
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
# Don't save base_url for Anthropic — resolve_runtime_provider()
|
||||
# always hardcodes it. Stale base_urls contaminate other providers.
|
||||
_update_config_for_provider("anthropic", "", default_model="claude-opus-4-6")
|
||||
_update_config_for_provider("anthropic", "")
|
||||
_set_model_provider(config, "anthropic")
|
||||
|
||||
# else: provider_idx == 9 (Keep current) — only shown when a provider already exists
|
||||
@@ -1935,17 +1925,7 @@ def setup_gateway(config: dict):
|
||||
"Allowed user IDs or usernames (comma-separated, leave empty for open access)"
|
||||
)
|
||||
if allowed_users:
|
||||
# Clean up common prefixes (user:123, <@123>, <@!123>)
|
||||
cleaned_ids = []
|
||||
for uid in allowed_users.replace(" ", "").split(","):
|
||||
uid = uid.strip()
|
||||
if uid.startswith("<@") and uid.endswith(">"):
|
||||
uid = uid.lstrip("<@!").rstrip(">")
|
||||
if uid.lower().startswith("user:"):
|
||||
uid = uid[5:]
|
||||
if uid:
|
||||
cleaned_ids.append(uid)
|
||||
save_env_value("DISCORD_ALLOWED_USERS", ",".join(cleaned_ids))
|
||||
save_env_value("DISCORD_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("Discord allowlist configured")
|
||||
else:
|
||||
print_info(
|
||||
@@ -1980,18 +1960,8 @@ def setup_gateway(config: dict):
|
||||
)
|
||||
allowed_users = prompt("Allowed user IDs (comma-separated)")
|
||||
if allowed_users:
|
||||
# Clean up common prefixes (user:123, <@123>, <@!123>)
|
||||
cleaned_ids = []
|
||||
for uid in allowed_users.replace(" ", "").split(","):
|
||||
uid = uid.strip()
|
||||
if uid.startswith("<@") and uid.endswith(">"):
|
||||
uid = uid.lstrip("<@!").rstrip(">")
|
||||
if uid.lower().startswith("user:"):
|
||||
uid = uid[5:]
|
||||
if uid:
|
||||
cleaned_ids.append(uid)
|
||||
save_env_value(
|
||||
"DISCORD_ALLOWED_USERS", ",".join(cleaned_ids)
|
||||
"DISCORD_ALLOWED_USERS", allowed_users.replace(" ", "")
|
||||
)
|
||||
print_success("Discord allowlist configured")
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ dependencies = [
|
||||
"fal-client",
|
||||
# Text-to-speech (Edge TTS is free, no API key needed)
|
||||
"edge-tts",
|
||||
"faster-whisper>=1.0.0",
|
||||
# mini-swe-agent deps (terminal tool)
|
||||
"litellm>=1.75.5",
|
||||
"typer",
|
||||
|
||||
@@ -2729,7 +2729,7 @@ class AIAgent:
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools if self.tools else None,
|
||||
"timeout": 900.0,
|
||||
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 900.0)),
|
||||
}
|
||||
|
||||
if self.max_tokens is not None:
|
||||
|
||||
@@ -153,6 +153,47 @@ class TestGenerateSummaryNoneContent:
|
||||
assert len(result) < len(msgs)
|
||||
|
||||
|
||||
class TestNonStringContent:
|
||||
"""Regression: content as dict (e.g., llama.cpp tool calls) must not crash."""
|
||||
|
||||
def test_dict_content_coerced_to_string(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = {"text": "some summary"}
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "do something"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]
|
||||
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
summary = c._generate_summary(messages)
|
||||
assert isinstance(summary, str)
|
||||
assert "CONTEXT SUMMARY" in summary
|
||||
|
||||
def test_none_content_coerced_to_empty(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = None
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "do something"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]
|
||||
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
summary = c._generate_summary(messages)
|
||||
# None content → empty string → "[CONTEXT SUMMARY]: " prefix added
|
||||
assert summary is not None
|
||||
assert "CONTEXT SUMMARY" in summary
|
||||
|
||||
|
||||
class TestCompressWithClient:
|
||||
def test_summarization_path(self):
|
||||
mock_client = MagicMock()
|
||||
@@ -314,143 +355,3 @@ class TestCompressWithClient:
|
||||
for msg in result:
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id"):
|
||||
assert msg["tool_call_id"] in called_ids
|
||||
|
||||
|
||||
class TestPruneToolOutputs:
|
||||
def _make_compressor(self, *, context_length=128000, protect_first_n=2, protect_last_n=2):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=context_length):
|
||||
return ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.50,
|
||||
protect_first_n=protect_first_n,
|
||||
protect_last_n=protect_last_n,
|
||||
quiet_mode=True,
|
||||
)
|
||||
|
||||
def test_prune_replaces_old_middle_tool_outputs(self):
|
||||
c = self._make_compressor(protect_last_n=1)
|
||||
big_content = "x" * (c._prune_protect_tokens * 4)
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "task"},
|
||||
{"role": "assistant", "content": "older"},
|
||||
{"role": "tool", "content": big_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "newer"},
|
||||
{"role": "tool", "content": big_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "tail"},
|
||||
]
|
||||
|
||||
pruned, chars_saved = c._prune_tool_outputs(messages)
|
||||
|
||||
assert chars_saved > 0
|
||||
assert pruned[3]["content"].startswith("[Tool output pruned")
|
||||
assert pruned[5]["content"] == big_content
|
||||
|
||||
def test_protected_tools_are_never_pruned(self):
|
||||
c = self._make_compressor()
|
||||
big_content = "x" * (c._prune_protect_tokens * 8)
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "task"},
|
||||
{"role": "assistant", "content": "older"},
|
||||
{"role": "tool", "content": big_content, "name": "read_file"},
|
||||
{"role": "assistant", "content": "middle"},
|
||||
{"role": "tool", "content": big_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "tail"},
|
||||
]
|
||||
|
||||
pruned, _ = c._prune_tool_outputs(messages)
|
||||
read_file_msg = next(msg for msg in pruned if msg.get("name") == "read_file")
|
||||
assert read_file_msg["content"] == big_content
|
||||
|
||||
def test_prune_only_path_skips_summary_call_when_sufficient(self):
|
||||
c = self._make_compressor(protect_first_n=2, protect_last_n=1)
|
||||
huge_content = "x" * 180000
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "task"},
|
||||
{"role": "assistant", "content": "older"},
|
||||
{"role": "tool", "content": huge_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "newer"},
|
||||
{"role": "tool", "content": huge_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "tail"},
|
||||
]
|
||||
|
||||
with patch.object(ContextCompressor, "_generate_summary", side_effect=AssertionError("summary should not be called")):
|
||||
result = c.compress(messages, current_tokens=200000)
|
||||
|
||||
assert result[3]["content"].startswith("[Tool output pruned")
|
||||
assert result[5]["content"] == huge_content
|
||||
assert c.compression_count == 1
|
||||
|
||||
def test_prune_does_not_touch_protected_tail_messages(self):
|
||||
c = self._make_compressor(context_length=128000, protect_first_n=2, protect_last_n=3)
|
||||
huge_content = "x" * (c._prune_protect_tokens * 8)
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "task"},
|
||||
{"role": "assistant", "content": "older"},
|
||||
{"role": "tool", "content": huge_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "tail assistant"},
|
||||
{"role": "tool", "content": huge_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "latest"},
|
||||
]
|
||||
|
||||
pruned, _ = c._prune_tool_outputs(messages)
|
||||
|
||||
assert pruned[-2]["content"] == huge_content
|
||||
assert pruned[-1]["content"] == "latest"
|
||||
|
||||
|
||||
class TestPruneAcceptancePolicy:
|
||||
def _make_compressor(self, *, context_length=128000):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=context_length):
|
||||
return ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.50,
|
||||
protect_first_n=2,
|
||||
protect_last_n=1,
|
||||
quiet_mode=True,
|
||||
)
|
||||
|
||||
def test_prune_near_threshold_still_falls_back_to_summary(self):
|
||||
c = self._make_compressor()
|
||||
huge_content = "x" * 180000
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "task"},
|
||||
{"role": "assistant", "content": "older"},
|
||||
{"role": "tool", "content": huge_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "newer"},
|
||||
{"role": "tool", "content": huge_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "tail"},
|
||||
]
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compacted"
|
||||
|
||||
with patch("agent.context_compressor.estimate_messages_tokens_rough", return_value=62000), \
|
||||
patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(messages, current_tokens=68000)
|
||||
|
||||
assert any("CONTEXT SUMMARY" in (msg.get("content") or "") for msg in result)
|
||||
|
||||
def test_prune_only_is_allowed_when_it_buys_real_runway(self):
|
||||
c = self._make_compressor()
|
||||
huge_content = "x" * 180000
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "task"},
|
||||
{"role": "assistant", "content": "older"},
|
||||
{"role": "tool", "content": huge_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "newer"},
|
||||
{"role": "tool", "content": huge_content, "name": "terminal"},
|
||||
{"role": "assistant", "content": "tail"},
|
||||
]
|
||||
|
||||
with patch("agent.context_compressor.estimate_messages_tokens_rough", return_value=48000), \
|
||||
patch.object(ContextCompressor, "_generate_summary", side_effect=AssertionError("summary should not be called")):
|
||||
result = c.compress(messages, current_tokens=68000)
|
||||
|
||||
assert result[3]["content"].startswith("[Tool output pruned")
|
||||
assert result[5]["content"] == huge_content
|
||||
|
||||
223
tests/tools/test_transcription.py
Normal file
223
tests/tools/test_transcription.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Tests for transcription_tools.py — local (faster-whisper) and OpenAI providers.
|
||||
|
||||
Tests cover provider selection, config loading, validation, and transcription
|
||||
dispatch. All external dependencies (faster_whisper, openai) are mocked.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, mock_open
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetProvider:
|
||||
"""_get_provider() picks the right backend based on config + availability."""
|
||||
|
||||
def test_local_when_available(self):
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "local"
|
||||
|
||||
def test_local_fallback_to_openai(self, monkeypatch):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "openai"
|
||||
|
||||
def test_local_nothing_available(self, monkeypatch):
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "none"
|
||||
|
||||
def test_openai_when_key_set(self, monkeypatch):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "openai"
|
||||
|
||||
def test_openai_fallback_to_local(self, monkeypatch):
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "local"
|
||||
|
||||
def test_default_provider_is_local(self):
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "local"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateAudioFile:
|
||||
|
||||
def test_missing_file(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
result = _validate_audio_file(str(tmp_path / "nope.ogg"))
|
||||
assert result is not None
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_unsupported_format(self, tmp_path):
|
||||
f = tmp_path / "test.xyz"
|
||||
f.write_bytes(b"data")
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
result = _validate_audio_file(str(f))
|
||||
assert result is not None
|
||||
assert "Unsupported" in result["error"]
|
||||
|
||||
def test_valid_file_returns_none(self, tmp_path):
|
||||
f = tmp_path / "test.ogg"
|
||||
f.write_bytes(b"fake audio data")
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
assert _validate_audio_file(str(f)) is None
|
||||
|
||||
def test_too_large(self, tmp_path):
|
||||
import stat as stat_mod
|
||||
f = tmp_path / "big.ogg"
|
||||
f.write_bytes(b"x")
|
||||
from tools.transcription_tools import _validate_audio_file, MAX_FILE_SIZE
|
||||
real_stat = f.stat()
|
||||
with patch.object(type(f), "stat", return_value=os.stat_result((
|
||||
real_stat.st_mode, real_stat.st_ino, real_stat.st_dev,
|
||||
real_stat.st_nlink, real_stat.st_uid, real_stat.st_gid,
|
||||
MAX_FILE_SIZE + 1, # st_size
|
||||
real_stat.st_atime, real_stat.st_mtime, real_stat.st_ctime,
|
||||
))):
|
||||
result = _validate_audio_file(str(f))
|
||||
assert result is not None
|
||||
assert "too large" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local transcription
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscribeLocal:
|
||||
|
||||
def test_successful_transcription(self, tmp_path):
|
||||
audio_file = tmp_path / "test.ogg"
|
||||
audio_file.write_bytes(b"fake audio")
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.text = "Hello world"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 2.5
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools.WhisperModel", return_value=mock_model), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio_file), "base")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "Hello world"
|
||||
|
||||
def test_not_installed(self):
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local("/tmp/test.ogg", "base")
|
||||
assert result["success"] is False
|
||||
assert "not installed" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI transcription
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscribeOpenAI:
|
||||
|
||||
def test_no_key(self, monkeypatch):
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai("/tmp/test.ogg", "whisper-1")
|
||||
assert result["success"] is False
|
||||
assert "VOICE_TOOLS_OPENAI_KEY" in result["error"]
|
||||
|
||||
def test_successful_transcription(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
audio_file = tmp_path / "test.ogg"
|
||||
audio_file.write_bytes(b"fake audio")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "Hello from OpenAI"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(str(audio_file), "whisper-1")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "Hello from OpenAI"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main transcribe_audio() dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscribeAudio:
|
||||
|
||||
def test_dispatches_to_local(self, tmp_path):
|
||||
audio_file = tmp_path / "test.ogg"
|
||||
audio_file.write_bytes(b"fake audio")
|
||||
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "local"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local", return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(str(audio_file))
|
||||
|
||||
assert result["success"] is True
|
||||
mock_local.assert_called_once()
|
||||
|
||||
def test_dispatches_to_openai(self, tmp_path):
|
||||
audio_file = tmp_path / "test.ogg"
|
||||
audio_file.write_bytes(b"fake audio")
|
||||
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openai"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openai"), \
|
||||
patch("tools.transcription_tools._transcribe_openai", return_value={"success": True, "transcript": "hi"}) as mock_openai:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(str(audio_file))
|
||||
|
||||
assert result["success"] is True
|
||||
mock_openai.assert_called_once()
|
||||
|
||||
def test_no_provider_returns_error(self, tmp_path):
|
||||
audio_file = tmp_path / "test.ogg"
|
||||
audio_file.write_bytes(b"fake audio")
|
||||
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="none"):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(str(audio_file))
|
||||
|
||||
assert result["success"] is False
|
||||
assert "No STT provider" in result["error"]
|
||||
|
||||
def test_invalid_file_returns_error(self):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio("/nonexistent/file.ogg")
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
@@ -2,18 +2,19 @@
|
||||
"""
|
||||
Transcription Tools Module
|
||||
|
||||
Provides speech-to-text transcription using OpenAI's Whisper API.
|
||||
Used by the messaging gateway to automatically transcribe voice messages
|
||||
sent by users on Telegram, Discord, WhatsApp, and Slack.
|
||||
Provides speech-to-text transcription with two providers:
|
||||
|
||||
Supported models:
|
||||
- whisper-1 (cheapest, good quality)
|
||||
- gpt-4o-mini-transcribe (better quality, higher cost)
|
||||
- gpt-4o-transcribe (best quality, highest cost)
|
||||
- **local** (default, free) — faster-whisper running locally, no API key needed.
|
||||
Auto-downloads the model (~150 MB for ``base``) on first use.
|
||||
- **openai** — OpenAI Whisper API, requires ``VOICE_TOOLS_OPENAI_KEY``.
|
||||
|
||||
Used by the messaging gateway to automatically transcribe voice messages
|
||||
sent by users on Telegram, Discord, WhatsApp, Slack, and Signal.
|
||||
|
||||
Supported input formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, ogg
|
||||
|
||||
Usage:
|
||||
Usage::
|
||||
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
|
||||
result = transcribe_audio("/path/to/audio.ogg")
|
||||
@@ -28,27 +29,205 @@ from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional imports — graceful degradation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default STT model -- cheapest and widely available
|
||||
DEFAULT_STT_MODEL = "whisper-1"
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
_HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
_HAS_FASTER_WHISPER = False
|
||||
WhisperModel = None # type: ignore[assignment,misc]
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
_HAS_OPENAI = True
|
||||
except ImportError:
|
||||
_HAS_OPENAI = False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_PROVIDER = "local"
|
||||
DEFAULT_LOCAL_MODEL = "base"
|
||||
DEFAULT_OPENAI_MODEL = "whisper-1"
|
||||
|
||||
# Supported audio formats
|
||||
SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg"}
|
||||
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB
|
||||
|
||||
# Maximum file size (25MB - OpenAI limit)
|
||||
MAX_FILE_SIZE = 25 * 1024 * 1024
|
||||
# Singleton for the local model — loaded once, reused across calls
|
||||
_local_model: Optional["WhisperModel"] = None
|
||||
_local_model_name: Optional[str] = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_stt_config() -> dict:
|
||||
"""Load the ``stt`` section from user config, falling back to defaults."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
return load_config().get("stt", {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_provider(stt_config: dict) -> str:
|
||||
"""Determine which STT provider to use.
|
||||
|
||||
Priority:
|
||||
1. Explicit config value (``stt.provider``)
|
||||
2. Auto-detect: local if faster-whisper available, else openai if key set
|
||||
3. Disabled (returns "none")
|
||||
"""
|
||||
provider = stt_config.get("provider", DEFAULT_PROVIDER)
|
||||
|
||||
if provider == "local":
|
||||
if _HAS_FASTER_WHISPER:
|
||||
return "local"
|
||||
# Local requested but not available — fall back to openai if possible
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
logger.info("faster-whisper not installed, falling back to OpenAI Whisper API")
|
||||
return "openai"
|
||||
return "none"
|
||||
|
||||
if provider == "openai":
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
return "openai"
|
||||
# OpenAI requested but no key — fall back to local if possible
|
||||
if _HAS_FASTER_WHISPER:
|
||||
logger.info("VOICE_TOOLS_OPENAI_KEY not set, falling back to local faster-whisper")
|
||||
return "local"
|
||||
return "none"
|
||||
|
||||
return provider # Unknown — let it fail downstream
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_audio_file(file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""Validate the audio file. Returns an error dict or None if OK."""
|
||||
audio_path = Path(file_path)
|
||||
|
||||
if not audio_path.exists():
|
||||
return {"success": False, "transcript": "", "error": f"Audio file not found: {file_path}"}
|
||||
if not audio_path.is_file():
|
||||
return {"success": False, "transcript": "", "error": f"Path is not a file: {file_path}"}
|
||||
if audio_path.suffix.lower() not in SUPPORTED_FORMATS:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Unsupported format: {audio_path.suffix}. Supported: {', '.join(sorted(SUPPORTED_FORMATS))}",
|
||||
}
|
||||
try:
|
||||
file_size = audio_path.stat().st_size
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024):.0f}MB)",
|
||||
}
|
||||
except OSError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Failed to access file: {e}"}
|
||||
|
||||
return None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: local (faster-whisper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using faster-whisper (local, free)."""
|
||||
global _local_model, _local_model_name
|
||||
|
||||
if not _HAS_FASTER_WHISPER:
|
||||
return {"success": False, "transcript": "", "error": "faster-whisper not installed"}
|
||||
|
||||
try:
|
||||
# Lazy-load the model (downloads on first use, ~150 MB for 'base')
|
||||
if _local_model is None or _local_model_name != model_name:
|
||||
logger.info("Loading faster-whisper model '%s' (first load downloads the model)...", model_name)
|
||||
_local_model = WhisperModel(model_name, device="auto", compute_type="auto")
|
||||
_local_model_name = model_name
|
||||
|
||||
segments, info = _local_model.transcribe(file_path, beam_size=5)
|
||||
transcript = " ".join(segment.text.strip() for segment in segments)
|
||||
|
||||
logger.info(
|
||||
"Transcribed %s via local whisper (%s, lang=%s, %.1fs audio)",
|
||||
Path(file_path).name, model_name, info.language, info.duration,
|
||||
)
|
||||
|
||||
return {"success": True, "transcript": transcript}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Local transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Local transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: openai (Whisper API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_openai(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using OpenAI Whisper API (paid)."""
|
||||
api_key = os.getenv("VOICE_TOOLS_OPENAI_KEY")
|
||||
if not api_key:
|
||||
return {"success": False, "transcript": "", "error": "VOICE_TOOLS_OPENAI_KEY not set"}
|
||||
|
||||
if not _HAS_OPENAI:
|
||||
return {"success": False, "transcript": "", "error": "openai package not installed"}
|
||||
|
||||
try:
|
||||
client = OpenAI(api_key=api_key, base_url="https://api.openai.com/v1")
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=audio_file,
|
||||
response_format="text",
|
||||
)
|
||||
|
||||
transcript_text = str(transcription).strip()
|
||||
logger.info("Transcribed %s via OpenAI API (%s, %d chars)",
|
||||
Path(file_path).name, model_name, len(transcript_text))
|
||||
|
||||
return {"success": True, "transcript": transcript_text}
|
||||
|
||||
except PermissionError:
|
||||
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
|
||||
except APIConnectionError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Connection error: {e}"}
|
||||
except APITimeoutError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Request timeout: {e}"}
|
||||
except APIError as e:
|
||||
return {"success": False, "transcript": "", "error": f"API error: {e}"}
|
||||
except Exception as e:
|
||||
logger.error("OpenAI transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Transcribe an audio file using OpenAI's Whisper API.
|
||||
Transcribe an audio file using the configured STT provider.
|
||||
|
||||
This function calls the OpenAI Audio Transcriptions endpoint directly
|
||||
(not via OpenRouter, since Whisper isn't available there).
|
||||
Provider priority:
|
||||
1. User config (``stt.provider`` in config.yaml)
|
||||
2. Auto-detect: local faster-whisper if available, else OpenAI API
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the audio file to transcribe.
|
||||
model: Whisper model to use. Defaults to config or "whisper-1".
|
||||
model: Override the model. If None, uses config or provider default.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
@@ -56,114 +235,31 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
||||
- "transcript" (str): The transcribed text (empty on failure)
|
||||
- "error" (str, optional): Error message if success is False
|
||||
"""
|
||||
api_key = os.getenv("VOICE_TOOLS_OPENAI_KEY")
|
||||
if not api_key:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": "VOICE_TOOLS_OPENAI_KEY not set",
|
||||
}
|
||||
# Validate input
|
||||
error = _validate_audio_file(file_path)
|
||||
if error:
|
||||
return error
|
||||
|
||||
audio_path = Path(file_path)
|
||||
|
||||
# Validate file exists
|
||||
if not audio_path.exists():
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Audio file not found: {file_path}",
|
||||
}
|
||||
|
||||
if not audio_path.is_file():
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Path is not a file: {file_path}",
|
||||
}
|
||||
|
||||
# Validate file extension
|
||||
if audio_path.suffix.lower() not in SUPPORTED_FORMATS:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Unsupported file format: {audio_path.suffix}. Supported formats: {', '.join(sorted(SUPPORTED_FORMATS))}",
|
||||
}
|
||||
|
||||
# Validate file size
|
||||
try:
|
||||
file_size = audio_path.stat().st_size
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024)}MB)",
|
||||
}
|
||||
except OSError as e:
|
||||
logger.error("Failed to get file size for %s: %s", file_path, e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Failed to access file: {e}",
|
||||
}
|
||||
# Load config and determine provider
|
||||
stt_config = _load_stt_config()
|
||||
provider = _get_provider(stt_config)
|
||||
|
||||
# Use provided model, or fall back to default
|
||||
if model is None:
|
||||
model = DEFAULT_STT_MODEL
|
||||
if provider == "local":
|
||||
local_cfg = stt_config.get("local", {})
|
||||
model_name = model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
|
||||
return _transcribe_local(file_path, model_name)
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
if provider == "openai":
|
||||
openai_cfg = stt_config.get("openai", {})
|
||||
model_name = model or openai_cfg.get("model", DEFAULT_OPENAI_MODEL)
|
||||
return _transcribe_openai(file_path, model_name)
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url="https://api.openai.com/v1")
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=model,
|
||||
file=audio_file,
|
||||
response_format="text",
|
||||
)
|
||||
|
||||
# The response is a plain string when response_format="text"
|
||||
transcript_text = str(transcription).strip()
|
||||
|
||||
logger.info("Transcribed %s (%d chars)", audio_path.name, len(transcript_text))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"transcript": transcript_text,
|
||||
}
|
||||
|
||||
except PermissionError:
|
||||
logger.error("Permission denied accessing file: %s", file_path, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Permission denied: {file_path}",
|
||||
}
|
||||
except APIConnectionError as e:
|
||||
logger.error("API connection error during transcription: %s", e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Connection error: {e}",
|
||||
}
|
||||
except APITimeoutError as e:
|
||||
logger.error("API timeout during transcription: %s", e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Request timeout: {e}",
|
||||
}
|
||||
except APIError as e:
|
||||
logger.error("OpenAI API error during transcription: %s", e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"API error: {e}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during transcription: %s", e, exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Transcription failed: {e}",
|
||||
}
|
||||
# No provider available
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": (
|
||||
"No STT provider available. Install faster-whisper for free local "
|
||||
"transcription, or set VOICE_TOOLS_OPENAI_KEY for the OpenAI Whisper API."
|
||||
),
|
||||
}
|
||||
|
||||
@@ -131,6 +131,7 @@ All variables go in `~/.hermes/.env`. You can also set them with `hermes config
|
||||
| `HERMES_HUMAN_DELAY_MIN_MS` | Custom delay range minimum (ms) |
|
||||
| `HERMES_HUMAN_DELAY_MAX_MS` | Custom delay range maximum (ms) |
|
||||
| `HERMES_QUIET` | Suppress non-essential output (`true`/`false`) |
|
||||
| `HERMES_API_TIMEOUT` | LLM API call timeout in seconds (default: `900`) |
|
||||
| `HERMES_EXEC_ASK` | Enable execution approval prompts in gateway mode (`true`/`false`) |
|
||||
|
||||
## Session Settings
|
||||
|
||||
@@ -67,23 +67,48 @@ Without ffmpeg, Edge TTS audio is sent as a regular audio file (playable, but sh
|
||||
If you want voice bubbles without installing ffmpeg, switch to the OpenAI or ElevenLabs provider.
|
||||
:::
|
||||
|
||||
## Voice Message Transcription
|
||||
## Voice Message Transcription (STT)
|
||||
|
||||
Voice messages sent on Telegram, Discord, WhatsApp, or Slack are automatically transcribed and injected as text into the conversation. The agent sees the transcript as normal text.
|
||||
Voice messages sent on Telegram, Discord, WhatsApp, Slack, or Signal are automatically transcribed and injected as text into the conversation. The agent sees the transcript as normal text.
|
||||
|
||||
| Provider | Model | Quality | Cost |
|
||||
|----------|-------|---------|------|
|
||||
| **OpenAI Whisper** | `whisper-1` (default) | Good | Low |
|
||||
| **OpenAI GPT-4o** | `gpt-4o-mini-transcribe` | Better | Medium |
|
||||
| **OpenAI GPT-4o** | `gpt-4o-transcribe` | Best | Higher |
|
||||
| Provider | Quality | Cost | API Key |
|
||||
|----------|---------|------|---------|
|
||||
| **Local Whisper** (default) | Good | Free | None needed |
|
||||
| **OpenAI Whisper API** | Good–Best | Paid | `VOICE_TOOLS_OPENAI_KEY` |
|
||||
|
||||
Requires `VOICE_TOOLS_OPENAI_KEY` in `~/.hermes/.env`.
|
||||
:::info Zero Config
|
||||
Local transcription works out of the box — no API key needed. The `faster-whisper` model (~150 MB for `base`) is auto-downloaded on first voice message.
|
||||
:::
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
# In ~/.hermes/config.yaml
|
||||
stt:
|
||||
enabled: true
|
||||
model: "whisper-1"
|
||||
provider: "local" # "local" (free, faster-whisper) | "openai" (API)
|
||||
local:
|
||||
model: "base" # tiny, base, small, medium, large-v3
|
||||
openai:
|
||||
model: "whisper-1" # whisper-1, gpt-4o-mini-transcribe, gpt-4o-transcribe
|
||||
```
|
||||
|
||||
### Provider Details
|
||||
|
||||
**Local (faster-whisper)** — Runs Whisper locally via [faster-whisper](https://github.com/SYSTRAN/faster-whisper). Uses CPU by default, GPU if available. Model sizes:
|
||||
|
||||
| Model | Size | Speed | Quality |
|
||||
|-------|------|-------|---------|
|
||||
| `tiny` | ~75 MB | Fastest | Basic |
|
||||
| `base` | ~150 MB | Fast | Good (default) |
|
||||
| `small` | ~500 MB | Medium | Better |
|
||||
| `medium` | ~1.5 GB | Slower | Great |
|
||||
| `large-v3` | ~3 GB | Slowest | Best |
|
||||
|
||||
**OpenAI API** — Requires `VOICE_TOOLS_OPENAI_KEY`. Supports `whisper-1`, `gpt-4o-mini-transcribe`, and `gpt-4o-transcribe`.
|
||||
|
||||
### Fallback Behavior
|
||||
|
||||
If your configured provider isn't available, Hermes automatically falls back:
|
||||
- **Local not installed** → Falls back to OpenAI API (if key is set)
|
||||
- **OpenAI key not set** → Falls back to local Whisper (if installed)
|
||||
- **Neither available** → Voice messages pass through with a note to the user
|
||||
|
||||
Reference in New Issue
Block a user