mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 19:57:07 +08:00
Compare commits
7 Commits
fix/toolse
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc5a67dbf2 | ||
|
|
f9c2565ab4 | ||
|
|
ad5f973a8d | ||
|
|
0791efe2c3 | ||
|
|
934fbe3c06 | ||
|
|
6302e56e7c | ||
|
|
868b3c07e3 |
@@ -523,8 +523,13 @@ def load_gateway_config() -> GatewayConfig:
|
||||
os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc)
|
||||
if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"):
|
||||
os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process config.yaml — falling back to .env / gateway.json values. "
|
||||
"Check %s for syntax errors. Error: %s",
|
||||
_home / "config.yaml",
|
||||
e,
|
||||
)
|
||||
|
||||
config = GatewayConfig.from_dict(gw_data)
|
||||
|
||||
|
||||
@@ -525,6 +525,12 @@ class GatewayRunner:
|
||||
Synchronous worker — meant to be called via run_in_executor from
|
||||
an async context so it doesn't block the event loop.
|
||||
"""
|
||||
# Skip cron sessions — they run headless with no meaningful user
|
||||
# conversation to extract memories from.
|
||||
if old_session_id and old_session_id.startswith("cron_"):
|
||||
logger.debug("Skipping memory flush for cron session: %s", old_session_id)
|
||||
return
|
||||
|
||||
try:
|
||||
history = self.session_store.load_transcript(old_session_id)
|
||||
if not history or len(history) < 4:
|
||||
@@ -557,6 +563,23 @@ class GatewayRunner:
|
||||
if m.get("role") in ("user", "assistant") and m.get("content")
|
||||
]
|
||||
|
||||
# Read live memory state from disk so the flush agent can see
|
||||
# what's already saved and avoid overwriting newer entries.
|
||||
_current_memory = ""
|
||||
try:
|
||||
from tools.memory_tool import MEMORY_DIR
|
||||
for fname, label in [
|
||||
("MEMORY.md", "MEMORY (your personal notes)"),
|
||||
("USER.md", "USER PROFILE (who the user is)"),
|
||||
]:
|
||||
fpath = MEMORY_DIR / fname
|
||||
if fpath.exists():
|
||||
content = fpath.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
_current_memory += f"\n\n## Current {label}:\n{content}"
|
||||
except Exception:
|
||||
pass # Non-fatal — flush still works, just without the guard
|
||||
|
||||
# Give the agent a real turn to think about what to save
|
||||
flush_prompt = (
|
||||
"[System: This session is about to be automatically reset due to "
|
||||
@@ -568,6 +591,20 @@ class GatewayRunner:
|
||||
"2. If you discovered a reusable workflow or solved a non-trivial "
|
||||
"problem, consider saving it as a skill.\n"
|
||||
"3. If nothing is worth saving, that's fine — just skip.\n\n"
|
||||
)
|
||||
|
||||
if _current_memory:
|
||||
flush_prompt += (
|
||||
"IMPORTANT — here is the current live state of memory. Other "
|
||||
"sessions, cron jobs, or the user may have updated it since this "
|
||||
"conversation ended. Do NOT overwrite or remove entries unless "
|
||||
"the conversation above reveals something that genuinely "
|
||||
"supersedes them. Only add new information that is not already "
|
||||
"captured below."
|
||||
f"{_current_memory}\n\n"
|
||||
)
|
||||
|
||||
flush_prompt += (
|
||||
"Do NOT respond to the user. Just use the memory and skill_manage "
|
||||
"tools if needed, then stop.]"
|
||||
)
|
||||
@@ -904,7 +941,9 @@ class GatewayRunner:
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
|
||||
167
tests/gateway/test_flush_memory_stale_guard.py
Normal file
167
tests/gateway/test_flush_memory_stale_guard.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for memory flush stale-overwrite prevention (#2670).
|
||||
|
||||
Verifies that:
|
||||
1. Cron sessions are skipped (no flush for headless cron runs)
|
||||
2. Current memory state is injected into the flush prompt so the
|
||||
flush agent can see what's already saved and avoid overwrites
|
||||
3. The flush still works normally when memory files don't exist
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._honcho_managers = {}
|
||||
runner._honcho_configs = {}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner.adapters = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.session_store = MagicMock()
|
||||
return runner
|
||||
|
||||
|
||||
_TRANSCRIPT_4_MSGS = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
{"role": "user", "content": "remember my name is Alice"},
|
||||
{"role": "assistant", "content": "Got it, Alice!"},
|
||||
]
|
||||
|
||||
|
||||
class TestCronSessionBypass:
|
||||
"""Cron sessions should never trigger a memory flush."""
|
||||
|
||||
def test_cron_session_skipped(self):
|
||||
runner = _make_runner()
|
||||
runner._flush_memories_for_session("cron_job123_20260323_120000")
|
||||
# session_store.load_transcript should never be called
|
||||
runner.session_store.load_transcript.assert_not_called()
|
||||
|
||||
def test_cron_session_with_honcho_key_skipped(self):
|
||||
runner = _make_runner()
|
||||
runner._flush_memories_for_session("cron_daily_20260323", "some-honcho-key")
|
||||
runner.session_store.load_transcript.assert_not_called()
|
||||
|
||||
def test_non_cron_session_proceeds(self):
|
||||
"""Non-cron sessions should still attempt the flush."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner._flush_memories_for_session("session_abc123")
|
||||
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
|
||||
|
||||
|
||||
class TestMemoryInjection:
|
||||
"""The flush prompt should include current memory state from disk."""
|
||||
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
|
||||
"""When memory files exist, their content appears in the flush prompt."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
|
||||
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_123")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
|
||||
flush_prompt = call_kwargs.get("user_message", "")
|
||||
|
||||
# Verify both memory sections appear in the prompt
|
||||
assert "Agent knows Python" in flush_prompt
|
||||
assert "User prefers dark mode" in flush_prompt
|
||||
assert "Name: Alice" in flush_prompt
|
||||
assert "Timezone: PST" in flush_prompt
|
||||
# Verify the stale-overwrite warning is present
|
||||
assert "Do NOT overwrite or remove entries" in flush_prompt
|
||||
assert "current live state of memory" in flush_prompt
|
||||
|
||||
def test_flush_works_without_memory_files(self, tmp_path):
|
||||
"""When no memory files exist, flush still runs without the guard."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
empty_dir = tmp_path / "no_memories"
|
||||
empty_dir.mkdir()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_456")
|
||||
|
||||
# Should still run, just without the memory guard section
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "Do NOT overwrite or remove entries" not in flush_prompt
|
||||
assert "Review the conversation above" in flush_prompt
|
||||
|
||||
def test_empty_memory_files_no_injection(self, tmp_path):
|
||||
"""Empty memory files should not trigger the guard section."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("")
|
||||
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_789")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
# No memory content → no guard section
|
||||
assert "current live state of memory" not in flush_prompt
|
||||
|
||||
|
||||
class TestFlushPromptStructure:
|
||||
"""Verify the flush prompt retains its core instructions."""
|
||||
|
||||
def test_core_instructions_present(self):
|
||||
"""The flush prompt should still contain the original guidance."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Make the import fail gracefully so we test without memory files
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_struct")
|
||||
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "automatically reset" in flush_prompt
|
||||
assert "Save any important facts" in flush_prompt
|
||||
assert "consider saving it as a skill" in flush_prompt
|
||||
assert "Do NOT respond to the user" in flush_prompt
|
||||
168
tests/tools/test_ansi_strip.py
Normal file
168
tests/tools/test_ansi_strip.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Comprehensive tests for ANSI escape sequence stripping (ECMA-48).
|
||||
|
||||
The strip_ansi function in tools/ansi_strip.py is the source-level fix for
|
||||
ANSI codes leaking into the model's context via terminal/execute_code output.
|
||||
It must strip ALL terminal escape sequences while preserving legitimate text.
|
||||
"""
|
||||
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
|
||||
class TestStripAnsiBasicSGR:
|
||||
"""Select Graphic Rendition — the most common ANSI sequences."""
|
||||
|
||||
def test_reset(self):
|
||||
assert strip_ansi("\x1b[0m") == ""
|
||||
|
||||
def test_color(self):
|
||||
assert strip_ansi("\x1b[31;1m") == ""
|
||||
|
||||
def test_truecolor_semicolon(self):
|
||||
assert strip_ansi("\x1b[38;2;255;0;0m") == ""
|
||||
|
||||
def test_truecolor_colon_separated(self):
|
||||
"""Modern terminals use colon-separated SGR params."""
|
||||
assert strip_ansi("\x1b[38:2:255:0:0m") == ""
|
||||
assert strip_ansi("\x1b[48:2:0:255:0m") == ""
|
||||
|
||||
|
||||
class TestStripAnsiCSIPrivateMode:
|
||||
"""CSI sequences with ? prefix (DEC private modes)."""
|
||||
|
||||
def test_cursor_show_hide(self):
|
||||
assert strip_ansi("\x1b[?25h") == ""
|
||||
assert strip_ansi("\x1b[?25l") == ""
|
||||
|
||||
def test_alt_screen(self):
|
||||
assert strip_ansi("\x1b[?1049h") == ""
|
||||
assert strip_ansi("\x1b[?1049l") == ""
|
||||
|
||||
def test_bracketed_paste(self):
|
||||
assert strip_ansi("\x1b[?2004h") == ""
|
||||
|
||||
|
||||
class TestStripAnsiCSIIntermediate:
|
||||
"""CSI sequences with intermediate bytes (space, etc.)."""
|
||||
|
||||
def test_cursor_shape(self):
|
||||
assert strip_ansi("\x1b[0 q") == ""
|
||||
assert strip_ansi("\x1b[2 q") == ""
|
||||
assert strip_ansi("\x1b[6 q") == ""
|
||||
|
||||
|
||||
class TestStripAnsiOSC:
|
||||
"""Operating System Command sequences."""
|
||||
|
||||
def test_bel_terminator(self):
|
||||
assert strip_ansi("\x1b]0;title\x07") == ""
|
||||
|
||||
def test_st_terminator(self):
|
||||
assert strip_ansi("\x1b]0;title\x1b\\") == ""
|
||||
|
||||
def test_hyperlink_preserves_text(self):
|
||||
assert strip_ansi(
|
||||
"\x1b]8;;https://example.com\x1b\\click\x1b]8;;\x1b\\"
|
||||
) == "click"
|
||||
|
||||
|
||||
class TestStripAnsiDECPrivate:
|
||||
"""DEC private / Fp escape sequences."""
|
||||
|
||||
def test_save_restore_cursor(self):
|
||||
assert strip_ansi("\x1b7") == ""
|
||||
assert strip_ansi("\x1b8") == ""
|
||||
|
||||
def test_keypad_modes(self):
|
||||
assert strip_ansi("\x1b=") == ""
|
||||
assert strip_ansi("\x1b>") == ""
|
||||
|
||||
|
||||
class TestStripAnsiFe:
|
||||
"""Fe (C1 as 7-bit) escape sequences."""
|
||||
|
||||
def test_reverse_index(self):
|
||||
assert strip_ansi("\x1bM") == ""
|
||||
|
||||
def test_reset_terminal(self):
|
||||
assert strip_ansi("\x1bc") == ""
|
||||
|
||||
def test_index_and_newline(self):
|
||||
assert strip_ansi("\x1bD") == ""
|
||||
assert strip_ansi("\x1bE") == ""
|
||||
|
||||
|
||||
class TestStripAnsiNF:
|
||||
"""nF (character set selection) sequences."""
|
||||
|
||||
def test_charset_selection(self):
|
||||
assert strip_ansi("\x1b(A") == ""
|
||||
assert strip_ansi("\x1b(B") == ""
|
||||
assert strip_ansi("\x1b(0") == ""
|
||||
|
||||
|
||||
class TestStripAnsiDCS:
|
||||
"""Device Control String sequences."""
|
||||
|
||||
def test_dcs(self):
|
||||
assert strip_ansi("\x1bP+q\x1b\\") == ""
|
||||
|
||||
|
||||
class TestStripAnsi8BitC1:
|
||||
"""8-bit C1 control characters."""
|
||||
|
||||
def test_8bit_csi(self):
|
||||
assert strip_ansi("\x9b31m") == ""
|
||||
assert strip_ansi("\x9b38;2;255;0;0m") == ""
|
||||
|
||||
def test_8bit_standalone(self):
|
||||
assert strip_ansi("\x9c") == ""
|
||||
assert strip_ansi("\x9d") == ""
|
||||
assert strip_ansi("\x90") == ""
|
||||
|
||||
|
||||
class TestStripAnsiRealWorld:
|
||||
"""Real-world contamination scenarios from bug reports."""
|
||||
|
||||
def test_colored_shebang(self):
|
||||
"""The original reported bug: shebang corrupted by color codes."""
|
||||
assert strip_ansi(
|
||||
"\x1b[32m#!/usr/bin/env python3\x1b[0m\nprint('hello')"
|
||||
) == "#!/usr/bin/env python3\nprint('hello')"
|
||||
|
||||
def test_stacked_sgr(self):
|
||||
assert strip_ansi(
|
||||
"\x1b[1m\x1b[31m\x1b[42mhello\x1b[0m"
|
||||
) == "hello"
|
||||
|
||||
def test_ansi_mid_code(self):
|
||||
assert strip_ansi(
|
||||
"def foo(\x1b[33m):\x1b[0m\n return 42"
|
||||
) == "def foo():\n return 42"
|
||||
|
||||
|
||||
class TestStripAnsiPassthrough:
|
||||
"""Clean content must pass through unmodified."""
|
||||
|
||||
def test_plain_text(self):
|
||||
assert strip_ansi("normal text") == "normal text"
|
||||
|
||||
def test_empty(self):
|
||||
assert strip_ansi("") == ""
|
||||
|
||||
def test_none(self):
|
||||
assert strip_ansi(None) is None
|
||||
|
||||
def test_whitespace_preserved(self):
|
||||
assert strip_ansi("line1\nline2\ttab") == "line1\nline2\ttab"
|
||||
|
||||
def test_unicode_safe(self):
|
||||
assert strip_ansi("emoji 🎉 and ñ café") == "emoji 🎉 and ñ café"
|
||||
|
||||
def test_backslash_in_code(self):
|
||||
code = "path = 'C:\\\\Users\\\\test'"
|
||||
assert strip_ansi(code) == code
|
||||
|
||||
def test_square_brackets_in_code(self):
|
||||
"""Array indexing must not be confused with CSI."""
|
||||
code = "arr[0] = arr[31]"
|
||||
assert strip_ansi(code) == code
|
||||
@@ -309,3 +309,6 @@ class TestSearchHints:
|
||||
raw = search_tool(pattern="foo", offset=50, limit=50)
|
||||
assert "[Hint:" in raw
|
||||
assert "offset=100" in raw
|
||||
|
||||
|
||||
|
||||
|
||||
176
tests/tools/test_url_safety.py
Normal file
176
tests/tools/test_url_safety.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Tests for SSRF protection in url_safety module."""
|
||||
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
from tools.url_safety import is_safe_url, _is_blocked_ip
|
||||
|
||||
import ipaddress
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIsSafeUrl:
|
||||
def test_public_url_allowed(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert is_safe_url("https://example.com/image.png") is True
|
||||
|
||||
def test_localhost_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("127.0.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://localhost:8080/secret") is False
|
||||
|
||||
def test_loopback_ip_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("127.0.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://127.0.0.1/admin") is False
|
||||
|
||||
def test_private_10_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("10.0.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://internal-service.local/api") is False
|
||||
|
||||
def test_private_172_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("172.16.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://private.corp/data") is False
|
||||
|
||||
def test_private_192_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("192.168.1.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://router.local") is False
|
||||
|
||||
def test_link_local_169_254_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("169.254.169.254", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://169.254.169.254/latest/meta-data/") is False
|
||||
|
||||
def test_metadata_google_internal_blocked(self):
|
||||
assert is_safe_url("http://metadata.google.internal/computeMetadata/v1/") is False
|
||||
|
||||
def test_ipv6_loopback_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("::1", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[::1]:8080/") is False
|
||||
|
||||
def test_dns_failure_blocked(self):
|
||||
"""DNS failures now fail closed — block the request."""
|
||||
with patch("socket.getaddrinfo", side_effect=socket.gaierror("Name resolution failed")):
|
||||
assert is_safe_url("https://nonexistent.example.com") is False
|
||||
|
||||
def test_empty_url_blocked(self):
|
||||
assert is_safe_url("") is False
|
||||
|
||||
def test_no_hostname_blocked(self):
|
||||
assert is_safe_url("http://") is False
|
||||
|
||||
def test_public_ip_allowed(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert is_safe_url("https://example.com") is True
|
||||
|
||||
# ── New tests for hardened SSRF protection ──
|
||||
|
||||
def test_cgnat_100_64_blocked(self):
|
||||
"""100.64.0.0/10 (CGNAT/Shared Address Space) is NOT covered by
|
||||
ipaddress.is_private — must be blocked explicitly."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("100.64.0.1", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://some-cgnat-host.example/") is False
|
||||
|
||||
def test_cgnat_100_127_blocked(self):
|
||||
"""Upper end of CGNAT range (100.127.255.255)."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("100.127.255.254", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://tailscale-peer.example/") is False
|
||||
|
||||
def test_multicast_blocked(self):
|
||||
"""Multicast addresses (224.0.0.0/4) not caught by is_private."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("224.0.0.251", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://mdns-host.local/") is False
|
||||
|
||||
def test_multicast_ipv6_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("ff02::1", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[ff02::1]/") is False
|
||||
|
||||
def test_ipv4_mapped_ipv6_loopback_blocked(self):
|
||||
"""::ffff:127.0.0.1 — IPv4-mapped IPv6 loopback."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("::ffff:127.0.0.1", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[::ffff:127.0.0.1]/") is False
|
||||
|
||||
def test_ipv4_mapped_ipv6_metadata_blocked(self):
|
||||
"""::ffff:169.254.169.254 — IPv4-mapped IPv6 cloud metadata."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("::ffff:169.254.169.254", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[::ffff:169.254.169.254]/") is False
|
||||
|
||||
def test_unspecified_address_blocked(self):
|
||||
"""0.0.0.0 — unspecified address, can bind to all interfaces."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("0.0.0.0", 0)),
|
||||
]):
|
||||
assert is_safe_url("http://0.0.0.0/") is False
|
||||
|
||||
def test_unexpected_error_fails_closed(self):
|
||||
"""Unexpected exceptions should block, not allow."""
|
||||
with patch("tools.url_safety.urlparse", side_effect=ValueError("bad url")):
|
||||
assert is_safe_url("http://evil.com/") is False
|
||||
|
||||
def test_metadata_goog_blocked(self):
|
||||
assert is_safe_url("http://metadata.goog/computeMetadata/v1/") is False
|
||||
|
||||
def test_ipv6_unique_local_blocked(self):
|
||||
"""fc00::/7 — IPv6 unique local addresses."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(10, 1, 6, "", ("fd12::1", 0, 0, 0)),
|
||||
]):
|
||||
assert is_safe_url("http://[fd12::1]/internal") is False
|
||||
|
||||
def test_non_cgnat_100_allowed(self):
|
||||
"""100.0.0.1 is NOT in CGNAT range (100.64.0.0/10), should be allowed."""
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("100.0.0.1", 0)),
|
||||
]):
|
||||
# 100.0.0.1 is a global IP, not in CGNAT range
|
||||
assert is_safe_url("http://legit-host.example/") is True
|
||||
|
||||
|
||||
class TestIsBlockedIp:
|
||||
"""Direct tests for the _is_blocked_ip helper."""
|
||||
|
||||
@pytest.mark.parametrize("ip_str", [
|
||||
"127.0.0.1", "10.0.0.1", "172.16.0.1", "192.168.1.1",
|
||||
"169.254.169.254", "0.0.0.0", "224.0.0.1", "255.255.255.255",
|
||||
"100.64.0.1", "100.100.100.100", "100.127.255.254",
|
||||
"::1", "fe80::1", "fc00::1", "fd12::1", "ff02::1",
|
||||
"::ffff:127.0.0.1", "::ffff:169.254.169.254",
|
||||
])
|
||||
def test_blocked_ips(self, ip_str):
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
assert _is_blocked_ip(ip) is True, f"{ip_str} should be blocked"
|
||||
|
||||
@pytest.mark.parametrize("ip_str", [
|
||||
"8.8.8.8", "93.184.216.34", "1.1.1.1", "100.0.0.1",
|
||||
"2606:4700::1", "2001:4860:4860::8888",
|
||||
])
|
||||
def test_allowed_ips(self, ip_str):
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
assert _is_blocked_ip(ip) is False, f"{ip_str} should be allowed"
|
||||
@@ -33,17 +33,30 @@ class TestValidateImageUrl:
|
||||
assert _validate_image_url("https://example.com/image.jpg") is True
|
||||
|
||||
def test_valid_http_url(self):
|
||||
assert _validate_image_url("http://cdn.example.org/photo.png") is True
|
||||
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert _validate_image_url("http://cdn.example.org/photo.png") is True
|
||||
|
||||
def test_valid_url_without_extension(self):
|
||||
"""CDN endpoints that redirect to images should still pass."""
|
||||
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
|
||||
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
|
||||
|
||||
def test_valid_url_with_query_params(self):
|
||||
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
|
||||
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
|
||||
|
||||
def test_localhost_url_blocked_by_ssrf(self):
|
||||
"""localhost URLs are now blocked by SSRF protection."""
|
||||
assert _validate_image_url("http://localhost:8080/image.png") is False
|
||||
|
||||
def test_valid_url_with_port(self):
|
||||
assert _validate_image_url("http://localhost:8080/image.png") is True
|
||||
assert _validate_image_url("http://example.com:8080/image.png") is True
|
||||
|
||||
def test_valid_url_with_path_only(self):
|
||||
assert _validate_image_url("https://example.com/") is True
|
||||
|
||||
@@ -343,6 +343,8 @@ def test_browser_navigate_allows_when_shared_file_missing(monkeypatch, tmp_path)
|
||||
async def test_web_extract_short_circuits_blocked_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr(
|
||||
web_tools,
|
||||
"check_website_access",
|
||||
@@ -389,6 +391,9 @@ def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypat
|
||||
async def test_web_extract_blocks_redirected_final_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test":
|
||||
return None
|
||||
@@ -428,6 +433,8 @@ async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
|
||||
|
||||
# web_crawl_tool checks for Firecrawl env before website policy
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr(
|
||||
web_tools,
|
||||
"check_website_access",
|
||||
@@ -457,6 +464,8 @@ async def test_web_crawl_blocks_redirected_final_url(monkeypatch):
|
||||
|
||||
# web_crawl_tool checks for Firecrawl env before website policy
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test":
|
||||
|
||||
44
tools/ansi_strip.py
Normal file
44
tools/ansi_strip.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Strip ANSI escape sequences from subprocess output.
|
||||
|
||||
Used by terminal_tool, code_execution_tool, and process_registry to clean
|
||||
command output before returning it to the model. This prevents ANSI codes
|
||||
from entering the model's context — which is the root cause of models
|
||||
copying escape sequences into file writes.
|
||||
|
||||
Covers the full ECMA-48 spec: CSI (including private-mode ``?`` prefix,
|
||||
colon-separated params, intermediate bytes), OSC (BEL and ST terminators),
|
||||
DCS/SOS/PM/APC string sequences, nF multi-byte escapes, Fp/Fe/Fs
|
||||
single-byte escapes, and 8-bit C1 control characters.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
_ANSI_ESCAPE_RE = re.compile(
|
||||
r"\x1b"
|
||||
r"(?:"
|
||||
r"\[[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # CSI sequence
|
||||
r"|\][\s\S]*?(?:\x07|\x1b\\)" # OSC (BEL or ST terminator)
|
||||
r"|[PX^_][\s\S]*?(?:\x1b\\)" # DCS/SOS/PM/APC strings
|
||||
r"|[\x20-\x2f]+[\x30-\x7e]" # nF escape sequences
|
||||
r"|[\x30-\x7e]" # Fp/Fe/Fs single-byte
|
||||
r")"
|
||||
r"|\x9b[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # 8-bit CSI
|
||||
r"|\x9d[\s\S]*?(?:\x07|\x9c)" # 8-bit OSC
|
||||
r"|[\x80-\x9f]", # Other 8-bit C1 controls
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Fast-path check — skip full regex when no escape-like bytes are present.
|
||||
_HAS_ESCAPE = re.compile(r"[\x1b\x80-\x9f]")
|
||||
|
||||
|
||||
def strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape sequences from text.
|
||||
|
||||
Returns the input unchanged (fast path) when no ESC or C1 bytes are
|
||||
present. Safe to call on any string — clean text passes through
|
||||
with negligible overhead.
|
||||
"""
|
||||
if not text or not _HAS_ESCAPE.search(text):
|
||||
return text
|
||||
return _ANSI_ESCAPE_RE.sub("", text)
|
||||
@@ -577,6 +577,12 @@ def execute_code(
|
||||
server_sock = None # prevent double close in finally
|
||||
rpc_thread.join(timeout=3)
|
||||
|
||||
# Strip ANSI escape sequences so the model never sees terminal
|
||||
# formatting — prevents it from copying escapes into file writes.
|
||||
from tools.ansi_strip import strip_ansi
|
||||
stdout_text = strip_ansi(stdout_text)
|
||||
stderr_text = strip_ansi(stderr_text)
|
||||
|
||||
# Build response
|
||||
result: Dict[str, Any] = {
|
||||
"status": status,
|
||||
|
||||
@@ -5,7 +5,6 @@ import errno
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional
|
||||
from tools.file_operations import ShellFileOperations
|
||||
@@ -13,17 +12,6 @@ from agent.redact import redact_sensitive_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Regex to match ANSI escape sequences (CSI codes, OSC codes, simple escapes).
|
||||
# Models occasionally copy these from terminal output into file content.
|
||||
_ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;]*[A-Za-z]|\x1b\][^\x07]*\x07|\x1b[()][A-B012]|\x1b[=>]")
|
||||
|
||||
|
||||
def _strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape sequences from text destined for file writes."""
|
||||
if not text or "\x1b" not in text:
|
||||
return text
|
||||
return _ANSI_ESCAPE_RE.sub("", text)
|
||||
|
||||
|
||||
_EXPECTED_WRITE_ERRNOS = {errno.EACCES, errno.EPERM, errno.EROFS}
|
||||
|
||||
@@ -301,7 +289,6 @@ def notify_other_tool_call(task_id: str = "default"):
|
||||
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||
"""Write content to a file."""
|
||||
try:
|
||||
content = _strip_ansi(content)
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.write_file(path, content)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
@@ -325,13 +312,10 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
return json.dumps({"error": "path required"})
|
||||
if old_string is None or new_string is None:
|
||||
return json.dumps({"error": "old_string and new_string required"})
|
||||
old_string = _strip_ansi(old_string)
|
||||
new_string = _strip_ansi(new_string)
|
||||
result = file_ops.patch_replace(path, old_string, new_string, replace_all)
|
||||
elif mode == "patch":
|
||||
if not patch:
|
||||
return json.dumps({"error": "patch content required"})
|
||||
patch = _strip_ansi(patch)
|
||||
result = file_ops.patch_v4a(patch)
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown mode: {mode}"})
|
||||
|
||||
@@ -426,12 +426,14 @@ class ProcessRegistry:
|
||||
|
||||
def poll(self, session_id: str) -> dict:
|
||||
"""Check status and get new output for a background process."""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
|
||||
with session._lock:
|
||||
output_preview = session.output_buffer[-1000:] if session.output_buffer else ""
|
||||
output_preview = strip_ansi(session.output_buffer[-1000:]) if session.output_buffer else ""
|
||||
|
||||
result = {
|
||||
"session_id": session.id,
|
||||
@@ -450,12 +452,14 @@ class ProcessRegistry:
|
||||
|
||||
def read_log(self, session_id: str, offset: int = 0, limit: int = 200) -> dict:
|
||||
"""Read the full output log with optional pagination by lines."""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
|
||||
with session._lock:
|
||||
full_output = session.output_buffer
|
||||
full_output = strip_ansi(session.output_buffer)
|
||||
|
||||
lines = full_output.splitlines()
|
||||
total_lines = len(lines)
|
||||
@@ -486,6 +490,7 @@ class ProcessRegistry:
|
||||
dict with status ("exited", "timeout", "interrupted", "not_found")
|
||||
and output snapshot.
|
||||
"""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
|
||||
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
|
||||
@@ -513,7 +518,7 @@ class ProcessRegistry:
|
||||
result = {
|
||||
"status": "exited",
|
||||
"exit_code": session.exit_code,
|
||||
"output": session.output_buffer[-2000:],
|
||||
"output": strip_ansi(session.output_buffer[-2000:]),
|
||||
}
|
||||
if timeout_note:
|
||||
result["timeout_note"] = timeout_note
|
||||
@@ -522,7 +527,7 @@ class ProcessRegistry:
|
||||
if _interrupt_event.is_set():
|
||||
result = {
|
||||
"status": "interrupted",
|
||||
"output": session.output_buffer[-1000:],
|
||||
"output": strip_ansi(session.output_buffer[-1000:]),
|
||||
"note": "User sent a new message -- wait interrupted",
|
||||
}
|
||||
if timeout_note:
|
||||
@@ -533,7 +538,7 @@ class ProcessRegistry:
|
||||
|
||||
result = {
|
||||
"status": "timeout",
|
||||
"output": session.output_buffer[-1000:],
|
||||
"output": strip_ansi(session.output_buffer[-1000:]),
|
||||
}
|
||||
if timeout_note:
|
||||
result["timeout_note"] = timeout_note
|
||||
|
||||
@@ -1163,6 +1163,11 @@ def terminal_tool(
|
||||
)
|
||||
output = output[:head_chars] + truncated_notice + output[-tail_chars:]
|
||||
|
||||
# Strip ANSI escape sequences so the model never sees terminal
|
||||
# formatting — prevents it from copying escapes into file writes.
|
||||
from tools.ansi_strip import strip_ansi
|
||||
output = strip_ansi(output)
|
||||
|
||||
# Redact secrets from command output (catches env/printenv leaking keys)
|
||||
from agent.redact import redact_sensitive_text
|
||||
output = redact_sensitive_text(output.strip()) if output else ""
|
||||
|
||||
96
tools/url_safety.py
Normal file
96
tools/url_safety.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""URL safety checks — blocks requests to private/internal network addresses.
|
||||
|
||||
Prevents SSRF (Server-Side Request Forgery) where a malicious prompt or
|
||||
skill could trick the agent into fetching internal resources like cloud
|
||||
metadata endpoints (169.254.169.254), localhost services, or private
|
||||
network hosts.
|
||||
|
||||
Limitations (documented, not fixable at pre-flight level):
|
||||
- DNS rebinding (TOCTOU): an attacker-controlled DNS server with TTL=0
|
||||
can return a public IP for the check, then a private IP for the actual
|
||||
connection. Fixing this requires connection-level validation (e.g.
|
||||
Python's Champion library or an egress proxy like Stripe's Smokescreen).
|
||||
- Redirect-based bypass in vision_tools is mitigated by an httpx event
|
||||
hook that re-validates each redirect target. Web tools use third-party
|
||||
SDKs (Firecrawl/Tavily) where redirect handling is on their servers.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hostnames that should always be blocked regardless of IP resolution
|
||||
_BLOCKED_HOSTNAMES = frozenset({
|
||||
"metadata.google.internal",
|
||||
"metadata.goog",
|
||||
})
|
||||
|
||||
# 100.64.0.0/10 (CGNAT / Shared Address Space, RFC 6598) is NOT covered by
|
||||
# ipaddress.is_private — it returns False for both is_private and is_global.
|
||||
# Must be blocked explicitly. Used by carrier-grade NAT, Tailscale/WireGuard
|
||||
# VPNs, and some cloud internal networks.
|
||||
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
|
||||
|
||||
|
||||
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
"""Return True if the IP should be blocked for SSRF protection."""
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
||||
return True
|
||||
if ip.is_multicast or ip.is_unspecified:
|
||||
return True
|
||||
# CGNAT range not covered by is_private
|
||||
if ip in _CGNAT_NETWORK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_safe_url(url: str) -> bool:
|
||||
"""Return True if the URL target is not a private/internal address.
|
||||
|
||||
Resolves the hostname to an IP and checks against private ranges.
|
||||
Fails closed: DNS errors and unexpected exceptions block the request.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = (parsed.hostname or "").strip().lower()
|
||||
if not hostname:
|
||||
return False
|
||||
|
||||
# Block known internal hostnames
|
||||
if hostname in _BLOCKED_HOSTNAMES:
|
||||
logger.warning("Blocked request to internal hostname: %s", hostname)
|
||||
return False
|
||||
|
||||
# Try to resolve and check IP
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
# DNS resolution failed — fail closed. If DNS can't resolve it,
|
||||
# the HTTP client will also fail, so blocking loses nothing.
|
||||
logger.warning("Blocked request — DNS resolution failed for: %s", hostname)
|
||||
return False
|
||||
|
||||
for family, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if _is_blocked_ip(ip):
|
||||
logger.warning(
|
||||
"Blocked request to private/internal address: %s -> %s",
|
||||
hostname, ip_str,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
# Fail closed on unexpected errors — don't let parsing edge cases
|
||||
# become SSRF bypass vectors
|
||||
logger.warning("Blocked request — URL safety check error for %s: %s", url, exc)
|
||||
return False
|
||||
@@ -69,7 +69,12 @@ def _validate_image_url(url: str) -> bool:
|
||||
if not parsed.netloc:
|
||||
return False
|
||||
|
||||
return True # Allow all well-formed HTTP/HTTPS URLs for flexibility
|
||||
# Block private/internal addresses to prevent SSRF
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
|
||||
@@ -92,12 +97,33 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
||||
# Create parent directories if they don't exist
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def _ssrf_redirect_guard(response):
|
||||
"""Re-validate each redirect target to prevent redirect-based SSRF.
|
||||
|
||||
Without this, an attacker can host a public URL that 302-redirects
|
||||
to http://169.254.169.254/ and bypass the pre-flight is_safe_url check.
|
||||
|
||||
Must be async because httpx.AsyncClient awaits event hooks.
|
||||
"""
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(redirect_url):
|
||||
raise ValueError(
|
||||
f"Blocked redirect to private/internal address: {redirect_url}"
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Download the image with appropriate headers using async httpx
|
||||
# Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum)
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
# SSRF: event_hooks validates each redirect target against private IP ranges
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
response = await client.get(
|
||||
image_url,
|
||||
headers={
|
||||
|
||||
@@ -46,6 +46,7 @@ import httpx
|
||||
from firecrawl import Firecrawl
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.url_safety import is_safe_url
|
||||
from tools.website_policy import check_website_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -861,136 +862,155 @@ async def web_extract_tool(
|
||||
try:
|
||||
logger.info("Extracting content from %d URL(s)", len(urls))
|
||||
|
||||
# Dispatch to the configured backend
|
||||
backend = _get_backend()
|
||||
|
||||
if backend == "parallel":
|
||||
results = await _parallel_extract(urls)
|
||||
elif backend == "tavily":
|
||||
logger.info("Tavily extract: %d URL(s)", len(urls))
|
||||
raw = _tavily_request("extract", {
|
||||
"urls": urls,
|
||||
"include_images": False,
|
||||
})
|
||||
results = _normalize_tavily_documents(raw, fallback_url=urls[0] if urls else "")
|
||||
else:
|
||||
# ── Firecrawl extraction ──
|
||||
# Determine requested formats for Firecrawl v2
|
||||
formats: List[str] = []
|
||||
if format == "markdown":
|
||||
formats = ["markdown"]
|
||||
elif format == "html":
|
||||
formats = ["html"]
|
||||
# ── SSRF protection — filter out private/internal URLs before any backend ──
|
||||
safe_urls = []
|
||||
ssrf_blocked: List[Dict[str, Any]] = []
|
||||
for url in urls:
|
||||
if not is_safe_url(url):
|
||||
ssrf_blocked.append({
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": "Blocked: URL targets a private or internal network address",
|
||||
})
|
||||
else:
|
||||
# Default: request markdown for LLM-readiness and include html as backup
|
||||
formats = ["markdown", "html"]
|
||||
safe_urls.append(url)
|
||||
|
||||
# Always use individual scraping for simplicity and reliability
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
# Dispatch only safe URLs to the configured backend
|
||||
if not safe_urls:
|
||||
results = []
|
||||
else:
|
||||
backend = _get_backend()
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
if backend == "parallel":
|
||||
results = await _parallel_extract(safe_urls)
|
||||
elif backend == "tavily":
|
||||
logger.info("Tavily extract: %d URL(s)", len(safe_urls))
|
||||
raw = _tavily_request("extract", {
|
||||
"urls": safe_urls,
|
||||
"include_images": False,
|
||||
})
|
||||
results = _normalize_tavily_documents(raw, fallback_url=safe_urls[0] if safe_urls else "")
|
||||
else:
|
||||
# ── Firecrawl extraction ──
|
||||
# Determine requested formats for Firecrawl v2
|
||||
formats: List[str] = []
|
||||
if format == "markdown":
|
||||
formats = ["markdown"]
|
||||
elif format == "html":
|
||||
formats = ["html"]
|
||||
else:
|
||||
# Default: request markdown for LLM-readiness and include html as backup
|
||||
formats = ["markdown", "html"]
|
||||
|
||||
# Website policy check — block before fetching
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
results.append({
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
|
||||
})
|
||||
continue
|
||||
# Always use individual scraping for simplicity and reliability
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in safe_urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
|
||||
# Process the result - properly handle object serialization
|
||||
metadata = {}
|
||||
title = ""
|
||||
content_markdown = None
|
||||
content_html = None
|
||||
|
||||
# Extract data from the scrape result
|
||||
if hasattr(scrape_result, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
result_dict = scrape_result.model_dump()
|
||||
content_markdown = result_dict.get('markdown')
|
||||
content_html = result_dict.get('html')
|
||||
metadata = result_dict.get('metadata', {})
|
||||
elif hasattr(scrape_result, '__dict__'):
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(scrape_result, 'markdown', None)
|
||||
content_html = getattr(scrape_result, 'html', None)
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(scrape_result, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(scrape_result, dict):
|
||||
# Already a dictionary
|
||||
content_markdown = scrape_result.get('markdown')
|
||||
content_html = scrape_result.get('html')
|
||||
metadata = scrape_result.get('metadata', {})
|
||||
|
||||
# Ensure metadata is a dict (not an object)
|
||||
if not isinstance(metadata, dict):
|
||||
if hasattr(metadata, 'model_dump'):
|
||||
metadata = metadata.model_dump()
|
||||
elif hasattr(metadata, '__dict__'):
|
||||
metadata = metadata.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Re-check final URL after redirect
|
||||
final_url = metadata.get("sourceURL", url)
|
||||
final_blocked = check_website_access(final_url)
|
||||
if final_blocked:
|
||||
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
|
||||
# Website policy check — block before fetching
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
results.append({
|
||||
"url": final_url, "title": title, "content": "", "raw_content": "",
|
||||
"error": final_blocked["message"],
|
||||
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
# Choose content based on requested format
|
||||
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
|
||||
results.append({
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"content": chosen_content,
|
||||
"raw_content": chosen_content,
|
||||
"metadata": metadata # Now guaranteed to be a dict
|
||||
})
|
||||
# Process the result - properly handle object serialization
|
||||
metadata = {}
|
||||
title = ""
|
||||
content_markdown = None
|
||||
content_html = None
|
||||
|
||||
except Exception as scrape_err:
|
||||
logger.debug("Scrape failed for %s: %s", url, scrape_err)
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": str(scrape_err)
|
||||
})
|
||||
# Extract data from the scrape result
|
||||
if hasattr(scrape_result, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
result_dict = scrape_result.model_dump()
|
||||
content_markdown = result_dict.get('markdown')
|
||||
content_html = result_dict.get('html')
|
||||
metadata = result_dict.get('metadata', {})
|
||||
elif hasattr(scrape_result, '__dict__'):
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(scrape_result, 'markdown', None)
|
||||
content_html = getattr(scrape_result, 'html', None)
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(scrape_result, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(scrape_result, dict):
|
||||
# Already a dictionary
|
||||
content_markdown = scrape_result.get('markdown')
|
||||
content_html = scrape_result.get('html')
|
||||
metadata = scrape_result.get('metadata', {})
|
||||
|
||||
# Ensure metadata is a dict (not an object)
|
||||
if not isinstance(metadata, dict):
|
||||
if hasattr(metadata, 'model_dump'):
|
||||
metadata = metadata.model_dump()
|
||||
elif hasattr(metadata, '__dict__'):
|
||||
metadata = metadata.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Re-check final URL after redirect
|
||||
final_url = metadata.get("sourceURL", url)
|
||||
final_blocked = check_website_access(final_url)
|
||||
if final_blocked:
|
||||
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
|
||||
results.append({
|
||||
"url": final_url, "title": title, "content": "", "raw_content": "",
|
||||
"error": final_blocked["message"],
|
||||
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
# Choose content based on requested format
|
||||
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
|
||||
|
||||
results.append({
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"content": chosen_content,
|
||||
"raw_content": chosen_content,
|
||||
"metadata": metadata # Now guaranteed to be a dict
|
||||
})
|
||||
|
||||
except Exception as scrape_err:
|
||||
logger.debug("Scrape failed for %s: %s", url, scrape_err)
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": str(scrape_err)
|
||||
})
|
||||
|
||||
# Merge any SSRF-blocked results back in
|
||||
if ssrf_blocked:
|
||||
results = ssrf_blocked + results
|
||||
|
||||
response = {"results": results}
|
||||
|
||||
@@ -1173,6 +1193,11 @@ async def web_crawl_tool(
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = f'https://{url}'
|
||||
|
||||
# SSRF protection — block private/internal addresses
|
||||
if not is_safe_url(url):
|
||||
return json.dumps({"results": [{"url": url, "title": "", "content": "",
|
||||
"error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False)
|
||||
|
||||
# Website policy check
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
@@ -1258,6 +1283,11 @@ async def web_crawl_tool(
|
||||
instructions_text = f" with instructions: '{instructions}'" if instructions else ""
|
||||
logger.info("Crawling %s%s", url, instructions_text)
|
||||
|
||||
# SSRF protection — block private/internal addresses
|
||||
if not is_safe_url(url):
|
||||
return json.dumps({"results": [{"url": url, "title": "", "content": "",
|
||||
"error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False)
|
||||
|
||||
# Website policy check — block before crawling
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
|
||||
Reference in New Issue
Block a user