mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
- Add ha_list_entities, ha_get_state, ha_call_service tools via REST API - Add WebSocket gateway adapter for real-time state_changed event monitoring - Support domain/entity filtering, cooldown, and auto-reconnect with backoff - Use REST API for outbound notifications to avoid WS race condition - Gate tool availability on HASS_TOKEN env var - Add 82 unit tests covering real logic (filtering, payload building, event pipeline)
365 lines
12 KiB
Python
365 lines
12 KiB
Python
"""Home Assistant tool for controlling smart home devices via REST API.
|
|
|
|
Registers three LLM-callable tools:
|
|
- ``ha_list_entities`` -- list/filter entities by domain or area
|
|
- ``ha_get_state`` -- get detailed state of a single entity
|
|
- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.)
|
|
|
|
Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var.
|
|
The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123).
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Any, Dict, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
|
|
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
|
|
|
|
|
|
def _get_headers() -> Dict[str, str]:
|
|
"""Return authorization headers for HA REST API."""
|
|
return {
|
|
"Authorization": f"Bearer {_HASS_TOKEN}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Async helpers (called from sync handlers via run_until_complete)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _filter_and_summarize(
|
|
states: list,
|
|
domain: Optional[str] = None,
|
|
area: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Filter raw HA states by domain/area and return a compact summary."""
|
|
if domain:
|
|
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
|
|
|
|
if area:
|
|
area_lower = area.lower()
|
|
states = [
|
|
s for s in states
|
|
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
|
|
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
|
|
]
|
|
|
|
entities = []
|
|
for s in states:
|
|
entities.append({
|
|
"entity_id": s["entity_id"],
|
|
"state": s["state"],
|
|
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
|
})
|
|
|
|
return {"count": len(entities), "entities": entities}
|
|
|
|
|
|
async def _async_list_entities(
|
|
domain: Optional[str] = None,
|
|
area: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Fetch entity states from HA and optionally filter by domain/area."""
|
|
import aiohttp
|
|
|
|
url = f"{_HASS_URL}/api/states"
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
|
resp.raise_for_status()
|
|
states = await resp.json()
|
|
|
|
return _filter_and_summarize(states, domain, area)
|
|
|
|
|
|
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
|
"""Fetch detailed state of a single entity."""
|
|
import aiohttp
|
|
|
|
url = f"{_HASS_URL}/api/states/{entity_id}"
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
|
resp.raise_for_status()
|
|
data = await resp.json()
|
|
|
|
return {
|
|
"entity_id": data["entity_id"],
|
|
"state": data["state"],
|
|
"attributes": data.get("attributes", {}),
|
|
"last_changed": data.get("last_changed"),
|
|
"last_updated": data.get("last_updated"),
|
|
}
|
|
|
|
|
|
def _build_service_payload(
|
|
entity_id: Optional[str] = None,
|
|
data: Optional[Dict[str, Any]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Build the JSON payload for a HA service call."""
|
|
payload: Dict[str, Any] = {}
|
|
if entity_id:
|
|
payload["entity_id"] = entity_id
|
|
if data:
|
|
payload.update(data)
|
|
return payload
|
|
|
|
|
|
def _parse_service_response(
|
|
domain: str,
|
|
service: str,
|
|
result: Any,
|
|
) -> Dict[str, Any]:
|
|
"""Parse HA service call response into a structured result."""
|
|
affected = []
|
|
if isinstance(result, list):
|
|
for s in result:
|
|
affected.append({
|
|
"entity_id": s.get("entity_id", ""),
|
|
"state": s.get("state", ""),
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"service": f"{domain}.{service}",
|
|
"affected_entities": affected,
|
|
}
|
|
|
|
|
|
async def _async_call_service(
|
|
domain: str,
|
|
service: str,
|
|
entity_id: Optional[str] = None,
|
|
data: Optional[Dict[str, Any]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Call a Home Assistant service."""
|
|
import aiohttp
|
|
|
|
url = f"{_HASS_URL}/api/services/{domain}/{service}"
|
|
payload = _build_service_payload(entity_id, data)
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
url,
|
|
headers=_get_headers(),
|
|
json=payload,
|
|
timeout=aiohttp.ClientTimeout(total=15),
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
result = await resp.json()
|
|
|
|
return _parse_service_response(domain, service, result)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sync wrappers (handler signature: (args, **kw) -> str)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _run_async(coro):
|
|
"""Run an async coroutine from a sync handler."""
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
loop = None
|
|
|
|
if loop and loop.is_running():
|
|
# Already inside an event loop -- create a new thread
|
|
import concurrent.futures
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
future = pool.submit(asyncio.run, coro)
|
|
return future.result(timeout=30)
|
|
else:
|
|
return asyncio.run(coro)
|
|
|
|
|
|
def _handle_list_entities(args: dict, **kw) -> str:
|
|
"""Handler for ha_list_entities tool."""
|
|
domain = args.get("domain")
|
|
area = args.get("area")
|
|
try:
|
|
result = _run_async(_async_list_entities(domain=domain, area=area))
|
|
return json.dumps({"result": result})
|
|
except Exception as e:
|
|
logger.error("ha_list_entities error: %s", e)
|
|
return json.dumps({"error": f"Failed to list entities: {e}"})
|
|
|
|
|
|
def _handle_get_state(args: dict, **kw) -> str:
|
|
"""Handler for ha_get_state tool."""
|
|
entity_id = args.get("entity_id", "")
|
|
if not entity_id:
|
|
return json.dumps({"error": "Missing required parameter: entity_id"})
|
|
try:
|
|
result = _run_async(_async_get_state(entity_id))
|
|
return json.dumps({"result": result})
|
|
except Exception as e:
|
|
logger.error("ha_get_state error: %s", e)
|
|
return json.dumps({"error": f"Failed to get state for {entity_id}: {e}"})
|
|
|
|
|
|
def _handle_call_service(args: dict, **kw) -> str:
|
|
"""Handler for ha_call_service tool."""
|
|
domain = args.get("domain", "")
|
|
service = args.get("service", "")
|
|
if not domain or not service:
|
|
return json.dumps({"error": "Missing required parameters: domain and service"})
|
|
|
|
entity_id = args.get("entity_id")
|
|
data = args.get("data")
|
|
try:
|
|
result = _run_async(_async_call_service(domain, service, entity_id, data))
|
|
return json.dumps({"result": result})
|
|
except Exception as e:
|
|
logger.error("ha_call_service error: %s", e)
|
|
return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Availability check
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _check_ha_available() -> bool:
|
|
"""Tool is only available when HASS_TOKEN is set."""
|
|
return bool(os.getenv("HASS_TOKEN"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
HA_LIST_ENTITIES_SCHEMA = {
|
|
"name": "ha_list_entities",
|
|
"description": (
|
|
"List Home Assistant entities. Optionally filter by domain "
|
|
"(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) "
|
|
"or by area name (living room, kitchen, bedroom, etc.)."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"domain": {
|
|
"type": "string",
|
|
"description": (
|
|
"Entity domain to filter by (e.g. 'light', 'switch', 'climate', "
|
|
"'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). "
|
|
"Omit to list all entities."
|
|
),
|
|
},
|
|
"area": {
|
|
"type": "string",
|
|
"description": (
|
|
"Area/room name to filter by (e.g. 'living room', 'kitchen'). "
|
|
"Matches against entity friendly names. Omit to list all."
|
|
),
|
|
},
|
|
},
|
|
"required": [],
|
|
},
|
|
}
|
|
|
|
HA_GET_STATE_SCHEMA = {
|
|
"name": "ha_get_state",
|
|
"description": (
|
|
"Get the detailed state of a single Home Assistant entity, including all "
|
|
"attributes (brightness, color, temperature setpoint, sensor readings, etc.)."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"entity_id": {
|
|
"type": "string",
|
|
"description": (
|
|
"The entity ID to query (e.g. 'light.living_room', "
|
|
"'climate.thermostat', 'sensor.temperature')."
|
|
),
|
|
},
|
|
},
|
|
"required": ["entity_id"],
|
|
},
|
|
}
|
|
|
|
HA_CALL_SERVICE_SCHEMA = {
|
|
"name": "ha_call_service",
|
|
"description": (
|
|
"Call a Home Assistant service to control a device. Common examples: "
|
|
"turn_on/turn_off lights and switches, set_temperature for climate, "
|
|
"open_cover/close_cover for blinds, set_volume_level for media players."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"domain": {
|
|
"type": "string",
|
|
"description": (
|
|
"Service domain (e.g. 'light', 'switch', 'climate', "
|
|
"'cover', 'media_player', 'fan', 'scene', 'script')."
|
|
),
|
|
},
|
|
"service": {
|
|
"type": "string",
|
|
"description": (
|
|
"Service name (e.g. 'turn_on', 'turn_off', 'toggle', "
|
|
"'set_temperature', 'set_hvac_mode', 'open_cover', "
|
|
"'close_cover', 'set_volume_level')."
|
|
),
|
|
},
|
|
"entity_id": {
|
|
"type": "string",
|
|
"description": (
|
|
"Target entity ID (e.g. 'light.living_room'). "
|
|
"Some services (like scene.turn_on) may not need this."
|
|
),
|
|
},
|
|
"data": {
|
|
"type": "object",
|
|
"description": (
|
|
"Additional service data. Examples: "
|
|
'{"brightness": 255, "color_name": "blue"} for lights, '
|
|
'{"temperature": 22, "hvac_mode": "heat"} for climate, '
|
|
'{"volume_level": 0.5} for media players.'
|
|
),
|
|
},
|
|
},
|
|
"required": ["domain", "service"],
|
|
},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Registration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
from tools.registry import registry
|
|
|
|
registry.register(
|
|
name="ha_list_entities",
|
|
toolset="homeassistant",
|
|
schema=HA_LIST_ENTITIES_SCHEMA,
|
|
handler=_handle_list_entities,
|
|
check_fn=_check_ha_available,
|
|
)
|
|
|
|
registry.register(
|
|
name="ha_get_state",
|
|
toolset="homeassistant",
|
|
schema=HA_GET_STATE_SCHEMA,
|
|
handler=_handle_get_state,
|
|
check_fn=_check_ha_available,
|
|
)
|
|
|
|
registry.register(
|
|
name="ha_call_service",
|
|
toolset="homeassistant",
|
|
schema=HA_CALL_SERVICE_SCHEMA,
|
|
handler=_handle_call_service,
|
|
check_fn=_check_ha_available,
|
|
)
|