Compare commits

..

7 Commits

Author SHA1 Message Date
Teknium
bc5a67dbf2 fix(gateway): prevent stale memory overwrites by flush agent (#2670)
The gateway memory flush agent reviews old conversation history on session
reset/expiry and writes to memory. It had no awareness of memory changes
made after that conversation ended (by the live agent, cron jobs, or other
sessions), causing silent overwrites of newer entries.

Two fixes:

1. Skip memory flush entirely for cron sessions (session IDs starting with
   'cron_'). Cron sessions are headless with no meaningful user conversation
   to extract memories from.

2. Inject the current live memory state (MEMORY.md + USER.md) directly into
   the flush prompt. The flush agent can now see what's already saved and
   make informed decisions — only adding genuinely new information rather
   than blindly overwriting entries that may have been updated since the
   conversation ended.

Addresses the root cause identified in #2670: the flush agent was making
memory decisions blind to the current state of memory, causing stale
context to overwrite newer entries on gateway restarts and session resets.

Co-authored-by: devorun <devorun@users.noreply.github.com>
Co-authored-by: dlkakbs <dlkakbs@users.noreply.github.com>
2026-03-23 16:05:35 -07:00
Teknium
f9c2565ab4 fix(config): log warning instead of silently swallowing config.yaml errors (#2683)
A bare `except Exception: pass` meant any YAML syntax error, bad value,
or unexpected structure in config.yaml was silently ignored and the
gateway fell back to .env / gateway.json without any indication.
Users had no way to know why their config changes had no effect.

Co-authored-by: sprmn24 <oncuevtv@gmail.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-23 15:54:11 -07:00
Teknium
ad5f973a8d fix(vision): make SSRF redirect guard async for httpx.AsyncClient
httpx.AsyncClient awaits event hooks. The sync _ssrf_redirect_guard
returned None, causing 'object NoneType can't be used in await
expression' on any vision_analyze call that followed redirects.

Caught during live PTY testing of the merged SSRF protection.
2026-03-23 15:44:52 -07:00
Teknium
0791efe2c3 fix(security): add SSRF protection to vision_tools and web_tools (hardened)
* fix(security): add SSRF protection to vision_tools and web_tools

Both vision_analyze and web_extract/web_crawl accept arbitrary URLs
without checking if they target private/internal network addresses.
A prompt-injected or malicious skill could use this to access cloud
metadata endpoints (169.254.169.254), localhost services, or private
network hosts.

Adds a shared url_safety.is_safe_url() that resolves hostnames and
blocks private, loopback, link-local, and reserved IP ranges. Also
blocks known internal hostnames (metadata.google.internal).

Integrated at the URL validation layer in vision_tools and before
each website_policy check in web_tools (extract, crawl).

* test(vision): update localhost test to reflect SSRF protection

The existing test_valid_url_with_port asserted localhost URLs pass
validation. With SSRF protection, localhost is now correctly blocked.
Update the test to verify the block, and add a separate test for
valid URLs with ports using a public hostname.

* fix(security): harden SSRF protection — fail-closed, CGNAT, multicast, redirect guard

Follow-up hardening on top of dieutx's SSRF protection (PR #2630):

- Change fail-open to fail-closed: DNS errors and unexpected exceptions
  now block the request instead of allowing it (OWASP best practice)
- Block CGNAT range (100.64.0.0/10): Python's ipaddress.is_private
  does NOT cover this range (returns False for both is_private and
  is_global). Used by Tailscale/WireGuard and carrier infrastructure.
- Add is_multicast and is_unspecified checks: multicast (224.0.0.0/4)
  and unspecified (0.0.0.0) addresses were not caught by the original
  four-check chain
- Add redirect guard for vision_tools: httpx event hook re-validates
  each redirect target against SSRF checks, preventing the classic
  redirect-based SSRF bypass (302 to internal IP)
- Move SSRF filtering before backend dispatch in web_extract: now
  covers Parallel and Tavily backends, not just Firecrawl
- Extract _is_blocked_ip() helper for cleaner IP range checking
- Add 24 new tests (CGNAT, multicast, IPv4-mapped IPv6, fail-closed
  behavior, parametrized blocked/allowed IP lists)
- Fix existing tests to mock DNS resolution for test hostnames

---------

Co-authored-by: dieutx <dangtc94@gmail.com>
2026-03-23 15:40:42 -07:00
Teknium
934fbe3c06 fix: strip ANSI at the source — clean terminal output before it reaches the model
Root cause: terminal_tool, execute_code, and process_registry returned raw
subprocess output with ANSI escape sequences intact. The model saw these
in tool results and copied them into file writes.

Previous fix (PR #2532) stripped ANSI at the write point in file_tools.py,
but this was a band-aid — regex on file content risks corrupting legitimate
content, and doesn't prevent ANSI from wasting tokens in the model context.

Source-level fix:
- New tools/ansi_strip.py with comprehensive ECMA-48 regex covering CSI
  (incl. private-mode, colon-separated, intermediate bytes), OSC (both
  terminators), DCS/SOS/PM/APC strings, Fp/Fe/Fs/nF escapes, 8-bit C1
- terminal_tool.py: strip output before returning to model
- code_execution_tool.py: strip stdout/stderr before returning
- process_registry.py: strip output in poll/read_log/wait
- file_tools.py: remove _strip_ansi band-aid (no longer needed)

Verified: `ls --color=always` output returned as clean text to model,
file written from that output contains zero ESC bytes.
2026-03-23 07:43:12 -07:00
Teknium
6302e56e7c fix(gateway): add all missing platform allowlist env vars to startup warning check (#2628)
* fix(gateway): added MATRIX_ALLOWED_USERS to list of env vars checked by gateway

* fix(gateway): add all missing platform allowlist env vars to startup check

The startup warning for 'No user allowlists configured' was only checking
TELEGRAM, DISCORD, WHATSAPP, SLACK, and SMS — missing SIGNAL, EMAIL,
MATTERMOST, and DINGTALK. Users of those platforms would see a spurious
warning even with their platform-specific allowlist configured.

Now matches the canonical platform_env_map in _is_user_authorized().

---------

Co-authored-by: SteelPh0enix <wojciech_olech@hotmail.com>
2026-03-23 07:19:14 -07:00
Teknium
868b3c07e3 fix: platform default toolsets silently override tool deselection in hermes tools (#2624)
Cherry-picked from PR #2576 by ereid7, plus read-side fix from 173a5c62.

Both fixes were originally landed in 173a5c62 but were inadvertently
reverted by commit 34be3f8b (a squash-merge that bundled unrelated
tools_config.py changes).

Save side (_save_platform_tools): exclude platform default toolset
names (hermes-cli, hermes-telegram) from preserved entries so they
don't silently re-enable everything.

Read side (_get_platform_tools): when the saved list contains explicit
configurable keys, use direct membership instead of subset inference.
The subset approach is broken when composite toolsets like hermes-cli
resolve to ALL tools.
2026-03-23 07:06:51 -07:00
16 changed files with 924 additions and 148 deletions

View File

@@ -523,8 +523,13 @@ def load_gateway_config() -> GatewayConfig:
os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc)
if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"):
os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower()
except Exception:
pass
except Exception as e:
logger.warning(
"Failed to process config.yaml — falling back to .env / gateway.json values. "
"Check %s for syntax errors. Error: %s",
_home / "config.yaml",
e,
)
config = GatewayConfig.from_dict(gw_data)

View File

@@ -525,6 +525,12 @@ class GatewayRunner:
Synchronous worker — meant to be called via run_in_executor from
an async context so it doesn't block the event loop.
"""
# Skip cron sessions — they run headless with no meaningful user
# conversation to extract memories from.
if old_session_id and old_session_id.startswith("cron_"):
logger.debug("Skipping memory flush for cron session: %s", old_session_id)
return
try:
history = self.session_store.load_transcript(old_session_id)
if not history or len(history) < 4:
@@ -557,6 +563,23 @@ class GatewayRunner:
if m.get("role") in ("user", "assistant") and m.get("content")
]
# Read live memory state from disk so the flush agent can see
# what's already saved and avoid overwriting newer entries.
_current_memory = ""
try:
from tools.memory_tool import MEMORY_DIR
for fname, label in [
("MEMORY.md", "MEMORY (your personal notes)"),
("USER.md", "USER PROFILE (who the user is)"),
]:
fpath = MEMORY_DIR / fname
if fpath.exists():
content = fpath.read_text(encoding="utf-8").strip()
if content:
_current_memory += f"\n\n## Current {label}:\n{content}"
except Exception:
pass # Non-fatal — flush still works, just without the guard
# Give the agent a real turn to think about what to save
flush_prompt = (
"[System: This session is about to be automatically reset due to "
@@ -568,6 +591,20 @@ class GatewayRunner:
"2. If you discovered a reusable workflow or solved a non-trivial "
"problem, consider saving it as a skill.\n"
"3. If nothing is worth saving, that's fine — just skip.\n\n"
)
if _current_memory:
flush_prompt += (
"IMPORTANT — here is the current live state of memory. Other "
"sessions, cron jobs, or the user may have updated it since this "
"conversation ended. Do NOT overwrite or remove entries unless "
"the conversation above reveals something that genuinely "
"supersedes them. Only add new information that is not already "
"captured below."
f"{_current_memory}\n\n"
)
flush_prompt += (
"Do NOT respond to the user. Just use the memory and skill_manage "
"tools if needed, then stop.]"
)
@@ -904,7 +941,9 @@ class GatewayRunner:
os.getenv(v)
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
"SMS_ALLOWED_USERS",
"SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS",
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
"GATEWAY_ALLOWED_USERS")
)
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")

View File

@@ -0,0 +1,167 @@
"""Tests for memory flush stale-overwrite prevention (#2670).
Verifies that:
1. Cron sessions are skipped (no flush for headless cron runs)
2. Current memory state is injected into the flush prompt so the
flush agent can see what's already saved and avoid overwrites
3. The flush still works normally when memory files don't exist
"""
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch, call
def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner._honcho_managers = {}
runner._honcho_configs = {}
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner.adapters = {}
runner.hooks = MagicMock()
runner.session_store = MagicMock()
return runner
_TRANSCRIPT_4_MSGS = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi there"},
{"role": "user", "content": "remember my name is Alice"},
{"role": "assistant", "content": "Got it, Alice!"},
]
class TestCronSessionBypass:
"""Cron sessions should never trigger a memory flush."""
def test_cron_session_skipped(self):
runner = _make_runner()
runner._flush_memories_for_session("cron_job123_20260323_120000")
# session_store.load_transcript should never be called
runner.session_store.load_transcript.assert_not_called()
def test_cron_session_with_honcho_key_skipped(self):
runner = _make_runner()
runner._flush_memories_for_session("cron_daily_20260323", "some-honcho-key")
runner.session_store.load_transcript.assert_not_called()
def test_non_cron_session_proceeds(self):
"""Non-cron sessions should still attempt the flush."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = []
runner._flush_memories_for_session("session_abc123")
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
class TestMemoryInjection:
"""The flush prompt should include current memory state from disk."""
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
"""When memory files exist, their content appears in the flush prompt."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
tmp_agent = MagicMock()
memory_dir = tmp_path / "memories"
memory_dir.mkdir()
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=tmp_agent),
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
):
runner._flush_memories_for_session("session_123")
tmp_agent.run_conversation.assert_called_once()
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
flush_prompt = call_kwargs.get("user_message", "")
# Verify both memory sections appear in the prompt
assert "Agent knows Python" in flush_prompt
assert "User prefers dark mode" in flush_prompt
assert "Name: Alice" in flush_prompt
assert "Timezone: PST" in flush_prompt
# Verify the stale-overwrite warning is present
assert "Do NOT overwrite or remove entries" in flush_prompt
assert "current live state of memory" in flush_prompt
def test_flush_works_without_memory_files(self, tmp_path):
"""When no memory files exist, flush still runs without the guard."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
tmp_agent = MagicMock()
empty_dir = tmp_path / "no_memories"
empty_dir.mkdir()
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=tmp_agent),
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
):
runner._flush_memories_for_session("session_456")
# Should still run, just without the memory guard section
tmp_agent.run_conversation.assert_called_once()
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
assert "Do NOT overwrite or remove entries" not in flush_prompt
assert "Review the conversation above" in flush_prompt
def test_empty_memory_files_no_injection(self, tmp_path):
"""Empty memory files should not trigger the guard section."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
tmp_agent = MagicMock()
memory_dir = tmp_path / "memories"
memory_dir.mkdir()
(memory_dir / "MEMORY.md").write_text("")
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=tmp_agent),
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
):
runner._flush_memories_for_session("session_789")
tmp_agent.run_conversation.assert_called_once()
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
# No memory content → no guard section
assert "current live state of memory" not in flush_prompt
class TestFlushPromptStructure:
"""Verify the flush prompt retains its core instructions."""
def test_core_instructions_present(self):
"""The flush prompt should still contain the original guidance."""
runner = _make_runner()
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
tmp_agent = MagicMock()
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch("run_agent.AIAgent", return_value=tmp_agent),
# Make the import fail gracefully so we test without memory files
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
):
runner._flush_memories_for_session("session_struct")
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
assert "automatically reset" in flush_prompt
assert "Save any important facts" in flush_prompt
assert "consider saving it as a skill" in flush_prompt
assert "Do NOT respond to the user" in flush_prompt

View File

@@ -0,0 +1,168 @@
"""Comprehensive tests for ANSI escape sequence stripping (ECMA-48).
The strip_ansi function in tools/ansi_strip.py is the source-level fix for
ANSI codes leaking into the model's context via terminal/execute_code output.
It must strip ALL terminal escape sequences while preserving legitimate text.
"""
from tools.ansi_strip import strip_ansi
class TestStripAnsiBasicSGR:
"""Select Graphic Rendition — the most common ANSI sequences."""
def test_reset(self):
assert strip_ansi("\x1b[0m") == ""
def test_color(self):
assert strip_ansi("\x1b[31;1m") == ""
def test_truecolor_semicolon(self):
assert strip_ansi("\x1b[38;2;255;0;0m") == ""
def test_truecolor_colon_separated(self):
"""Modern terminals use colon-separated SGR params."""
assert strip_ansi("\x1b[38:2:255:0:0m") == ""
assert strip_ansi("\x1b[48:2:0:255:0m") == ""
class TestStripAnsiCSIPrivateMode:
"""CSI sequences with ? prefix (DEC private modes)."""
def test_cursor_show_hide(self):
assert strip_ansi("\x1b[?25h") == ""
assert strip_ansi("\x1b[?25l") == ""
def test_alt_screen(self):
assert strip_ansi("\x1b[?1049h") == ""
assert strip_ansi("\x1b[?1049l") == ""
def test_bracketed_paste(self):
assert strip_ansi("\x1b[?2004h") == ""
class TestStripAnsiCSIIntermediate:
"""CSI sequences with intermediate bytes (space, etc.)."""
def test_cursor_shape(self):
assert strip_ansi("\x1b[0 q") == ""
assert strip_ansi("\x1b[2 q") == ""
assert strip_ansi("\x1b[6 q") == ""
class TestStripAnsiOSC:
"""Operating System Command sequences."""
def test_bel_terminator(self):
assert strip_ansi("\x1b]0;title\x07") == ""
def test_st_terminator(self):
assert strip_ansi("\x1b]0;title\x1b\\") == ""
def test_hyperlink_preserves_text(self):
assert strip_ansi(
"\x1b]8;;https://example.com\x1b\\click\x1b]8;;\x1b\\"
) == "click"
class TestStripAnsiDECPrivate:
"""DEC private / Fp escape sequences."""
def test_save_restore_cursor(self):
assert strip_ansi("\x1b7") == ""
assert strip_ansi("\x1b8") == ""
def test_keypad_modes(self):
assert strip_ansi("\x1b=") == ""
assert strip_ansi("\x1b>") == ""
class TestStripAnsiFe:
"""Fe (C1 as 7-bit) escape sequences."""
def test_reverse_index(self):
assert strip_ansi("\x1bM") == ""
def test_reset_terminal(self):
assert strip_ansi("\x1bc") == ""
def test_index_and_newline(self):
assert strip_ansi("\x1bD") == ""
assert strip_ansi("\x1bE") == ""
class TestStripAnsiNF:
"""nF (character set selection) sequences."""
def test_charset_selection(self):
assert strip_ansi("\x1b(A") == ""
assert strip_ansi("\x1b(B") == ""
assert strip_ansi("\x1b(0") == ""
class TestStripAnsiDCS:
"""Device Control String sequences."""
def test_dcs(self):
assert strip_ansi("\x1bP+q\x1b\\") == ""
class TestStripAnsi8BitC1:
"""8-bit C1 control characters."""
def test_8bit_csi(self):
assert strip_ansi("\x9b31m") == ""
assert strip_ansi("\x9b38;2;255;0;0m") == ""
def test_8bit_standalone(self):
assert strip_ansi("\x9c") == ""
assert strip_ansi("\x9d") == ""
assert strip_ansi("\x90") == ""
class TestStripAnsiRealWorld:
"""Real-world contamination scenarios from bug reports."""
def test_colored_shebang(self):
"""The original reported bug: shebang corrupted by color codes."""
assert strip_ansi(
"\x1b[32m#!/usr/bin/env python3\x1b[0m\nprint('hello')"
) == "#!/usr/bin/env python3\nprint('hello')"
def test_stacked_sgr(self):
assert strip_ansi(
"\x1b[1m\x1b[31m\x1b[42mhello\x1b[0m"
) == "hello"
def test_ansi_mid_code(self):
assert strip_ansi(
"def foo(\x1b[33m):\x1b[0m\n return 42"
) == "def foo():\n return 42"
class TestStripAnsiPassthrough:
"""Clean content must pass through unmodified."""
def test_plain_text(self):
assert strip_ansi("normal text") == "normal text"
def test_empty(self):
assert strip_ansi("") == ""
def test_none(self):
assert strip_ansi(None) is None
def test_whitespace_preserved(self):
assert strip_ansi("line1\nline2\ttab") == "line1\nline2\ttab"
def test_unicode_safe(self):
assert strip_ansi("emoji 🎉 and ñ café") == "emoji 🎉 and ñ café"
def test_backslash_in_code(self):
code = "path = 'C:\\\\Users\\\\test'"
assert strip_ansi(code) == code
def test_square_brackets_in_code(self):
"""Array indexing must not be confused with CSI."""
code = "arr[0] = arr[31]"
assert strip_ansi(code) == code

View File

@@ -309,3 +309,6 @@ class TestSearchHints:
raw = search_tool(pattern="foo", offset=50, limit=50)
assert "[Hint:" in raw
assert "offset=100" in raw

View File

@@ -0,0 +1,176 @@
"""Tests for SSRF protection in url_safety module."""
import socket
from unittest.mock import patch
from tools.url_safety import is_safe_url, _is_blocked_ip
import ipaddress
import pytest
class TestIsSafeUrl:
def test_public_url_allowed(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert is_safe_url("https://example.com/image.png") is True
def test_localhost_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("127.0.0.1", 0)),
]):
assert is_safe_url("http://localhost:8080/secret") is False
def test_loopback_ip_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("127.0.0.1", 0)),
]):
assert is_safe_url("http://127.0.0.1/admin") is False
def test_private_10_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("10.0.0.1", 0)),
]):
assert is_safe_url("http://internal-service.local/api") is False
def test_private_172_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("172.16.0.1", 0)),
]):
assert is_safe_url("http://private.corp/data") is False
def test_private_192_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("192.168.1.1", 0)),
]):
assert is_safe_url("http://router.local") is False
def test_link_local_169_254_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("169.254.169.254", 0)),
]):
assert is_safe_url("http://169.254.169.254/latest/meta-data/") is False
def test_metadata_google_internal_blocked(self):
assert is_safe_url("http://metadata.google.internal/computeMetadata/v1/") is False
def test_ipv6_loopback_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(10, 1, 6, "", ("::1", 0, 0, 0)),
]):
assert is_safe_url("http://[::1]:8080/") is False
def test_dns_failure_blocked(self):
"""DNS failures now fail closed — block the request."""
with patch("socket.getaddrinfo", side_effect=socket.gaierror("Name resolution failed")):
assert is_safe_url("https://nonexistent.example.com") is False
def test_empty_url_blocked(self):
assert is_safe_url("") is False
def test_no_hostname_blocked(self):
assert is_safe_url("http://") is False
def test_public_ip_allowed(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert is_safe_url("https://example.com") is True
# ── New tests for hardened SSRF protection ──
def test_cgnat_100_64_blocked(self):
"""100.64.0.0/10 (CGNAT/Shared Address Space) is NOT covered by
ipaddress.is_private — must be blocked explicitly."""
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("100.64.0.1", 0)),
]):
assert is_safe_url("http://some-cgnat-host.example/") is False
def test_cgnat_100_127_blocked(self):
"""Upper end of CGNAT range (100.127.255.255)."""
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("100.127.255.254", 0)),
]):
assert is_safe_url("http://tailscale-peer.example/") is False
def test_multicast_blocked(self):
"""Multicast addresses (224.0.0.0/4) not caught by is_private."""
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("224.0.0.251", 0)),
]):
assert is_safe_url("http://mdns-host.local/") is False
def test_multicast_ipv6_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(10, 1, 6, "", ("ff02::1", 0, 0, 0)),
]):
assert is_safe_url("http://[ff02::1]/") is False
def test_ipv4_mapped_ipv6_loopback_blocked(self):
"""::ffff:127.0.0.1 — IPv4-mapped IPv6 loopback."""
with patch("socket.getaddrinfo", return_value=[
(10, 1, 6, "", ("::ffff:127.0.0.1", 0, 0, 0)),
]):
assert is_safe_url("http://[::ffff:127.0.0.1]/") is False
def test_ipv4_mapped_ipv6_metadata_blocked(self):
"""::ffff:169.254.169.254 — IPv4-mapped IPv6 cloud metadata."""
with patch("socket.getaddrinfo", return_value=[
(10, 1, 6, "", ("::ffff:169.254.169.254", 0, 0, 0)),
]):
assert is_safe_url("http://[::ffff:169.254.169.254]/") is False
def test_unspecified_address_blocked(self):
"""0.0.0.0 — unspecified address, can bind to all interfaces."""
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("0.0.0.0", 0)),
]):
assert is_safe_url("http://0.0.0.0/") is False
def test_unexpected_error_fails_closed(self):
"""Unexpected exceptions should block, not allow."""
with patch("tools.url_safety.urlparse", side_effect=ValueError("bad url")):
assert is_safe_url("http://evil.com/") is False
def test_metadata_goog_blocked(self):
assert is_safe_url("http://metadata.goog/computeMetadata/v1/") is False
def test_ipv6_unique_local_blocked(self):
"""fc00::/7 — IPv6 unique local addresses."""
with patch("socket.getaddrinfo", return_value=[
(10, 1, 6, "", ("fd12::1", 0, 0, 0)),
]):
assert is_safe_url("http://[fd12::1]/internal") is False
def test_non_cgnat_100_allowed(self):
"""100.0.0.1 is NOT in CGNAT range (100.64.0.0/10), should be allowed."""
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("100.0.0.1", 0)),
]):
# 100.0.0.1 is a global IP, not in CGNAT range
assert is_safe_url("http://legit-host.example/") is True
class TestIsBlockedIp:
"""Direct tests for the _is_blocked_ip helper."""
@pytest.mark.parametrize("ip_str", [
"127.0.0.1", "10.0.0.1", "172.16.0.1", "192.168.1.1",
"169.254.169.254", "0.0.0.0", "224.0.0.1", "255.255.255.255",
"100.64.0.1", "100.100.100.100", "100.127.255.254",
"::1", "fe80::1", "fc00::1", "fd12::1", "ff02::1",
"::ffff:127.0.0.1", "::ffff:169.254.169.254",
])
def test_blocked_ips(self, ip_str):
ip = ipaddress.ip_address(ip_str)
assert _is_blocked_ip(ip) is True, f"{ip_str} should be blocked"
@pytest.mark.parametrize("ip_str", [
"8.8.8.8", "93.184.216.34", "1.1.1.1", "100.0.0.1",
"2606:4700::1", "2001:4860:4860::8888",
])
def test_allowed_ips(self, ip_str):
ip = ipaddress.ip_address(ip_str)
assert _is_blocked_ip(ip) is False, f"{ip_str} should be allowed"

View File

@@ -33,17 +33,30 @@ class TestValidateImageUrl:
assert _validate_image_url("https://example.com/image.jpg") is True
def test_valid_http_url(self):
assert _validate_image_url("http://cdn.example.org/photo.png") is True
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert _validate_image_url("http://cdn.example.org/photo.png") is True
def test_valid_url_without_extension(self):
"""CDN endpoints that redirect to images should still pass."""
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert _validate_image_url("https://cdn.example.com/abcdef123") is True
def test_valid_url_with_query_params(self):
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
with patch("tools.url_safety.socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
def test_localhost_url_blocked_by_ssrf(self):
"""localhost URLs are now blocked by SSRF protection."""
assert _validate_image_url("http://localhost:8080/image.png") is False
def test_valid_url_with_port(self):
assert _validate_image_url("http://localhost:8080/image.png") is True
assert _validate_image_url("http://example.com:8080/image.png") is True
def test_valid_url_with_path_only(self):
assert _validate_image_url("https://example.com/") is True

View File

@@ -343,6 +343,8 @@ def test_browser_navigate_allows_when_shared_file_missing(monkeypatch, tmp_path)
async def test_web_extract_short_circuits_blocked_url(monkeypatch):
from tools import web_tools
# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
monkeypatch.setattr(
web_tools,
"check_website_access",
@@ -389,6 +391,9 @@ def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypat
async def test_web_extract_blocks_redirected_final_url(monkeypatch):
from tools import web_tools
# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
def fake_check(url):
if url == "https://allowed.test":
return None
@@ -428,6 +433,8 @@ async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
# web_crawl_tool checks for Firecrawl env before website policy
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
monkeypatch.setattr(
web_tools,
"check_website_access",
@@ -457,6 +464,8 @@ async def test_web_crawl_blocks_redirected_final_url(monkeypatch):
# web_crawl_tool checks for Firecrawl env before website policy
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
def fake_check(url):
if url == "https://allowed.test":

44
tools/ansi_strip.py Normal file
View File

@@ -0,0 +1,44 @@
"""Strip ANSI escape sequences from subprocess output.
Used by terminal_tool, code_execution_tool, and process_registry to clean
command output before returning it to the model. This prevents ANSI codes
from entering the model's context — which is the root cause of models
copying escape sequences into file writes.
Covers the full ECMA-48 spec: CSI (including private-mode ``?`` prefix,
colon-separated params, intermediate bytes), OSC (BEL and ST terminators),
DCS/SOS/PM/APC string sequences, nF multi-byte escapes, Fp/Fe/Fs
single-byte escapes, and 8-bit C1 control characters.
"""
import re
_ANSI_ESCAPE_RE = re.compile(
r"\x1b"
r"(?:"
r"\[[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # CSI sequence
r"|\][\s\S]*?(?:\x07|\x1b\\)" # OSC (BEL or ST terminator)
r"|[PX^_][\s\S]*?(?:\x1b\\)" # DCS/SOS/PM/APC strings
r"|[\x20-\x2f]+[\x30-\x7e]" # nF escape sequences
r"|[\x30-\x7e]" # Fp/Fe/Fs single-byte
r")"
r"|\x9b[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # 8-bit CSI
r"|\x9d[\s\S]*?(?:\x07|\x9c)" # 8-bit OSC
r"|[\x80-\x9f]", # Other 8-bit C1 controls
re.DOTALL,
)
# Fast-path check — skip full regex when no escape-like bytes are present.
_HAS_ESCAPE = re.compile(r"[\x1b\x80-\x9f]")
def strip_ansi(text: str) -> str:
"""Remove ANSI escape sequences from text.
Returns the input unchanged (fast path) when no ESC or C1 bytes are
present. Safe to call on any string — clean text passes through
with negligible overhead.
"""
if not text or not _HAS_ESCAPE.search(text):
return text
return _ANSI_ESCAPE_RE.sub("", text)

View File

@@ -577,6 +577,12 @@ def execute_code(
server_sock = None # prevent double close in finally
rpc_thread.join(timeout=3)
# Strip ANSI escape sequences so the model never sees terminal
# formatting — prevents it from copying escapes into file writes.
from tools.ansi_strip import strip_ansi
stdout_text = strip_ansi(stdout_text)
stderr_text = strip_ansi(stderr_text)
# Build response
result: Dict[str, Any] = {
"status": status,

View File

@@ -5,7 +5,6 @@ import errno
import json
import logging
import os
import re
import threading
from typing import Optional
from tools.file_operations import ShellFileOperations
@@ -13,17 +12,6 @@ from agent.redact import redact_sensitive_text
logger = logging.getLogger(__name__)
# Regex to match ANSI escape sequences (CSI codes, OSC codes, simple escapes).
# Models occasionally copy these from terminal output into file content.
_ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;]*[A-Za-z]|\x1b\][^\x07]*\x07|\x1b[()][A-B012]|\x1b[=>]")
def _strip_ansi(text: str) -> str:
"""Remove ANSI escape sequences from text destined for file writes."""
if not text or "\x1b" not in text:
return text
return _ANSI_ESCAPE_RE.sub("", text)
_EXPECTED_WRITE_ERRNOS = {errno.EACCES, errno.EPERM, errno.EROFS}
@@ -301,7 +289,6 @@ def notify_other_tool_call(task_id: str = "default"):
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
"""Write content to a file."""
try:
content = _strip_ansi(content)
file_ops = _get_file_ops(task_id)
result = file_ops.write_file(path, content)
return json.dumps(result.to_dict(), ensure_ascii=False)
@@ -325,13 +312,10 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
return json.dumps({"error": "path required"})
if old_string is None or new_string is None:
return json.dumps({"error": "old_string and new_string required"})
old_string = _strip_ansi(old_string)
new_string = _strip_ansi(new_string)
result = file_ops.patch_replace(path, old_string, new_string, replace_all)
elif mode == "patch":
if not patch:
return json.dumps({"error": "patch content required"})
patch = _strip_ansi(patch)
result = file_ops.patch_v4a(patch)
else:
return json.dumps({"error": f"Unknown mode: {mode}"})

View File

@@ -426,12 +426,14 @@ class ProcessRegistry:
def poll(self, session_id: str) -> dict:
"""Check status and get new output for a background process."""
from tools.ansi_strip import strip_ansi
session = self.get(session_id)
if session is None:
return {"status": "not_found", "error": f"No process with ID {session_id}"}
with session._lock:
output_preview = session.output_buffer[-1000:] if session.output_buffer else ""
output_preview = strip_ansi(session.output_buffer[-1000:]) if session.output_buffer else ""
result = {
"session_id": session.id,
@@ -450,12 +452,14 @@ class ProcessRegistry:
def read_log(self, session_id: str, offset: int = 0, limit: int = 200) -> dict:
"""Read the full output log with optional pagination by lines."""
from tools.ansi_strip import strip_ansi
session = self.get(session_id)
if session is None:
return {"status": "not_found", "error": f"No process with ID {session_id}"}
with session._lock:
full_output = session.output_buffer
full_output = strip_ansi(session.output_buffer)
lines = full_output.splitlines()
total_lines = len(lines)
@@ -486,6 +490,7 @@ class ProcessRegistry:
dict with status ("exited", "timeout", "interrupted", "not_found")
and output snapshot.
"""
from tools.ansi_strip import strip_ansi
from tools.terminal_tool import _interrupt_event
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
@@ -513,7 +518,7 @@ class ProcessRegistry:
result = {
"status": "exited",
"exit_code": session.exit_code,
"output": session.output_buffer[-2000:],
"output": strip_ansi(session.output_buffer[-2000:]),
}
if timeout_note:
result["timeout_note"] = timeout_note
@@ -522,7 +527,7 @@ class ProcessRegistry:
if _interrupt_event.is_set():
result = {
"status": "interrupted",
"output": session.output_buffer[-1000:],
"output": strip_ansi(session.output_buffer[-1000:]),
"note": "User sent a new message -- wait interrupted",
}
if timeout_note:
@@ -533,7 +538,7 @@ class ProcessRegistry:
result = {
"status": "timeout",
"output": session.output_buffer[-1000:],
"output": strip_ansi(session.output_buffer[-1000:]),
}
if timeout_note:
result["timeout_note"] = timeout_note

View File

@@ -1163,6 +1163,11 @@ def terminal_tool(
)
output = output[:head_chars] + truncated_notice + output[-tail_chars:]
# Strip ANSI escape sequences so the model never sees terminal
# formatting — prevents it from copying escapes into file writes.
from tools.ansi_strip import strip_ansi
output = strip_ansi(output)
# Redact secrets from command output (catches env/printenv leaking keys)
from agent.redact import redact_sensitive_text
output = redact_sensitive_text(output.strip()) if output else ""

96
tools/url_safety.py Normal file
View File

@@ -0,0 +1,96 @@
"""URL safety checks — blocks requests to private/internal network addresses.
Prevents SSRF (Server-Side Request Forgery) where a malicious prompt or
skill could trick the agent into fetching internal resources like cloud
metadata endpoints (169.254.169.254), localhost services, or private
network hosts.
Limitations (documented, not fixable at pre-flight level):
- DNS rebinding (TOCTOU): an attacker-controlled DNS server with TTL=0
can return a public IP for the check, then a private IP for the actual
connection. Fixing this requires connection-level validation (e.g.
Python's Champion library or an egress proxy like Stripe's Smokescreen).
- Redirect-based bypass in vision_tools is mitigated by an httpx event
hook that re-validates each redirect target. Web tools use third-party
SDKs (Firecrawl/Tavily) where redirect handling is on their servers.
"""
import ipaddress
import logging
import socket
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# Hostnames that should always be blocked regardless of IP resolution
_BLOCKED_HOSTNAMES = frozenset({
"metadata.google.internal",
"metadata.goog",
})
# 100.64.0.0/10 (CGNAT / Shared Address Space, RFC 6598) is NOT covered by
# ipaddress.is_private — it returns False for both is_private and is_global.
# Must be blocked explicitly. Used by carrier-grade NAT, Tailscale/WireGuard
# VPNs, and some cloud internal networks.
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return True if the IP should be blocked for SSRF protection."""
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
return True
if ip.is_multicast or ip.is_unspecified:
return True
# CGNAT range not covered by is_private
if ip in _CGNAT_NETWORK:
return True
return False
def is_safe_url(url: str) -> bool:
"""Return True if the URL target is not a private/internal address.
Resolves the hostname to an IP and checks against private ranges.
Fails closed: DNS errors and unexpected exceptions block the request.
"""
try:
parsed = urlparse(url)
hostname = (parsed.hostname or "").strip().lower()
if not hostname:
return False
# Block known internal hostnames
if hostname in _BLOCKED_HOSTNAMES:
logger.warning("Blocked request to internal hostname: %s", hostname)
return False
# Try to resolve and check IP
try:
addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.gaierror:
# DNS resolution failed — fail closed. If DNS can't resolve it,
# the HTTP client will also fail, so blocking loses nothing.
logger.warning("Blocked request — DNS resolution failed for: %s", hostname)
return False
for family, _, _, _, sockaddr in addr_info:
ip_str = sockaddr[0]
try:
ip = ipaddress.ip_address(ip_str)
except ValueError:
continue
if _is_blocked_ip(ip):
logger.warning(
"Blocked request to private/internal address: %s -> %s",
hostname, ip_str,
)
return False
return True
except Exception as exc:
# Fail closed on unexpected errors — don't let parsing edge cases
# become SSRF bypass vectors
logger.warning("Blocked request — URL safety check error for %s: %s", url, exc)
return False

View File

@@ -69,7 +69,12 @@ def _validate_image_url(url: str) -> bool:
if not parsed.netloc:
return False
return True # Allow all well-formed HTTP/HTTPS URLs for flexibility
# Block private/internal addresses to prevent SSRF
from tools.url_safety import is_safe_url
if not is_safe_url(url):
return False
return True
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
@@ -92,12 +97,33 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
# Create parent directories if they don't exist
destination.parent.mkdir(parents=True, exist_ok=True)
async def _ssrf_redirect_guard(response):
"""Re-validate each redirect target to prevent redirect-based SSRF.
Without this, an attacker can host a public URL that 302-redirects
to http://169.254.169.254/ and bypass the pre-flight is_safe_url check.
Must be async because httpx.AsyncClient awaits event hooks.
"""
if response.is_redirect and response.next_request:
redirect_url = str(response.next_request.url)
from tools.url_safety import is_safe_url
if not is_safe_url(redirect_url):
raise ValueError(
f"Blocked redirect to private/internal address: {redirect_url}"
)
last_error = None
for attempt in range(max_retries):
try:
# Download the image with appropriate headers using async httpx
# Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum)
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
# SSRF: event_hooks validates each redirect target against private IP ranges
async with httpx.AsyncClient(
timeout=30.0,
follow_redirects=True,
event_hooks={"response": [_ssrf_redirect_guard]},
) as client:
response = await client.get(
image_url,
headers={

View File

@@ -46,6 +46,7 @@ import httpx
from firecrawl import Firecrawl
from agent.auxiliary_client import async_call_llm
from tools.debug_helpers import DebugSession
from tools.url_safety import is_safe_url
from tools.website_policy import check_website_access
logger = logging.getLogger(__name__)
@@ -861,136 +862,155 @@ async def web_extract_tool(
try:
logger.info("Extracting content from %d URL(s)", len(urls))
# Dispatch to the configured backend
backend = _get_backend()
if backend == "parallel":
results = await _parallel_extract(urls)
elif backend == "tavily":
logger.info("Tavily extract: %d URL(s)", len(urls))
raw = _tavily_request("extract", {
"urls": urls,
"include_images": False,
})
results = _normalize_tavily_documents(raw, fallback_url=urls[0] if urls else "")
else:
# ── Firecrawl extraction ──
# Determine requested formats for Firecrawl v2
formats: List[str] = []
if format == "markdown":
formats = ["markdown"]
elif format == "html":
formats = ["html"]
# ── SSRF protection — filter out private/internal URLs before any backend ──
safe_urls = []
ssrf_blocked: List[Dict[str, Any]] = []
for url in urls:
if not is_safe_url(url):
ssrf_blocked.append({
"url": url, "title": "", "content": "",
"error": "Blocked: URL targets a private or internal network address",
})
else:
# Default: request markdown for LLM-readiness and include html as backup
formats = ["markdown", "html"]
safe_urls.append(url)
# Always use individual scraping for simplicity and reliability
# Batch scraping adds complexity without much benefit for small numbers of URLs
results: List[Dict[str, Any]] = []
# Dispatch only safe URLs to the configured backend
if not safe_urls:
results = []
else:
backend = _get_backend()
from tools.interrupt import is_interrupted as _is_interrupted
for url in urls:
if _is_interrupted():
results.append({"url": url, "error": "Interrupted", "title": ""})
continue
if backend == "parallel":
results = await _parallel_extract(safe_urls)
elif backend == "tavily":
logger.info("Tavily extract: %d URL(s)", len(safe_urls))
raw = _tavily_request("extract", {
"urls": safe_urls,
"include_images": False,
})
results = _normalize_tavily_documents(raw, fallback_url=safe_urls[0] if safe_urls else "")
else:
# ── Firecrawl extraction ──
# Determine requested formats for Firecrawl v2
formats: List[str] = []
if format == "markdown":
formats = ["markdown"]
elif format == "html":
formats = ["html"]
else:
# Default: request markdown for LLM-readiness and include html as backup
formats = ["markdown", "html"]
# Website policy check — block before fetching
blocked = check_website_access(url)
if blocked:
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
results.append({
"url": url, "title": "", "content": "",
"error": blocked["message"],
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
})
continue
# Always use individual scraping for simplicity and reliability
# Batch scraping adds complexity without much benefit for small numbers of URLs
results: List[Dict[str, Any]] = []
try:
logger.info("Scraping: %s", url)
scrape_result = _get_firecrawl_client().scrape(
url=url,
formats=formats
)
from tools.interrupt import is_interrupted as _is_interrupted
for url in safe_urls:
if _is_interrupted():
results.append({"url": url, "error": "Interrupted", "title": ""})
continue
# Process the result - properly handle object serialization
metadata = {}
title = ""
content_markdown = None
content_html = None
# Extract data from the scrape result
if hasattr(scrape_result, 'model_dump'):
# Pydantic model - use model_dump to get dict
result_dict = scrape_result.model_dump()
content_markdown = result_dict.get('markdown')
content_html = result_dict.get('html')
metadata = result_dict.get('metadata', {})
elif hasattr(scrape_result, '__dict__'):
# Regular object with attributes
content_markdown = getattr(scrape_result, 'markdown', None)
content_html = getattr(scrape_result, 'html', None)
# Handle metadata - convert to dict if it's an object
metadata_obj = getattr(scrape_result, 'metadata', {})
if hasattr(metadata_obj, 'model_dump'):
metadata = metadata_obj.model_dump()
elif hasattr(metadata_obj, '__dict__'):
metadata = metadata_obj.__dict__
elif isinstance(metadata_obj, dict):
metadata = metadata_obj
else:
metadata = {}
elif isinstance(scrape_result, dict):
# Already a dictionary
content_markdown = scrape_result.get('markdown')
content_html = scrape_result.get('html')
metadata = scrape_result.get('metadata', {})
# Ensure metadata is a dict (not an object)
if not isinstance(metadata, dict):
if hasattr(metadata, 'model_dump'):
metadata = metadata.model_dump()
elif hasattr(metadata, '__dict__'):
metadata = metadata.__dict__
else:
metadata = {}
# Get title from metadata
title = metadata.get("title", "")
# Re-check final URL after redirect
final_url = metadata.get("sourceURL", url)
final_blocked = check_website_access(final_url)
if final_blocked:
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
# Website policy check — block before fetching
blocked = check_website_access(url)
if blocked:
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
results.append({
"url": final_url, "title": title, "content": "", "raw_content": "",
"error": final_blocked["message"],
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
"url": url, "title": "", "content": "",
"error": blocked["message"],
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
})
continue
# Choose content based on requested format
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
try:
logger.info("Scraping: %s", url)
scrape_result = _get_firecrawl_client().scrape(
url=url,
formats=formats
)
results.append({
"url": final_url,
"title": title,
"content": chosen_content,
"raw_content": chosen_content,
"metadata": metadata # Now guaranteed to be a dict
})
# Process the result - properly handle object serialization
metadata = {}
title = ""
content_markdown = None
content_html = None
except Exception as scrape_err:
logger.debug("Scrape failed for %s: %s", url, scrape_err)
results.append({
"url": url,
"title": "",
"content": "",
"raw_content": "",
"error": str(scrape_err)
})
# Extract data from the scrape result
if hasattr(scrape_result, 'model_dump'):
# Pydantic model - use model_dump to get dict
result_dict = scrape_result.model_dump()
content_markdown = result_dict.get('markdown')
content_html = result_dict.get('html')
metadata = result_dict.get('metadata', {})
elif hasattr(scrape_result, '__dict__'):
# Regular object with attributes
content_markdown = getattr(scrape_result, 'markdown', None)
content_html = getattr(scrape_result, 'html', None)
# Handle metadata - convert to dict if it's an object
metadata_obj = getattr(scrape_result, 'metadata', {})
if hasattr(metadata_obj, 'model_dump'):
metadata = metadata_obj.model_dump()
elif hasattr(metadata_obj, '__dict__'):
metadata = metadata_obj.__dict__
elif isinstance(metadata_obj, dict):
metadata = metadata_obj
else:
metadata = {}
elif isinstance(scrape_result, dict):
# Already a dictionary
content_markdown = scrape_result.get('markdown')
content_html = scrape_result.get('html')
metadata = scrape_result.get('metadata', {})
# Ensure metadata is a dict (not an object)
if not isinstance(metadata, dict):
if hasattr(metadata, 'model_dump'):
metadata = metadata.model_dump()
elif hasattr(metadata, '__dict__'):
metadata = metadata.__dict__
else:
metadata = {}
# Get title from metadata
title = metadata.get("title", "")
# Re-check final URL after redirect
final_url = metadata.get("sourceURL", url)
final_blocked = check_website_access(final_url)
if final_blocked:
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
results.append({
"url": final_url, "title": title, "content": "", "raw_content": "",
"error": final_blocked["message"],
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
})
continue
# Choose content based on requested format
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
results.append({
"url": final_url,
"title": title,
"content": chosen_content,
"raw_content": chosen_content,
"metadata": metadata # Now guaranteed to be a dict
})
except Exception as scrape_err:
logger.debug("Scrape failed for %s: %s", url, scrape_err)
results.append({
"url": url,
"title": "",
"content": "",
"raw_content": "",
"error": str(scrape_err)
})
# Merge any SSRF-blocked results back in
if ssrf_blocked:
results = ssrf_blocked + results
response = {"results": results}
@@ -1173,6 +1193,11 @@ async def web_crawl_tool(
if not url.startswith(('http://', 'https://')):
url = f'https://{url}'
# SSRF protection — block private/internal addresses
if not is_safe_url(url):
return json.dumps({"results": [{"url": url, "title": "", "content": "",
"error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False)
# Website policy check
blocked = check_website_access(url)
if blocked:
@@ -1258,6 +1283,11 @@ async def web_crawl_tool(
instructions_text = f" with instructions: '{instructions}'" if instructions else ""
logger.info("Crawling %s%s", url, instructions_text)
# SSRF protection — block private/internal addresses
if not is_safe_url(url):
return json.dumps({"results": [{"url": url, "title": "", "content": "",
"error": "Blocked: URL targets a private or internal network address"}]}, ensure_ascii=False)
# Website policy check — block before crawling
blocked = check_website_access(url)
if blocked: