mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-04 01:37:34 +08:00
fix: add service domain blocklist and entity_id validation to HA tools
Block dangerous HA service domains (shell_command, command_line, python_script, pyscript, hassio, rest_command) that allow arbitrary code execution or SSRF. Add regex validation for entity_id to prevent path traversal attacks. 17 new tests covering both security features.
This commit is contained in:
@@ -16,6 +16,8 @@ from tools.homeassistant_tool import (
|
|||||||
_get_headers,
|
_get_headers,
|
||||||
_handle_get_state,
|
_handle_get_state,
|
||||||
_handle_call_service,
|
_handle_call_service,
|
||||||
|
_BLOCKED_DOMAINS,
|
||||||
|
_ENTITY_ID_RE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -211,6 +213,96 @@ class TestHandlerValidation:
|
|||||||
assert "error" in result
|
assert "error" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Security: domain blocklist
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDomainBlocklist:
|
||||||
|
"""Verify dangerous HA service domains are blocked."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("domain", sorted(_BLOCKED_DOMAINS))
|
||||||
|
def test_blocked_domain_rejected(self, domain):
|
||||||
|
result = json.loads(_handle_call_service({
|
||||||
|
"domain": domain, "service": "any_service"
|
||||||
|
}))
|
||||||
|
assert "error" in result
|
||||||
|
assert "blocked" in result["error"].lower()
|
||||||
|
|
||||||
|
def test_safe_domain_not_blocked(self):
|
||||||
|
"""Safe domains like 'light' should not be blocked (will fail on network, not blocklist)."""
|
||||||
|
# This will try to make a real HTTP call and fail, but the important thing
|
||||||
|
# is it does NOT return a "blocked" error
|
||||||
|
result = json.loads(_handle_call_service({
|
||||||
|
"domain": "light", "service": "turn_on", "entity_id": "light.test"
|
||||||
|
}))
|
||||||
|
# Should fail with a network/connection error, not a "blocked" error
|
||||||
|
if "error" in result:
|
||||||
|
assert "blocked" not in result["error"].lower()
|
||||||
|
|
||||||
|
def test_blocked_domains_include_shell_command(self):
|
||||||
|
assert "shell_command" in _BLOCKED_DOMAINS
|
||||||
|
|
||||||
|
def test_blocked_domains_include_hassio(self):
|
||||||
|
assert "hassio" in _BLOCKED_DOMAINS
|
||||||
|
|
||||||
|
def test_blocked_domains_include_rest_command(self):
|
||||||
|
assert "rest_command" in _BLOCKED_DOMAINS
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Security: entity_id validation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntityIdValidation:
|
||||||
|
"""Verify entity_id format validation prevents path traversal."""
|
||||||
|
|
||||||
|
def test_valid_entity_id_accepted(self):
|
||||||
|
assert _ENTITY_ID_RE.match("light.bedroom")
|
||||||
|
assert _ENTITY_ID_RE.match("sensor.temperature_1")
|
||||||
|
assert _ENTITY_ID_RE.match("binary_sensor.motion")
|
||||||
|
assert _ENTITY_ID_RE.match("climate.main_thermostat")
|
||||||
|
|
||||||
|
def test_path_traversal_rejected(self):
|
||||||
|
assert _ENTITY_ID_RE.match("../../config") is None
|
||||||
|
assert _ENTITY_ID_RE.match("light/../../../etc/passwd") is None
|
||||||
|
assert _ENTITY_ID_RE.match("../api/config") is None
|
||||||
|
|
||||||
|
def test_special_chars_rejected(self):
|
||||||
|
assert _ENTITY_ID_RE.match("light.bed room") is None # space
|
||||||
|
assert _ENTITY_ID_RE.match("light.bed;rm -rf") is None # semicolon
|
||||||
|
assert _ENTITY_ID_RE.match("light.bed/room") is None # slash
|
||||||
|
assert _ENTITY_ID_RE.match("LIGHT.BEDROOM") is None # uppercase
|
||||||
|
|
||||||
|
def test_missing_domain_rejected(self):
|
||||||
|
assert _ENTITY_ID_RE.match(".bedroom") is None
|
||||||
|
assert _ENTITY_ID_RE.match("bedroom") is None
|
||||||
|
|
||||||
|
def test_get_state_rejects_invalid_entity_id(self):
|
||||||
|
result = json.loads(_handle_get_state({"entity_id": "../../config"}))
|
||||||
|
assert "error" in result
|
||||||
|
assert "Invalid entity_id" in result["error"]
|
||||||
|
|
||||||
|
def test_call_service_rejects_invalid_entity_id(self):
|
||||||
|
result = json.loads(_handle_call_service({
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_on",
|
||||||
|
"entity_id": "../../../etc/passwd",
|
||||||
|
}))
|
||||||
|
assert "error" in result
|
||||||
|
assert "Invalid entity_id" in result["error"]
|
||||||
|
|
||||||
|
def test_call_service_allows_no_entity_id(self):
|
||||||
|
"""Some services (like scene.turn_on) don't need entity_id."""
|
||||||
|
# Will fail on network, but should NOT fail on entity_id validation
|
||||||
|
result = json.loads(_handle_call_service({
|
||||||
|
"domain": "scene", "service": "turn_on"
|
||||||
|
}))
|
||||||
|
if "error" in result:
|
||||||
|
assert "Invalid entity_id" not in result["error"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Availability check
|
# Availability check
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -24,6 +25,21 @@ logger = logging.getLogger(__name__)
|
|||||||
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
|
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
|
||||||
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
|
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
|
||||||
|
|
||||||
|
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
|
||||||
|
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
|
||||||
|
|
||||||
|
# Service domains blocked for security -- these allow arbitrary code/command
|
||||||
|
# execution on the HA host or enable SSRF attacks on the local network.
|
||||||
|
# HA provides zero service-level access control; all safety must be in our layer.
|
||||||
|
_BLOCKED_DOMAINS = frozenset({
|
||||||
|
"shell_command", # arbitrary shell commands as root in HA container
|
||||||
|
"command_line", # sensors/switches that execute shell commands
|
||||||
|
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||||
|
"pyscript", # scripting integration with broader access
|
||||||
|
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||||
|
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def _get_headers() -> Dict[str, str]:
|
def _get_headers() -> Dict[str, str]:
|
||||||
"""Return authorization headers for HA REST API."""
|
"""Return authorization headers for HA REST API."""
|
||||||
@@ -198,6 +214,8 @@ def _handle_get_state(args: dict, **kw) -> str:
|
|||||||
entity_id = args.get("entity_id", "")
|
entity_id = args.get("entity_id", "")
|
||||||
if not entity_id:
|
if not entity_id:
|
||||||
return json.dumps({"error": "Missing required parameter: entity_id"})
|
return json.dumps({"error": "Missing required parameter: entity_id"})
|
||||||
|
if not _ENTITY_ID_RE.match(entity_id):
|
||||||
|
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
|
||||||
try:
|
try:
|
||||||
result = _run_async(_async_get_state(entity_id))
|
result = _run_async(_async_get_state(entity_id))
|
||||||
return json.dumps({"result": result})
|
return json.dumps({"result": result})
|
||||||
@@ -213,7 +231,16 @@ def _handle_call_service(args: dict, **kw) -> str:
|
|||||||
if not domain or not service:
|
if not domain or not service:
|
||||||
return json.dumps({"error": "Missing required parameters: domain and service"})
|
return json.dumps({"error": "Missing required parameters: domain and service"})
|
||||||
|
|
||||||
|
if domain in _BLOCKED_DOMAINS:
|
||||||
|
return json.dumps({
|
||||||
|
"error": f"Service domain '{domain}' is blocked for security. "
|
||||||
|
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||||
|
})
|
||||||
|
|
||||||
entity_id = args.get("entity_id")
|
entity_id = args.get("entity_id")
|
||||||
|
if entity_id and not _ENTITY_ID_RE.match(entity_id):
|
||||||
|
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
|
||||||
|
|
||||||
data = args.get("data")
|
data = args.get("data")
|
||||||
try:
|
try:
|
||||||
result = _run_async(_async_call_service(domain, service, entity_id, data))
|
result = _run_async(_async_call_service(domain, service, entity_id, data))
|
||||||
|
|||||||
Reference in New Issue
Block a user