mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-04 09:47:54 +08:00
Compare commits
8 Commits
fix/plugin
...
salvage/bu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e1fc0ad536 | ||
|
|
177e42e90a | ||
|
|
224fc4c29f | ||
|
|
f9a94c65e6 | ||
|
|
e36917867a | ||
|
|
4aa97af895 | ||
|
|
ae9c1f22c1 | ||
|
|
2a9e115210 |
@@ -168,7 +168,7 @@ def _build_skill_message(
|
|||||||
subdir_path = skill_dir / subdir
|
subdir_path = skill_dir / subdir
|
||||||
if subdir_path.exists():
|
if subdir_path.exists():
|
||||||
for f in sorted(subdir_path.rglob("*")):
|
for f in sorted(subdir_path.rglob("*")):
|
||||||
if f.is_file():
|
if f.is_file() and not f.is_symlink():
|
||||||
rel = str(f.relative_to(skill_dir))
|
rel = str(f.relative_to(skill_dir))
|
||||||
supporting.append(rel)
|
supporting.append(rel)
|
||||||
|
|
||||||
|
|||||||
@@ -442,6 +442,14 @@ def _run_job_script(script_path: str) -> tuple[bool, str]:
|
|||||||
stdout = (result.stdout or "").strip()
|
stdout = (result.stdout or "").strip()
|
||||||
stderr = (result.stderr or "").strip()
|
stderr = (result.stderr or "").strip()
|
||||||
|
|
||||||
|
# Redact secrets from both stdout and stderr before any return path.
|
||||||
|
try:
|
||||||
|
from agent.redact import redact_sensitive_text
|
||||||
|
stdout = redact_sensitive_text(stdout)
|
||||||
|
stderr = redact_sensitive_text(stderr)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
parts = [f"Script exited with code {result.returncode}"]
|
parts = [f"Script exited with code {result.returncode}"]
|
||||||
if stderr:
|
if stderr:
|
||||||
@@ -450,13 +458,6 @@ def _run_job_script(script_path: str) -> tuple[bool, str]:
|
|||||||
parts.append(f"stdout:\n{stdout}")
|
parts.append(f"stdout:\n{stdout}")
|
||||||
return False, "\n".join(parts)
|
return False, "\n".join(parts)
|
||||||
|
|
||||||
# Redact any secrets that may appear in script output before
|
|
||||||
# they are injected into the LLM prompt context.
|
|
||||||
try:
|
|
||||||
from agent.redact import redact_sensitive_text
|
|
||||||
stdout = redact_sensitive_text(stdout)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return True, stdout
|
return True, stdout
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ class HermesToolCallParser(ToolCallParser):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
tc_data = json.loads(raw_json)
|
tc_data = json.loads(raw_json)
|
||||||
|
if "name" not in tc_data:
|
||||||
|
continue
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ChatCompletionMessageToolCall(
|
ChatCompletionMessageToolCall(
|
||||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||||
|
|||||||
@@ -89,6 +89,8 @@ class MistralToolCallParser(ToolCallParser):
|
|||||||
parsed = [parsed]
|
parsed = [parsed]
|
||||||
|
|
||||||
for tc in parsed:
|
for tc in parsed:
|
||||||
|
if "name" not in tc:
|
||||||
|
continue
|
||||||
args = tc.get("arguments", {})
|
args = tc.get("arguments", {})
|
||||||
if isinstance(args, dict):
|
if isinstance(args, dict):
|
||||||
args = json.dumps(args, ensure_ascii=False)
|
args = json.dumps(args, ensure_ascii=False)
|
||||||
|
|||||||
@@ -501,6 +501,10 @@ def _get_platform_tools(
|
|||||||
default_ts = PLATFORMS[platform]["default_toolset"]
|
default_ts = PLATFORMS[platform]["default_toolset"]
|
||||||
toolset_names = [default_ts]
|
toolset_names = [default_ts]
|
||||||
|
|
||||||
|
# YAML may parse bare numeric names (e.g. ``12306:``) as int.
|
||||||
|
# Normalise to str so downstream sorted() never mixes types.
|
||||||
|
toolset_names = [str(ts) for ts in toolset_names]
|
||||||
|
|
||||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||||
|
|
||||||
# If the saved list contains any configurable keys directly, the user
|
# If the saved list contains any configurable keys directly, the user
|
||||||
@@ -559,7 +563,7 @@ def _get_platform_tools(
|
|||||||
# Special sentinel: "no_mcp" in the toolset list disables all MCP servers.
|
# Special sentinel: "no_mcp" in the toolset list disables all MCP servers.
|
||||||
mcp_servers = config.get("mcp_servers") or {}
|
mcp_servers = config.get("mcp_servers") or {}
|
||||||
enabled_mcp_servers = {
|
enabled_mcp_servers = {
|
||||||
name
|
str(name)
|
||||||
for name, server_cfg in mcp_servers.items()
|
for name, server_cfg in mcp_servers.items()
|
||||||
if isinstance(server_cfg, dict)
|
if isinstance(server_cfg, dict)
|
||||||
and _parse_enabled_flag(server_cfg.get("enabled", True), default=True)
|
and _parse_enabled_flag(server_cfg.get("enabled", True), default=True)
|
||||||
|
|||||||
@@ -345,6 +345,11 @@ class TestBlockingApprovalE2E:
|
|||||||
|
|
||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
_clear_approval_state()
|
_clear_approval_state()
|
||||||
|
os.environ.pop("HERMES_YOLO_MODE", None)
|
||||||
|
os.environ.pop("HERMES_INTERACTIVE", None)
|
||||||
|
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||||
|
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||||
|
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||||
|
|
||||||
def test_blocking_approval_approve_once(self):
|
def test_blocking_approval_approve_once(self):
|
||||||
"""check_all_command_guards blocks until resolve_gateway_approval is called."""
|
"""check_all_command_guards blocks until resolve_gateway_approval is called."""
|
||||||
@@ -364,6 +369,7 @@ class TestBlockingApprovalE2E:
|
|||||||
from tools.approval import reset_current_session_key, set_current_session_key
|
from tools.approval import reset_current_session_key, set_current_session_key
|
||||||
|
|
||||||
token = set_current_session_key(session_key)
|
token = set_current_session_key(session_key)
|
||||||
|
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||||
try:
|
try:
|
||||||
@@ -371,6 +377,7 @@ class TestBlockingApprovalE2E:
|
|||||||
"rm -rf /important", "local"
|
"rm -rf /important", "local"
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||||
reset_current_session_key(token)
|
reset_current_session_key(token)
|
||||||
@@ -410,6 +417,7 @@ class TestBlockingApprovalE2E:
|
|||||||
from tools.approval import reset_current_session_key, set_current_session_key
|
from tools.approval import reset_current_session_key, set_current_session_key
|
||||||
|
|
||||||
token = set_current_session_key(session_key)
|
token = set_current_session_key(session_key)
|
||||||
|
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||||
try:
|
try:
|
||||||
@@ -417,6 +425,7 @@ class TestBlockingApprovalE2E:
|
|||||||
"rm -rf /important", "local"
|
"rm -rf /important", "local"
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||||
reset_current_session_key(token)
|
reset_current_session_key(token)
|
||||||
@@ -451,6 +460,7 @@ class TestBlockingApprovalE2E:
|
|||||||
from tools.approval import reset_current_session_key, set_current_session_key
|
from tools.approval import reset_current_session_key, set_current_session_key
|
||||||
|
|
||||||
token = set_current_session_key(session_key)
|
token = set_current_session_key(session_key)
|
||||||
|
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||||
try:
|
try:
|
||||||
@@ -460,6 +470,7 @@ class TestBlockingApprovalE2E:
|
|||||||
"rm -rf /important", "local"
|
"rm -rf /important", "local"
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||||
reset_current_session_key(token)
|
reset_current_session_key(token)
|
||||||
@@ -491,11 +502,13 @@ class TestBlockingApprovalE2E:
|
|||||||
from tools.approval import reset_current_session_key, set_current_session_key
|
from tools.approval import reset_current_session_key, set_current_session_key
|
||||||
|
|
||||||
token = set_current_session_key(session_key)
|
token = set_current_session_key(session_key)
|
||||||
|
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||||
try:
|
try:
|
||||||
results[idx] = check_all_command_guards(cmd, "local")
|
results[idx] = check_all_command_guards(cmd, "local")
|
||||||
finally:
|
finally:
|
||||||
|
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||||
reset_current_session_key(token)
|
reset_current_session_key(token)
|
||||||
@@ -546,11 +559,13 @@ class TestBlockingApprovalE2E:
|
|||||||
from tools.approval import reset_current_session_key, set_current_session_key
|
from tools.approval import reset_current_session_key, set_current_session_key
|
||||||
|
|
||||||
token = set_current_session_key(session_key)
|
token = set_current_session_key(session_key)
|
||||||
|
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||||
try:
|
try:
|
||||||
results[idx] = check_all_command_guards(cmd, "local")
|
results[idx] = check_all_command_guards(cmd, "local")
|
||||||
finally:
|
finally:
|
||||||
|
os.environ.pop("HERMES_GATEWAY_SESSION", None)
|
||||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||||
reset_current_session_key(token)
|
reset_current_session_key(token)
|
||||||
|
|||||||
@@ -428,3 +428,31 @@ class TestPlatformToolsetConsistency:
|
|||||||
f"Platform {platform!r} in tools_config but missing from "
|
f"Platform {platform!r} in tools_config but missing from "
|
||||||
f"skills_config PLATFORMS"
|
f"skills_config PLATFORMS"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_numeric_mcp_server_name_does_not_crash_sorted():
|
||||||
|
"""YAML parses bare numeric keys (e.g. ``12306:``) as int.
|
||||||
|
|
||||||
|
_get_platform_tools must normalise them to str so that sorted()
|
||||||
|
on the returned set never raises TypeError on mixed int/str.
|
||||||
|
|
||||||
|
Regression test for https://github.com/NousResearch/hermes-agent/issues/6901
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"platform_toolsets": {"cli": ["web", 12306]},
|
||||||
|
"mcp_servers": {
|
||||||
|
12306: {"url": "https://example.com/mcp"},
|
||||||
|
"normal-server": {"url": "https://example.com/mcp2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled = _get_platform_tools(config, "cli")
|
||||||
|
|
||||||
|
# All names must be str — no int leaking through
|
||||||
|
assert all(isinstance(name, str) for name in enabled), (
|
||||||
|
f"Non-string toolset names found: {enabled}"
|
||||||
|
)
|
||||||
|
assert "12306" in enabled
|
||||||
|
|
||||||
|
# sorted() must not raise TypeError
|
||||||
|
sorted(enabled)
|
||||||
|
|||||||
@@ -156,6 +156,8 @@ class TestSessionKeyContext:
|
|||||||
assert "reset_current_session_key" in called_names
|
assert "reset_current_session_key" in called_names
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestRmFalsePositiveFix:
|
class TestRmFalsePositiveFix:
|
||||||
"""Regression tests: filenames starting with 'r' must NOT trigger recursive delete."""
|
"""Regression tests: filenames starting with 'r' must NOT trigger recursive delete."""
|
||||||
|
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ def test_openai_tts_uses_managed_audio_gateway_when_direct_key_absent(monkeypatc
|
|||||||
_install_fake_tools_package()
|
_install_fake_tools_package()
|
||||||
_install_fake_openai_module(captured)
|
_install_fake_openai_module(captured)
|
||||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
monkeypatch.setenv("TOOL_GATEWAY_DOMAIN", "nousresearch.com")
|
monkeypatch.setenv("TOOL_GATEWAY_DOMAIN", "nousresearch.com")
|
||||||
monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
|
monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
|
||||||
|
|
||||||
@@ -256,6 +257,7 @@ def test_transcription_uses_model_specific_response_formats(monkeypatch, tmp_pat
|
|||||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
(tmp_path / "config.yaml").write_text("stt:\n provider: openai\n")
|
(tmp_path / "config.yaml").write_text("stt:\n provider: openai\n")
|
||||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
monkeypatch.setenv("TOOL_GATEWAY_DOMAIN", "nousresearch.com")
|
monkeypatch.setenv("TOOL_GATEWAY_DOMAIN", "nousresearch.com")
|
||||||
monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
|
monkeypatch.setenv("TOOL_GATEWAY_USER_TOKEN", "nous-token")
|
||||||
|
|
||||||
|
|||||||
@@ -414,6 +414,7 @@ class TestVisionSafetyGuards:
|
|||||||
|
|
||||||
class FakeResponse:
|
class FakeResponse:
|
||||||
url = "https://blocked.test/final.png"
|
url = "https://blocked.test/final.png"
|
||||||
|
headers = {"content-length": "24"}
|
||||||
content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 16
|
content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 16
|
||||||
|
|
||||||
def raise_for_status(self):
|
def raise_for_status(self):
|
||||||
@@ -533,6 +534,133 @@ class TestTildeExpansion:
|
|||||||
assert data["success"] is False
|
assert data["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# file:// URI support
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileUriSupport:
|
||||||
|
"""Verify that file:// URIs resolve as local file paths."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_uri_resolved_as_local_path(self, tmp_path):
|
||||||
|
"""file:///absolute/path should be treated as a local file."""
|
||||||
|
img = tmp_path / "photo.png"
|
||||||
|
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message.content = "A test image"
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"tools.vision_tools._image_to_base64_data_url",
|
||||||
|
return_value="data:image/png;base64,abc",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"tools.vision_tools.async_call_llm",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await vision_analyze_tool(
|
||||||
|
f"file://{img}", "describe this", "test/model"
|
||||||
|
)
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["success"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_uri_nonexistent_gives_error(self, tmp_path):
|
||||||
|
"""file:// pointing to a missing file should fail gracefully."""
|
||||||
|
result = await vision_analyze_tool(
|
||||||
|
f"file://{tmp_path}/nonexistent.png", "describe this", "test/model"
|
||||||
|
)
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Base64 size pre-flight check
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBase64SizeLimit:
|
||||||
|
"""Verify that oversized images are rejected before hitting the API."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oversized_image_rejected_before_api_call(self, tmp_path):
|
||||||
|
"""Images exceeding 5 MB base64 should fail with a clear size error."""
|
||||||
|
img = tmp_path / "huge.png"
|
||||||
|
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * (4 * 1024 * 1024))
|
||||||
|
|
||||||
|
with patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock) as mock_llm:
|
||||||
|
result = json.loads(await vision_analyze_tool(str(img), "describe this"))
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "too large" in result["error"].lower()
|
||||||
|
mock_llm.assert_not_awaited()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_small_image_not_rejected(self, tmp_path):
|
||||||
|
"""Images well under the limit should pass the size check."""
|
||||||
|
img = tmp_path / "small.png"
|
||||||
|
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 64)
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message.content = "Small image"
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"tools.vision_tools.async_call_llm",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = json.loads(await vision_analyze_tool(str(img), "describe this", "test/model"))
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error classification for 400 responses
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorClassification:
|
||||||
|
"""Verify that API 400 errors produce actionable guidance."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_request_error_gives_image_guidance(self, tmp_path):
|
||||||
|
"""An invalid_request_error from the API should mention image size/format."""
|
||||||
|
img = tmp_path / "test.png"
|
||||||
|
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
|
||||||
|
|
||||||
|
api_error = Exception(
|
||||||
|
"Error code: 400 - {'type': 'error', 'error': "
|
||||||
|
"{'type': 'invalid_request_error', 'message': 'Invalid request data'}}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"tools.vision_tools._image_to_base64_data_url",
|
||||||
|
return_value="data:image/png;base64,abc",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"tools.vision_tools.async_call_llm",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=api_error,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = json.loads(await vision_analyze_tool(str(img), "describe", "test/model"))
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "rejected the image" in result["analysis"].lower()
|
||||||
|
assert "smaller" in result["analysis"].lower()
|
||||||
|
|
||||||
|
|
||||||
class TestVisionRegistration:
|
class TestVisionRegistration:
|
||||||
def test_vision_analyze_registered(self):
|
def test_vision_analyze_registered(self):
|
||||||
from tools.registry import registry
|
from tools.registry import registry
|
||||||
|
|||||||
@@ -396,8 +396,8 @@ class ProcessRegistry:
|
|||||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Process stdout reader ended: %s", e)
|
logger.debug("Process stdout reader ended: %s", e)
|
||||||
|
finally:
|
||||||
# Process exited
|
# Always reap the child to prevent zombie processes.
|
||||||
try:
|
try:
|
||||||
session.process.wait(timeout=5)
|
session.process.wait(timeout=5)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -67,6 +67,10 @@ def _resolve_download_timeout() -> float:
|
|||||||
|
|
||||||
_VISION_DOWNLOAD_TIMEOUT = _resolve_download_timeout()
|
_VISION_DOWNLOAD_TIMEOUT = _resolve_download_timeout()
|
||||||
|
|
||||||
|
# Hard cap on downloaded image file size (50 MB). Prevents OOM from
|
||||||
|
# attacker-hosted multi-gigabyte files or decompression bombs.
|
||||||
|
_VISION_MAX_DOWNLOAD_BYTES = 50 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
def _validate_image_url(url: str) -> bool:
|
def _validate_image_url(url: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -181,13 +185,25 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Reject overly large images early via Content-Length header.
|
||||||
|
cl = response.headers.get("content-length")
|
||||||
|
if cl and int(cl) > _VISION_MAX_DOWNLOAD_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image too large ({int(cl)} bytes, max {_VISION_MAX_DOWNLOAD_BYTES})"
|
||||||
|
)
|
||||||
|
|
||||||
final_url = str(response.url)
|
final_url = str(response.url)
|
||||||
blocked = check_website_access(final_url)
|
blocked = check_website_access(final_url)
|
||||||
if blocked:
|
if blocked:
|
||||||
raise PermissionError(blocked["message"])
|
raise PermissionError(blocked["message"])
|
||||||
|
|
||||||
# Save the image content
|
# Save the image content (double-check actual size)
|
||||||
destination.write_bytes(response.content)
|
body = response.content
|
||||||
|
if len(body) > _VISION_MAX_DOWNLOAD_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image too large ({len(body)} bytes, max {_VISION_MAX_DOWNLOAD_BYTES})"
|
||||||
|
)
|
||||||
|
destination.write_bytes(body)
|
||||||
|
|
||||||
return destination
|
return destination
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -326,7 +342,11 @@ async def vision_analyze_tool(
|
|||||||
logger.info("User prompt: %s", user_prompt[:100])
|
logger.info("User prompt: %s", user_prompt[:100])
|
||||||
|
|
||||||
# Determine if this is a local file path or a remote URL
|
# Determine if this is a local file path or a remote URL
|
||||||
local_path = Path(os.path.expanduser(image_url))
|
# Strip file:// scheme so file URIs resolve as local paths.
|
||||||
|
resolved_url = image_url
|
||||||
|
if resolved_url.startswith("file://"):
|
||||||
|
resolved_url = resolved_url[len("file://"):]
|
||||||
|
local_path = Path(os.path.expanduser(resolved_url))
|
||||||
if local_path.is_file():
|
if local_path.is_file():
|
||||||
# Local file path (e.g. from platform image cache) -- skip download
|
# Local file path (e.g. from platform image cache) -- skip download
|
||||||
logger.info("Using local image file: %s", image_url)
|
logger.info("Using local image file: %s", image_url)
|
||||||
@@ -363,6 +383,18 @@ async def vision_analyze_tool(
|
|||||||
data_size_kb = len(image_data_url) / 1024
|
data_size_kb = len(image_data_url) / 1024
|
||||||
logger.info("Image converted to base64 (%.1f KB)", data_size_kb)
|
logger.info("Image converted to base64 (%.1f KB)", data_size_kb)
|
||||||
|
|
||||||
|
# Pre-flight size check: most vision APIs cap base64 payloads at 5 MB.
|
||||||
|
# Reject early with a clear message instead of a cryptic provider 400.
|
||||||
|
_MAX_BASE64_BYTES = 5 * 1024 * 1024 # 5 MB
|
||||||
|
# The data URL includes the header (e.g. "data:image/jpeg;base64,") which
|
||||||
|
# is negligible, but measure the full string to be safe.
|
||||||
|
if len(image_data_url) > _MAX_BASE64_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image too large for vision API: base64 payload is "
|
||||||
|
f"{len(image_data_url) / (1024 * 1024):.1f} MB (limit 5 MB). "
|
||||||
|
f"Resize or compress the image and try again."
|
||||||
|
)
|
||||||
|
|
||||||
debug_call_data["image_size_bytes"] = image_size_bytes
|
debug_call_data["image_size_bytes"] = image_size_bytes
|
||||||
|
|
||||||
# Use the prompt as provided (model_tools.py now handles full description formatting)
|
# Use the prompt as provided (model_tools.py now handles full description formatting)
|
||||||
@@ -455,14 +487,21 @@ async def vision_analyze_tool(
|
|||||||
f"API provider account and try again. Error: {e}"
|
f"API provider account and try again. Error: {e}"
|
||||||
)
|
)
|
||||||
elif any(hint in err_str for hint in (
|
elif any(hint in err_str for hint in (
|
||||||
"does not support", "not support image", "invalid_request",
|
"does not support", "not support image",
|
||||||
"content_policy", "image_url", "multimodal",
|
"content_policy", "multimodal",
|
||||||
"unrecognized request argument", "image input",
|
"unrecognized request argument", "image input",
|
||||||
)):
|
)):
|
||||||
analysis = (
|
analysis = (
|
||||||
f"{model} does not support vision or our request was not "
|
f"{model} does not support vision or our request was not "
|
||||||
f"accepted by the server. Error: {e}"
|
f"accepted by the server. Error: {e}"
|
||||||
)
|
)
|
||||||
|
elif "invalid_request" in err_str or "image_url" in err_str:
|
||||||
|
analysis = (
|
||||||
|
"The vision API rejected the image. This can happen when the "
|
||||||
|
"image is too large, in an unsupported format, or corrupted. "
|
||||||
|
"Try a smaller JPEG/PNG (under 3.5 MB) and retry. "
|
||||||
|
f"Error: {e}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
analysis = (
|
analysis = (
|
||||||
"There was a problem with the request and the image could not "
|
"There was a problem with the request and the image could not "
|
||||||
|
|||||||
Reference in New Issue
Block a user