From 00c65280c44d31f47a5597a8e89e706102f10e32 Mon Sep 17 00:00:00 2001 From: Teknium Date: Wed, 15 Apr 2026 22:44:40 -0700 Subject: [PATCH] feat(xai): add video generation, image editing, and X search tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-picked from PR #10600 by Jaaneek — the media/search tool additions, separated from the core provider upgrade (PR #10783). NOTE: Depends on PR #10783 being merged first (for xai_http.py, codex_responses transport, and XAI_API_KEY env var). - Add video generation tool (generate, edit, extend) with async polling - Add xAI image generation/editing backend alongside FAL - Add X search tool backed by xAI Responses API - Add x_search and video_gen toolset definitions - Add CONFIGURABLE_TOOLSETS entries for tools_config UI - Wire into safe and api-server toolsets - Add test coverage for all new tools Co-authored-by: Jaaneek --- hermes_cli/tools_config.py | 2 + tests/tools/test_x_search_tool.py | 207 ++++++++++ tests/tools/test_xai_media_tools.py | 611 ++++++++++++++++++++++++++++ tools/image_generation_tool.py | 560 ++++++++++++++++++++----- tools/video_generation_tool.py | 459 +++++++++++++++++++++ tools/x_search_tool.py | 351 ++++++++++++++++ tools/xai_http.py | 12 + toolsets.py | 22 +- 8 files changed, 2125 insertions(+), 99 deletions(-) create mode 100644 tests/tools/test_x_search_tool.py create mode 100644 tests/tools/test_xai_media_tools.py create mode 100644 tools/video_generation_tool.py create mode 100644 tools/x_search_tool.py create mode 100644 tools/xai_http.py diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 5fe8cdc79e..cbda0588db 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -48,12 +48,14 @@ from hermes_cli.cli_output import ( # noqa: E402 — late import block # These map to keys in toolsets.py TOOLSETS dict. CONFIGURABLE_TOOLSETS = [ ("web", "🔍 Web Search & Scraping", "web_search, web_extract"), + ("x_search", "🐦 X Search", "x_search"), ("browser", "🌐 Browser Automation", "navigate, click, type, scroll"), ("terminal", "💻 Terminal & Processes", "terminal, process"), ("file", "📁 File Operations", "read, write, patch, search"), ("code_execution", "⚡ Code Execution", "execute_code"), ("vision", "👁️ Vision / Image Analysis", "vision_analyze"), ("image_gen", "🎨 Image Generation", "image_generate"), + ("video_gen", "🎬 Video Generation", "video_generate"), ("moa", "🧠 Mixture of Agents", "mixture_of_agents"), ("tts", "🔊 Text-to-Speech", "text_to_speech"), ("skills", "📚 Skills", "list, view, manage"), diff --git a/tests/tools/test_x_search_tool.py b/tests/tools/test_x_search_tool.py new file mode 100644 index 0000000000..cc238134c6 --- /dev/null +++ b/tests/tools/test_x_search_tool.py @@ -0,0 +1,207 @@ +import json +import requests + + +class _FakeResponse: + def __init__(self, payload, *, status_code=200, text=None): + self._payload = payload + self.status_code = status_code + self.text = text if text is not None else json.dumps(payload) + + def raise_for_status(self): + if self.status_code >= 400: + err = requests.HTTPError(f"{self.status_code} Client Error") + err.response = self + raise err + + def json(self): + return self._payload + + +def test_x_search_posts_responses_request(monkeypatch): + from tools.x_search_tool import x_search_tool + from hermes_cli import __version__ + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["url"] = url + captured["headers"] = headers + captured["json"] = json + captured["timeout"] = timeout + return _FakeResponse( + { + "output_text": "People on X are discussing xAI's latest launch.", + "citations": [{"url": "https://x.com/example/status/1", "title": "Example post"}], + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("requests.post", _fake_post) + + result = json.loads( + x_search_tool( + query="What are people saying about xAI on X?", + allowed_x_handles=["xai", "@grok"], + from_date="2026-04-01", + to_date="2026-04-10", + enable_image_understanding=True, + ) + ) + + tool_def = captured["json"]["tools"][0] + assert captured["url"] == "https://api.x.ai/v1/responses" + assert captured["headers"]["User-Agent"] == f"Hermes-Agent/{__version__}" + assert captured["json"]["model"] == "grok-4.20-reasoning" + assert captured["json"]["store"] is False + assert tool_def["type"] == "x_search" + assert tool_def["allowed_x_handles"] == ["xai", "grok"] + assert tool_def["from_date"] == "2026-04-01" + assert tool_def["to_date"] == "2026-04-10" + assert tool_def["enable_image_understanding"] is True + assert result["success"] is True + assert result["answer"] == "People on X are discussing xAI's latest launch." + + +def test_x_search_rejects_conflicting_handle_filters(monkeypatch): + from tools.x_search_tool import x_search_tool + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + result = json.loads( + x_search_tool( + query="latest xAI discussion", + allowed_x_handles=["xai"], + excluded_x_handles=["grok"], + ) + ) + + assert result["error"] == "allowed_x_handles and excluded_x_handles cannot be used together" + + +def test_x_search_extracts_inline_url_citations(monkeypatch): + from tools.x_search_tool import x_search_tool + + def _fake_post(url, headers=None, json=None, timeout=None): + return _FakeResponse( + { + "output": [ + { + "type": "message", + "content": [ + { + "type": "output_text", + "text": "xAI posted an update on X.", + "annotations": [ + { + "type": "url_citation", + "url": "https://x.com/xai/status/123", + "title": "xAI update", + "start_index": 0, + "end_index": 3, + } + ], + } + ], + } + ] + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("requests.post", _fake_post) + + result = json.loads(x_search_tool(query="latest post from xai")) + + assert result["success"] is True + assert result["answer"] == "xAI posted an update on X." + assert result["inline_citations"] == [ + { + "url": "https://x.com/xai/status/123", + "title": "xAI update", + "start_index": 0, + "end_index": 3, + } + ] + + +def test_x_search_returns_structured_http_error(monkeypatch): + from tools.x_search_tool import x_search_tool + + class _FailingResponse: + status_code = 403 + text = '{"code":"forbidden","error":"x_search is not enabled for this model"}' + + def json(self): + return { + "code": "forbidden", + "error": "x_search is not enabled for this model", + } + + def raise_for_status(self): + err = requests.HTTPError("403 Client Error: Forbidden") + err.response = self + raise err + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("requests.post", lambda *a, **k: _FailingResponse()) + + result = json.loads(x_search_tool(query="latest xai discussion")) + + assert result["success"] is False + assert result["provider"] == "xai" + assert result["tool"] == "x_search" + assert result["error_type"] == "HTTPError" + assert result["error"] == "forbidden: x_search is not enabled for this model" + + +def test_x_search_retries_read_timeout_then_succeeds(monkeypatch): + from tools.x_search_tool import x_search_tool + + calls = {"count": 0} + + def _fake_post(url, headers=None, json=None, timeout=None): + calls["count"] += 1 + if calls["count"] == 1: + raise requests.ReadTimeout("timed out") + return _FakeResponse( + { + "output_text": "Recovered after retry.", + "citations": [], + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("requests.post", _fake_post) + monkeypatch.setattr("tools.x_search_tool.time.sleep", lambda *_: None) + + result = json.loads(x_search_tool(query="grok xai")) + + assert calls["count"] == 2 + assert result["success"] is True + assert result["answer"] == "Recovered after retry." + + +def test_x_search_retries_5xx_then_succeeds(monkeypatch): + from tools.x_search_tool import x_search_tool + + calls = {"count": 0} + + def _fake_post(url, headers=None, json=None, timeout=None): + calls["count"] += 1 + if calls["count"] == 1: + return _FakeResponse( + {"code": "Internal error", "error": "Service temporarily unavailable."}, + status_code=500, + ) + return _FakeResponse({"output_text": "Recovered after 5xx retry."}) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("requests.post", _fake_post) + monkeypatch.setattr("tools.x_search_tool.time.sleep", lambda *_: None) + + result = json.loads(x_search_tool(query="grok xai")) + + assert calls["count"] == 2 + assert result["success"] is True + assert result["answer"] == "Recovered after 5xx retry." diff --git a/tests/tools/test_xai_media_tools.py b/tests/tools/test_xai_media_tools.py new file mode 100644 index 0000000000..5e324bd66d --- /dev/null +++ b/tests/tools/test_xai_media_tools.py @@ -0,0 +1,611 @@ +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + + +def test_video_generate_schema_guides_prompt_without_requiring_it(): + from tools.video_generation_tool import VIDEO_GENERATE_SCHEMA + + parameters = VIDEO_GENERATE_SCHEMA["parameters"] + properties = parameters["properties"] + + assert "prompt" not in parameters.get("required", []) + assert "Usually pass this" in properties["prompt"]["description"] + assert "Optional only for image-to-video" in properties["prompt"]["description"] + assert "output" not in properties + assert "output_upload_url" not in properties + + +class _FakeResponse: + def __init__(self, *, json_payload=None, content=b""): + self._json_payload = json_payload or {} + self.content = content + + def raise_for_status(self): + return None + + def json(self): + return self._json_payload + + +def _fake_httpx_client(*, post_fn, get_fn=None): + """Build a mock httpx.AsyncClient that delegates to sync test helpers.""" + client = AsyncMock() + + async def _post(url, *, headers=None, json=None, timeout=None): + return post_fn(url, headers=headers, json=json, timeout=timeout) + + async def _get(url, *, headers=None, timeout=None): + return get_fn(url, headers=headers, timeout=timeout) + + client.post = _post + if get_fn is not None: + client.get = _get + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=False) + return client + + +def test_image_generate_tool_supports_xai_provider(monkeypatch): + from tools.image_generation_tool import image_generate_tool + from hermes_cli import __version__ + + def _fake_post(url, headers=None, json=None, timeout=None): + assert url == "https://api.x.ai/v1/images/generations" + assert headers["User-Agent"] == f"Hermes-Agent/{__version__}" + assert json["model"] == "grok-imagine-image" + assert json["aspect_ratio"] == "16:9" + return _FakeResponse( + json_payload={ + "data": [ + { + "url": "https://cdn.example.com/generated.png", + "width": 1280, + "height": 720, + } + ] + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("tools.image_generation_tool.requests.post", _fake_post) + monkeypatch.setattr("tools.image_generation_tool._has_fal_backend", lambda: False) + + result = json.loads( + image_generate_tool( + prompt="a cinematic skyline at sunset", + provider="xai", + aspect_ratio="landscape", + ) + ) + + assert result["success"] is True + assert result["provider"] == "xai" + assert result["image"] == "https://cdn.example.com/generated.png" + + +def test_image_generate_tool_supports_xai_reference_images_for_generate(monkeypatch): + from tools.image_generation_tool import image_generate_tool + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["url"] = url + captured["json"] = json + return _FakeResponse( + json_payload={ + "data": [ + { + "url": "https://cdn.example.com/reference-guided.png", + "width": 1280, + "height": 720, + } + ] + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("tools.image_generation_tool.requests.post", _fake_post) + monkeypatch.setattr("tools.image_generation_tool._has_fal_backend", lambda: True) + + result = json.loads( + image_generate_tool( + prompt="A campaign portrait in xAI style.", + provider="auto", + aspect_ratio="16:9", + reference_image_urls=[ + "https://cdn.example.com/reference-a.png", + "https://cdn.example.com/reference-b.png", + ], + ) + ) + + assert captured["url"] == "https://api.x.ai/v1/images/generations" + assert captured["json"]["reference_images"] == [ + {"type": "image_url", "url": "https://cdn.example.com/reference-a.png"}, + {"type": "image_url", "url": "https://cdn.example.com/reference-b.png"}, + ] + assert result["success"] is True + assert result["provider"] == "xai" + assert result["image"] == "https://cdn.example.com/reference-guided.png" + + +def test_image_generate_tool_supports_xai_edit_with_multiple_source_images(monkeypatch): + from tools.image_generation_tool import image_generate_tool + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["url"] = url + captured["json"] = json + return _FakeResponse( + json_payload={ + "data": [ + { + "url": "https://cdn.example.com/edited.png", + "width": 1536, + "height": 1024, + } + ] + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("tools.image_generation_tool.requests.post", _fake_post) + monkeypatch.setattr("tools.image_generation_tool._has_fal_backend", lambda: True) + + result = json.loads( + image_generate_tool( + prompt="Put the two people together in one cinematic rooftop portrait.", + operation="edit", + provider="auto", + aspect_ratio="3:2", + resolution="2k", + source_image_urls=[ + "https://cdn.example.com/person-a.png", + "https://cdn.example.com/person-b.png", + ], + ) + ) + + assert captured["url"] == "https://api.x.ai/v1/images/edits" + assert captured["json"]["images"][0]["url"] == "https://cdn.example.com/person-a.png" + assert captured["json"]["images"][1]["url"] == "https://cdn.example.com/person-b.png" + assert captured["json"]["aspect_ratio"] == "3:2" + assert captured["json"]["resolution"] == "2k" + assert result["success"] is True + assert result["provider"] == "xai" + assert result["operation"] == "edit" + assert result["image"] == "https://cdn.example.com/edited.png" + + +def test_image_generate_tool_uses_configured_xai_provider_by_default(monkeypatch): + from tools.image_generation_tool import image_generate_tool + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["url"] = url + captured["json"] = json + return _FakeResponse( + json_payload={ + "data": [ + { + "url": "https://cdn.example.com/configured-xai.png", + "width": 1024, + "height": 1024, + } + ] + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("tools.image_generation_tool.requests.post", _fake_post) + monkeypatch.setattr("tools.image_generation_tool._has_fal_backend", lambda: True) + monkeypatch.setattr( + "hermes_cli.config.load_config", + lambda: {"image_generation": {"provider": "xai"}}, + ) + + result = json.loads( + image_generate_tool( + prompt="an xAI-first image backend test", + aspect_ratio="square", + ) + ) + + assert captured["url"] == "https://api.x.ai/v1/images/generations" + assert result["success"] is True + assert result["provider"] == "xai" + + +def test_image_generate_tool_prefers_xai_only_features_over_saved_fal_default(monkeypatch): + from tools.image_generation_tool import image_generate_tool + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["url"] = url + captured["json"] = json + return _FakeResponse( + json_payload={ + "data": [ + { + "url": "https://cdn.example.com/edited-with-xai.png", + "width": 1024, + "height": 1024, + } + ] + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + monkeypatch.setattr("tools.image_generation_tool.requests.post", _fake_post) + monkeypatch.setattr("tools.image_generation_tool._has_fal_backend", lambda: True) + monkeypatch.setattr( + "hermes_cli.config.load_config", + lambda: {"image_generation": {"provider": "fal"}}, + ) + + result = json.loads( + image_generate_tool( + prompt="edit this image", + provider="auto", + operation="edit", + source_image_url="https://cdn.example.com/source.png", + ) + ) + + assert captured["url"] == "https://api.x.ai/v1/images/edits" + assert result["success"] is True + assert result["provider"] == "xai" + + +def test_image_generate_tool_errors_clearly_when_xai_only_features_need_xai(monkeypatch): + from tools.image_generation_tool import image_generate_tool + + monkeypatch.delenv("XAI_API_KEY", raising=False) + monkeypatch.setattr("tools.image_generation_tool._has_fal_backend", lambda: True) + monkeypatch.setattr("hermes_cli.config.load_config", lambda: {}) + + result = json.loads( + image_generate_tool( + prompt="edit this image", + provider="auto", + operation="edit", + source_image_url="https://cdn.example.com/source.png", + ) + ) + + assert result["success"] is False + assert "requires xAI image support" in result["error"] + + +def test_video_generate_tool_polls_until_done(monkeypatch): + from tools.video_generation_tool import video_generate_tool + from hermes_cli import __version__ + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["submit_url"] = url + captured["submit_json"] = json + return _FakeResponse(json_payload={"request_id": "vid-123"}) + + def _fake_get(url, headers=None, timeout=None): + captured.setdefault("poll_urls", []).append(url) + return _FakeResponse( + json_payload={ + "status": "done", + "video": {"url": "https://cdn.example.com/generated.mp4"}, + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + mock_client = _fake_httpx_client(post_fn=_fake_post, get_fn=_fake_get) + monkeypatch.setattr("tools.video_generation_tool.httpx.AsyncClient", lambda: mock_client) + + result = json.loads( + asyncio.run( + video_generate_tool( + prompt="slow drone shot over a neon city", + duration=8, + aspect_ratio="16:9", + resolution="720p", + poll_interval_seconds=0, + timeout_seconds=30, + ) + ) + ) + + assert captured["submit_url"] == "https://api.x.ai/v1/videos/generations" + assert captured["submit_json"]["prompt"] == "slow drone shot over a neon city" + assert captured["poll_urls"] == ["https://api.x.ai/v1/videos/vid-123"] + assert result["success"] is True + assert result["provider"] == "xai" + assert result["video"] == "https://cdn.example.com/generated.mp4" + + +def test_video_generate_tool_sends_hermes_user_agent(monkeypatch): + from tools.video_generation_tool import video_generate_tool + from hermes_cli import __version__ + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["submit_headers"] = headers + return _FakeResponse(json_payload={"request_id": "vid-ua"}) + + def _fake_get(url, headers=None, timeout=None): + captured["poll_headers"] = headers + return _FakeResponse( + json_payload={ + "status": "done", + "video": {"url": "https://cdn.example.com/generated.mp4"}, + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + mock_client = _fake_httpx_client(post_fn=_fake_post, get_fn=_fake_get) + monkeypatch.setattr("tools.video_generation_tool.httpx.AsyncClient", lambda: mock_client) + + asyncio.run( + video_generate_tool( + prompt="slow drone shot over a neon city", + duration=8, + aspect_ratio="16:9", + resolution="720p", + poll_interval_seconds=0, + timeout_seconds=30, + ) + ) + + assert captured["submit_headers"]["User-Agent"] == f"Hermes-Agent/{__version__}" + assert captured["poll_headers"]["User-Agent"] == f"Hermes-Agent/{__version__}" + + +def test_video_generate_tool_supports_native_extend(monkeypatch): + from tools.video_generation_tool import video_generate_tool + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["submit_url"] = url + captured["submit_json"] = json + return _FakeResponse(json_payload={"request_id": "vid-456"}) + + def _fake_get(url, headers=None, timeout=None): + return _FakeResponse( + json_payload={ + "status": "done", + "video": {"url": "https://cdn.example.com/extended.mp4"}, + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + mock_client = _fake_httpx_client(post_fn=_fake_post, get_fn=_fake_get) + monkeypatch.setattr("tools.video_generation_tool.httpx.AsyncClient", lambda: mock_client) + + result = json.loads( + asyncio.run( + video_generate_tool( + prompt="Continue the shot as the camera drifts behind the subject.", + operation="extend", + duration=6, + video_url="https://cdn.example.com/source.mp4", + poll_interval_seconds=0, + timeout_seconds=30, + ) + ) + ) + + assert captured["submit_url"] == "https://api.x.ai/v1/videos/extensions" + assert captured["submit_json"]["video"]["url"] == "https://cdn.example.com/source.mp4" + assert captured["submit_json"]["duration"] == 6 + assert result["success"] is True + assert result["operation"] == "extend" + assert result["video"] == "https://cdn.example.com/extended.mp4" + + +def test_video_generate_tool_recovers_promptless_extend_from_source_video_url(monkeypatch): + from tools.video_generation_tool import video_generate_tool + + captured = {} + source_url = "https://cdn.example.com/source.mp4" + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["submit_url"] = url + captured["submit_json"] = json + return _FakeResponse(json_payload={"request_id": "vid-extend-auto"}) + + def _fake_get(url, headers=None, timeout=None): + return _FakeResponse( + json_payload={ + "status": "done", + "video": {"url": "https://cdn.example.com/extended-auto.mp4"}, + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + mock_client = _fake_httpx_client(post_fn=_fake_post, get_fn=_fake_get) + monkeypatch.setattr("tools.video_generation_tool.httpx.AsyncClient", lambda: mock_client) + + result = json.loads( + asyncio.run( + video_generate_tool( + duration=8, + video_url=source_url, + poll_interval_seconds=0, + timeout_seconds=30, + ) + ) + ) + + assert captured["submit_url"] == "https://api.x.ai/v1/videos/extensions" + assert captured["submit_json"]["video"]["url"] == source_url + assert captured["submit_json"]["prompt"] == "Continue the existing video naturally." + assert result["success"] is True + assert result["operation"] == "extend" + assert any("default continuation prompt" in note for note in result["notes"]) + + +def test_video_generate_tool_edit_without_prompt_still_errors(monkeypatch): + from tools.video_generation_tool import video_generate_tool + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + + result = json.loads( + asyncio.run( + video_generate_tool( + operation="edit", + video_url="https://cdn.example.com/source.mp4", + ) + ) + ) + + assert result["error"] == "prompt is required for xAI video edit" + + +def test_video_generate_tool_uses_video_object_for_edit(monkeypatch): + from tools.video_generation_tool import video_generate_tool + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["submit_url"] = url + captured["submit_json"] = json + return _FakeResponse(json_payload={"request_id": "vid-edit"}) + + def _fake_get(url, headers=None, timeout=None): + return _FakeResponse( + json_payload={ + "status": "done", + "model": "grok-imagine-video", + "video": { + "url": "https://cdn.example.com/edited.mp4", + "duration": 8, + "respect_moderation": True, + }, + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + mock_client = _fake_httpx_client(post_fn=_fake_post, get_fn=_fake_get) + monkeypatch.setattr("tools.video_generation_tool.httpx.AsyncClient", lambda: mock_client) + + result = json.loads( + asyncio.run( + video_generate_tool( + prompt="Give the subject a silver necklace.", + operation="edit", + video_url="https://cdn.example.com/source.mp4", + user="jaaneek", + poll_interval_seconds=0, + timeout_seconds=30, + ) + ) + ) + + assert captured["submit_url"] == "https://api.x.ai/v1/videos/edits" + assert captured["submit_json"]["video"]["url"] == "https://cdn.example.com/source.mp4" + assert captured["submit_json"]["user"] == "jaaneek" + assert "output" not in captured["submit_json"] + assert result["success"] is True + assert result["operation"] == "edit" + assert result["video"] == "https://cdn.example.com/edited.mp4" + assert result["respect_moderation"] is True + + +def test_video_generate_tool_ignores_duration_for_edit(monkeypatch): + from tools.video_generation_tool import video_generate_tool + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["submit_json"] = json + return _FakeResponse(json_payload={"request_id": "vid-edit-duration"}) + + def _fake_get(url, headers=None, timeout=None): + return _FakeResponse( + json_payload={ + "status": "done", + "video": { + "url": "https://cdn.example.com/edited.mp4", + "duration": 8, + }, + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + mock_client = _fake_httpx_client(post_fn=_fake_post, get_fn=_fake_get) + monkeypatch.setattr("tools.video_generation_tool.httpx.AsyncClient", lambda: mock_client) + + result = json.loads( + asyncio.run( + video_generate_tool( + prompt="Give the subject a silver necklace.", + operation="edit", + duration=20, + video_url="https://cdn.example.com/source.mp4", + poll_interval_seconds=0, + timeout_seconds=30, + ) + ) + ) + + assert result["success"] is True + assert "duration" not in captured["submit_json"] + assert result["duration"] == 8 + + +def test_video_generate_tool_supports_promptless_image_to_video(monkeypatch): + from tools.video_generation_tool import video_generate_tool + + captured = {} + + def _fake_post(url, headers=None, json=None, timeout=None): + captured["submit_url"] = url + captured["submit_json"] = json + return _FakeResponse(json_payload={"request_id": "vid-i2v"}) + + def _fake_get(url, headers=None, timeout=None): + return _FakeResponse( + json_payload={ + "status": "done", + "video": { + "url": "https://cdn.example.com/i2v.mp4", + "duration": 8, + "respect_moderation": True, + }, + } + ) + + monkeypatch.setenv("XAI_API_KEY", "xai-test-key") + mock_client = _fake_httpx_client(post_fn=_fake_post, get_fn=_fake_get) + monkeypatch.setattr("tools.video_generation_tool.httpx.AsyncClient", lambda: mock_client) + + result = json.loads( + asyncio.run( + video_generate_tool( + prompt="", + operation="generate", + image_url="https://cdn.example.com/still.png", + seconds=8, + aspect_ratio="4:3", + resolution="480p", + size="848x480", + poll_interval_seconds=0, + timeout_seconds=30, + ) + ) + ) + + assert captured["submit_url"] == "https://api.x.ai/v1/videos/generations" + assert "prompt" not in captured["submit_json"] + assert captured["submit_json"]["image"]["url"] == "https://cdn.example.com/still.png" + assert captured["submit_json"]["duration"] == 8 + assert captured["submit_json"]["aspect_ratio"] == "4:3" + assert captured["submit_json"]["resolution"] == "480p" + assert captured["submit_json"]["size"] == "848x480" + assert result["success"] is True + assert result["video"] == "https://cdn.example.com/i2v.mp4" diff --git a/tools/image_generation_tool.py b/tools/image_generation_tool.py index 487b9b8db8..2ad0fff193 100644 --- a/tools/image_generation_tool.py +++ b/tools/image_generation_tool.py @@ -2,8 +2,9 @@ """ Image Generation Tools Module -This module provides image generation tools using FAL.ai's FLUX 2 Pro model with -automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality. +This module provides image generation tools using either: +- FAL.ai FLUX 2 Pro with automatic Clarity upscaling +- xAI grok-imagine-image Available tools: - image_generate_tool: Generate images from text prompts with automatic upscaling @@ -34,17 +35,22 @@ import os import datetime import threading import uuid +import requests from typing import Dict, Any, Optional, Union from urllib.parse import urlencode -import fal_client from tools.debug_helpers import DebugSession from tools.managed_tool_gateway import resolve_managed_tool_gateway from tools.tool_backend_helpers import managed_nous_tools_enabled +from tools.xai_http import hermes_xai_user_agent logger = logging.getLogger(__name__) # Configuration for image generation +DEFAULT_PROVIDER = "auto" +DEFAULT_OPERATION = "generate" DEFAULT_MODEL = "fal-ai/flux-2-pro" +DEFAULT_XAI_MODEL = "grok-imagine-image" +DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1" DEFAULT_ASPECT_RATIO = "landscape" DEFAULT_NUM_INFERENCE_STEPS = 50 DEFAULT_GUIDANCE_SCALE = 4.5 @@ -79,6 +85,30 @@ VALID_IMAGE_SIZES = [ ] VALID_OUTPUT_FORMATS = ["jpeg", "png"] VALID_ACCELERATION_MODES = ["none", "regular", "high"] +XAI_ASPECT_RATIO_MAP = { + "landscape": "16:9", + "square": "1:1", + "portrait": "9:16", +} +VALID_XAI_ASPECT_RATIOS = { + "auto", + "1:1", + "16:9", + "9:16", + "4:3", + "3:4", + "3:2", + "2:3", + "2:1", + "1:2", + "19.5:9", + "9:19.5", + "20:9", + "9:20", +} +VALID_XAI_RESOLUTIONS = {"1k", "2k"} +VALID_XAI_RESPONSE_FORMATS = {"url", "b64_json"} +VALID_XAI_OPERATIONS = {"generate", "edit"} _debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG") _managed_fal_client = None @@ -86,6 +116,13 @@ _managed_fal_client_config = None _managed_fal_client_lock = threading.Lock() +def _import_fal_client(): + """Lazy import fal_client so xAI-only users can still use image generation.""" + import fal_client + + return fal_client + + def _resolve_managed_fal_gateway(): """Return managed fal-queue gateway config when direct FAL credentials are absent.""" if os.getenv("FAL_KEY"): @@ -104,6 +141,7 @@ class _ManagedFalSyncClient: """Small per-instance wrapper around fal_client.SyncClient for managed queue hosts.""" def __init__(self, *, key: str, queue_run_origin: str): + fal_client = _import_fal_client() sync_client_class = getattr(fal_client, "SyncClient", None) if sync_client_class is None: raise RuntimeError("fal_client.SyncClient is required for managed FAL gateway mode") @@ -204,6 +242,7 @@ def _submit_fal_request(model: str, arguments: Dict[str, Any]): request_headers = {"x-idempotency-key": str(uuid.uuid4())} managed_gateway = _resolve_managed_fal_gateway() if managed_gateway is None: + fal_client = _import_fal_client() return fal_client.submit(model, arguments=arguments, headers=request_headers) managed_client = _get_managed_fal_client(managed_gateway) @@ -214,6 +253,246 @@ def _submit_fal_request(model: str, arguments: Dict[str, Any]): ) +def _has_fal_backend() -> bool: + """Return True when FAL image generation can run with direct or managed auth.""" + if not (os.getenv("FAL_KEY") or _resolve_managed_fal_gateway()): + return False + try: + _import_fal_client() + return True + except ImportError: + return False + + +def _has_xai_image_backend() -> bool: + return bool(os.getenv("XAI_API_KEY", "").strip()) + + +def _normalize_provider(provider: Optional[str]) -> str: + normalized = (provider or DEFAULT_PROVIDER).lower().strip() + aliases = { + "grok": "xai", + "x-ai": "xai", + "x.ai": "xai", + } + normalized = aliases.get(normalized, normalized) + if normalized not in {"auto", "fal", "xai"}: + raise ValueError("provider must be one of: auto, fal, xai") + return normalized + + +def _resolve_image_provider( + provider: Optional[str], + *, + prefer_xai: bool = False, +) -> str: + requested = _normalize_provider(provider) + if requested == "auto" and not prefer_xai: + try: + from hermes_cli.config import load_config + + configured_provider = _normalize_provider( + (load_config().get("image_generation", {}) or {}).get("provider") + ) + if configured_provider != "auto": + requested = configured_provider + except Exception: + pass + if requested != "auto": + return requested + if prefer_xai and _has_xai_image_backend(): + return "xai" + if prefer_xai: + raise ValueError( + "This image request requires xAI image support. Configure XAI_API_KEY or call image_generate with provider='fal' only for basic generation." + ) + if _has_fal_backend(): + return "fal" + if _has_xai_image_backend(): + return "xai" + return "fal" + + +def _data_uri_from_b64(encoded: str, output_format: str) -> str: + mime = "image/png" if output_format == "png" else "image/jpeg" + return f"data:{mime};base64,{encoded}" + + +def _normalize_xai_aspect_ratio(aspect_ratio: Optional[str]) -> str: + normalized = (aspect_ratio or DEFAULT_ASPECT_RATIO).strip().lower() + return XAI_ASPECT_RATIO_MAP.get(normalized, normalized) + + +def _normalize_xai_operation( + operation: Optional[str], + source_image_url: Optional[str], + source_image_urls: Optional[list[str]], +) -> str: + normalized = (operation or "").strip().lower() + if not normalized: + return "edit" if ((source_image_url or "").strip() or (source_image_urls or [])) else DEFAULT_OPERATION + aliases = { + "generate_image": "generate", + "edit_image": "edit", + } + normalized = aliases.get(normalized, normalized) + if normalized not in VALID_XAI_OPERATIONS: + raise ValueError(f"operation must be one of {sorted(VALID_XAI_OPERATIONS)}") + return normalized + + +def _normalize_xai_source_images( + source_image_url: Optional[str], + source_image_urls: Optional[list[str]], +) -> list[dict[str, str]]: + merged: list[str] = [] + if source_image_url and source_image_url.strip(): + merged.append(source_image_url.strip()) + for value in source_image_urls or []: + normalized = (value or "").strip() + if normalized: + merged.append(normalized) + + deduped: list[str] = [] + seen = set() + for value in merged: + if value not in seen: + seen.add(value) + deduped.append(value) + return [{"type": "image_url", "url": value} for value in deduped] + + +def _normalize_xai_reference_images( + reference_image_urls: Optional[list[str]], +) -> list[dict[str, str]]: + deduped: list[str] = [] + seen = set() + for value in reference_image_urls or []: + normalized = (value or "").strip() + if normalized and normalized not in seen: + seen.add(normalized) + deduped.append(normalized) + return [{"type": "image_url", "url": value} for value in deduped] + + +def _generate_image_with_xai( + prompt: str, + operation: str, + aspect_ratio: Optional[str], + num_images: int, + output_format: str, + resolution: Optional[str] = None, + response_format: str = "url", + source_image_url: Optional[str] = None, + source_image_urls: Optional[list[str]] = None, + reference_image_urls: Optional[list[str]] = None, +) -> list[Dict[str, Any]]: + api_key = os.getenv("XAI_API_KEY", "").strip() + if not api_key: + raise ValueError("XAI_API_KEY environment variable not set") + + base_url = (os.getenv("XAI_BASE_URL") or DEFAULT_XAI_BASE_URL).strip().rstrip("/") + normalized_operation = _normalize_xai_operation( + operation, + source_image_url, + source_image_urls, + ) + normalized_aspect_ratio = _normalize_xai_aspect_ratio(aspect_ratio) + normalized_response_format = (response_format or "url").strip().lower() + if normalized_response_format not in VALID_XAI_RESPONSE_FORMATS: + raise ValueError( + f"response_format must be one of {sorted(VALID_XAI_RESPONSE_FORMATS)}" + ) + + normalized_resolution = None + if resolution: + normalized_resolution = (resolution or "").strip().lower() + if normalized_resolution not in VALID_XAI_RESOLUTIONS: + raise ValueError( + f"resolution must be one of {sorted(VALID_XAI_RESOLUTIONS)}" + ) + + payload: Dict[str, Any] = { + "model": DEFAULT_XAI_MODEL, + "prompt": prompt.strip(), + "n": num_images, + } + source_images = _normalize_xai_source_images( + source_image_url, + source_image_urls, + ) + reference_images = _normalize_xai_reference_images(reference_image_urls) + + if normalized_operation == "generate": + if source_images: + raise ValueError("source images are only supported for xAI image edit") + if len(reference_images) > 5: + raise ValueError("xAI image generation supports at most 5 reference images") + if normalized_aspect_ratio not in VALID_XAI_ASPECT_RATIOS: + raise ValueError( + f"aspect_ratio must be one of {sorted(VALID_XAI_ASPECT_RATIOS)} or landscape/square/portrait" + ) + payload["aspect_ratio"] = normalized_aspect_ratio + if reference_images: + payload["reference_images"] = reference_images + endpoint = "images/generations" + else: + if not source_images: + raise ValueError("source_image_url or source_image_urls is required for xAI image edit") + if len(source_images) + len(reference_images) > 5: + raise ValueError("xAI image edit supports at most 5 combined source and reference images") + if len(source_images) == 1: + payload["image"] = source_images[0] + else: + if normalized_aspect_ratio not in VALID_XAI_ASPECT_RATIOS: + raise ValueError( + f"aspect_ratio must be one of {sorted(VALID_XAI_ASPECT_RATIOS)} or landscape/square/portrait" + ) + payload["images"] = source_images + payload["aspect_ratio"] = normalized_aspect_ratio + if reference_images: + payload["reference_images"] = reference_images + endpoint = "images/edits" + + if normalized_resolution: + payload["resolution"] = normalized_resolution + if normalized_response_format == "b64_json": + payload["response_format"] = "b64_json" + + response = requests.post( + f"{base_url}/{endpoint}", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": hermes_xai_user_agent(), + "x-idempotency-key": str(uuid.uuid4()), + }, + json=payload, + timeout=120, + ) + response.raise_for_status() + + result = response.json() + images = [] + for item in result.get("data", []): + image_url = item.get("url") + if not image_url and item.get("b64_json"): + image_url = _data_uri_from_b64(item["b64_json"], output_format) + if not image_url: + continue + images.append( + { + "url": image_url, + "width": item.get("width", 0), + "height": item.get("height", 0), + "upscaled": False, + "provider": "xai", + "operation": normalized_operation, + } + ) + return images + + def _validate_parameters( image_size: Union[str, Dict[str, int]], num_inference_steps: int, @@ -351,11 +630,18 @@ def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]: def image_generate_tool( prompt: str, aspect_ratio: str = DEFAULT_ASPECT_RATIO, + operation: str = DEFAULT_OPERATION, num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS, guidance_scale: float = DEFAULT_GUIDANCE_SCALE, num_images: int = DEFAULT_NUM_IMAGES, output_format: str = DEFAULT_OUTPUT_FORMAT, - seed: Optional[int] = None + seed: Optional[int] = None, + provider: str = DEFAULT_PROVIDER, + resolution: Optional[str] = None, + response_format: str = "url", + source_image_url: Optional[str] = None, + source_image_urls: Optional[list[str]] = None, + reference_image_urls: Optional[list[str]] = None, ) -> str: """ Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling. @@ -397,7 +683,11 @@ def image_generate_tool( "guidance_scale": guidance_scale, "num_images": num_images, "output_format": output_format, - "seed": seed + "seed": seed, + "provider": provider, + "operation": operation, + "resolution": resolution, + "response_format": response_format, }, "error": None, "success": False, @@ -408,98 +698,133 @@ def image_generate_tool( start_time = datetime.datetime.now() try: - logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80]) + normalized_operation = _normalize_xai_operation( + operation, + source_image_url, + source_image_urls, + ) + prefer_xai = ( + normalized_operation == "edit" + or bool((source_image_url or "").strip()) + or bool(source_image_urls) + or bool(reference_image_urls) + or bool(resolution) + or (response_format or "").strip().lower() == "b64_json" + or _normalize_xai_aspect_ratio(aspect_ratio) not in {"16:9", "1:1", "9:16"} + ) + resolved_provider = _resolve_image_provider(provider, prefer_xai=prefer_xai) + debug_call_data["parameters"]["resolved_provider"] = resolved_provider + + logger.info( + "Generating %s image(s) with %s image backend: %s", + num_images, + resolved_provider, + prompt[:80], + ) # Validate prompt if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0: raise ValueError("Prompt is required and must be a non-empty string") - # Check API key availability - if not (os.getenv("FAL_KEY") or _resolve_managed_fal_gateway()): - message = "FAL_KEY environment variable not set" - if managed_nous_tools_enabled(): - message += " and managed FAL gateway is unavailable" - raise ValueError(message) - # Validate other parameters validated_params = _validate_parameters( image_size, num_inference_steps, guidance_scale, num_images, output_format, "none" ) - - # Prepare arguments for FAL.ai FLUX 2 Pro API - arguments = { - "prompt": prompt.strip(), - "image_size": validated_params["image_size"], - "num_inference_steps": validated_params["num_inference_steps"], - "guidance_scale": validated_params["guidance_scale"], - "num_images": validated_params["num_images"], - "output_format": validated_params["output_format"], - "enable_safety_checker": ENABLE_SAFETY_CHECKER, - "safety_tolerance": SAFETY_TOLERANCE, - "sync_mode": True # Use sync mode for immediate results - } - - # Add seed if provided - if seed is not None and isinstance(seed, int): - arguments["seed"] = seed - - logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...") - logger.info(" Model: %s", DEFAULT_MODEL) - logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size) - logger.info(" Steps: %s", validated_params['num_inference_steps']) - logger.info(" Guidance: %s", validated_params['guidance_scale']) - - # Submit request to FAL.ai using sync API (avoids cached event loop issues) - handler = _submit_fal_request( - DEFAULT_MODEL, - arguments=arguments, - ) - - # Get the result (sync — blocks until done) - result = handler.get() - - generation_time = (datetime.datetime.now() - start_time).total_seconds() - - # Process the response - if not result or "images" not in result: - raise ValueError("Invalid response from FAL.ai API - no images returned") - - images = result.get("images", []) - if not images: - raise ValueError("No images were generated") - - # Format image data and upscale images - formatted_images = [] - for img in images: - if isinstance(img, dict) and "url" in img: - original_image = { - "url": img["url"], - "width": img.get("width", 0), - "height": img.get("height", 0) - } - - # Attempt to upscale the image - upscaled_image = _upscale_image(img["url"], prompt.strip()) - - if upscaled_image: - # Use upscaled image if successful - formatted_images.append(upscaled_image) - else: - # Fall back to original image if upscaling fails - logger.warning("Using original image as fallback") - original_image["upscaled"] = False - formatted_images.append(original_image) - + + if resolved_provider == "fal": + if source_image_url or source_image_urls or reference_image_urls or normalized_operation == "edit": + raise ValueError("FAL image backend only supports generation. Use provider='xai' for image edit/reference workflows.") + if not (os.getenv("FAL_KEY") or _resolve_managed_fal_gateway()): + message = "FAL_KEY environment variable not set" + if managed_nous_tools_enabled(): + message += " and managed FAL gateway is unavailable" + raise ValueError(message) + + arguments = { + "prompt": prompt.strip(), + "image_size": validated_params["image_size"], + "num_inference_steps": validated_params["num_inference_steps"], + "guidance_scale": validated_params["guidance_scale"], + "num_images": validated_params["num_images"], + "output_format": validated_params["output_format"], + "enable_safety_checker": ENABLE_SAFETY_CHECKER, + "safety_tolerance": SAFETY_TOLERANCE, + "sync_mode": True, + } + + if seed is not None and isinstance(seed, int): + arguments["seed"] = seed + + logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...") + logger.info(" Model: %s", DEFAULT_MODEL) + logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size) + logger.info(" Steps: %s", validated_params["num_inference_steps"]) + logger.info(" Guidance: %s", validated_params["guidance_scale"]) + + handler = _submit_fal_request( + DEFAULT_MODEL, + arguments=arguments, + ) + result = handler.get() + + if not result or "images" not in result: + raise ValueError("Invalid response from FAL.ai API - no images returned") + + images = result.get("images", []) + if not images: + raise ValueError("No images were generated") + + formatted_images = [] + for img in images: + if isinstance(img, dict) and "url" in img: + original_image = { + "url": img["url"], + "width": img.get("width", 0), + "height": img.get("height", 0), + "provider": "fal", + } + + upscaled_image = _upscale_image(img["url"], prompt.strip()) + + if upscaled_image: + upscaled_image["provider"] = "fal" + formatted_images.append(upscaled_image) + else: + logger.warning("Using original image as fallback") + original_image["upscaled"] = False + formatted_images.append(original_image) + else: + logger.info("Submitting generation request to xAI image API...") + logger.info(" Model: %s", DEFAULT_XAI_MODEL) + logger.info(" Operation: %s", normalized_operation) + formatted_images = _generate_image_with_xai( + prompt=prompt, + operation=normalized_operation, + aspect_ratio=aspect_ratio, + num_images=validated_params["num_images"], + output_format=validated_params["output_format"], + resolution=resolution, + response_format=response_format, + source_image_url=source_image_url, + source_image_urls=source_image_urls, + reference_image_urls=reference_image_urls, + ) + if not formatted_images: - raise ValueError("No valid image URLs returned from API") - + raise ValueError(f"No valid image URLs returned from {resolved_provider} API") + + generation_time = (datetime.datetime.now() - start_time).total_seconds() + upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False)) logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count) # Prepare successful response - minimal format response_data = { "success": True, - "image": formatted_images[0]["url"] if formatted_images else None + "image": formatted_images[0]["url"] if formatted_images else None, + "provider": resolved_provider, + "operation": formatted_images[0].get("operation", normalized_operation), + "images": formatted_images, } debug_call_data["success"] = True @@ -551,15 +876,11 @@ def check_image_generation_requirements() -> bool: bool: True if requirements are met, False otherwise """ try: - # Check API key - if not check_fal_api_key(): - return False - - # Check if fal_client is available - import fal_client # noqa: F401 — SDK presence check - return True - - except ImportError: + if _has_fal_backend() or _has_xai_image_backend(): + return True + return False + + except Exception: return False @@ -646,7 +967,7 @@ from tools.registry import registry, tool_error IMAGE_GENERATE_SCHEMA = { "name": "image_generate", - "description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL. Display it using markdown: ![description](URL)", + "description": "Generate or edit images. FAL supports text-to-image generation; xAI grok-imagine-image supports generation, single-image edits, multi-image edits, source/reference images, extra aspect ratios, 1k/2k resolution, and optional base64 output. Returns a primary image URL plus an images list.", "parameters": { "type": "object", "properties": { @@ -654,11 +975,55 @@ IMAGE_GENERATE_SCHEMA = { "type": "string", "description": "The text prompt describing the desired image. Be detailed and descriptive." }, + "operation": { + "type": "string", + "enum": sorted(VALID_XAI_OPERATIONS), + "description": "Use 'generate' for a new image or 'edit' to transform one or more source images. If source_image_url/source_image_urls are provided, xAI edit mode is used automatically.", + "default": DEFAULT_OPERATION + }, + "provider": { + "type": "string", + "enum": ["auto", "fal", "xai"], + "description": "Image backend to use. 'auto' prefers xAI when you request xAI-only features such as edit, source images, extra aspect ratios, 1k/2k resolution, or b64_json output; otherwise it prefers FAL when available.", + "default": "auto" + }, "aspect_ratio": { "type": "string", - "enum": ["landscape", "square", "portrait"], - "description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.", + "enum": ["landscape", "square", "portrait", "auto", "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "2:1", "1:2", "19.5:9", "9:19.5", "20:9", "9:20"], + "description": "Aspect ratio. FAL supports landscape/square/portrait. xAI also supports direct ratios like 3:2, 4:3, 2:1, 20:9, and auto.", "default": "landscape" + }, + "num_images": { + "type": "integer", + "description": "Number of images to generate. Best used with xAI generate mode.", + "default": DEFAULT_NUM_IMAGES, + "minimum": 1, + "maximum": 4 + }, + "resolution": { + "type": "string", + "enum": ["1k", "2k"], + "description": "xAI-only image resolution." + }, + "response_format": { + "type": "string", + "enum": sorted(VALID_XAI_RESPONSE_FORMATS), + "description": "xAI-only response format. Use b64_json to force inline base64 output.", + "default": "url" + }, + "source_image_url": { + "type": "string", + "description": "Optional source image URL or data URI for xAI image editing." + }, + "source_image_urls": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of source image URLs or data URIs for xAI multi-image editing. Up to 5." + }, + "reference_image_urls": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional xAI reference images. For generate mode they guide style/content; for edit mode they are combined with source images. Up to 5 combined images total." } }, "required": ["prompt"] @@ -672,10 +1037,17 @@ def _handle_image_generate(args, **kw): return tool_error("prompt is required for image generation") return image_generate_tool( prompt=prompt, + operation=args.get("operation", DEFAULT_OPERATION), + provider=args.get("provider", "auto"), aspect_ratio=args.get("aspect_ratio", "landscape"), + resolution=args.get("resolution"), + response_format=args.get("response_format", "url"), + source_image_url=args.get("source_image_url"), + source_image_urls=args.get("source_image_urls"), + reference_image_urls=args.get("reference_image_urls"), num_inference_steps=50, guidance_scale=4.5, - num_images=1, + num_images=args.get("num_images", 1), output_format="png", seed=None, ) diff --git a/tools/video_generation_tool.py b/tools/video_generation_tool.py new file mode 100644 index 0000000000..6a383bb2d3 --- /dev/null +++ b/tools/video_generation_tool.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +""" +Video generation tool using xAI's async video API. +""" + +import asyncio +import json +import logging +import os +import uuid +from typing import Any, Dict, List, Optional + +import httpx + +from tools.registry import registry, tool_error +from tools.xai_http import hermes_xai_user_agent + +logger = logging.getLogger(__name__) + +DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1" +DEFAULT_XAI_VIDEO_MODEL = "grok-imagine-video" +DEFAULT_OPERATION = "generate" +DEFAULT_DURATION = 8 +DEFAULT_ASPECT_RATIO = "16:9" +DEFAULT_RESOLUTION = "720p" +DEFAULT_TIMEOUT_SECONDS = 240 +DEFAULT_POLL_INTERVAL_SECONDS = 5 +VALID_ASPECT_RATIOS = {"1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"} +VALID_RESOLUTIONS = {"480p", "720p"} +VALID_SIZES = {"848x480", "1696x960", "1280x720", "1920x1080"} +VALID_OPERATIONS = {"generate", "edit", "extend"} + + +def _get_xai_base_url() -> str: + return (os.getenv("XAI_BASE_URL") or DEFAULT_XAI_BASE_URL).strip().rstrip("/") + + +def check_video_generation_requirements() -> bool: + return bool(os.getenv("XAI_API_KEY", "").strip()) + + +def _xai_headers() -> Dict[str, str]: + api_key = os.getenv("XAI_API_KEY", "").strip() + if not api_key: + raise ValueError("XAI_API_KEY not set. Get one at https://console.x.ai/") + return { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": hermes_xai_user_agent(), + } + + +def _normalize_reference_images( + image_url: Optional[str], + reference_image_urls: Optional[List[str]], +) -> tuple[Optional[Dict[str, str]], Optional[List[Dict[str, str]]]]: + primary_image = None + if image_url and image_url.strip(): + primary_image = {"url": image_url.strip()} + + refs = [] + for url in reference_image_urls or []: + normalized = (url or "").strip() + if normalized: + refs.append({"url": normalized}) + return primary_image, refs or None + + +def _normalize_operation( + operation: Optional[str], + video_url: Optional[str], + prompt: Optional[str], +) -> str: + normalized = (operation or "").strip().lower() + if not normalized: + if (video_url or "").strip(): + prompt_lower = (prompt or "").strip().lower() + if not prompt_lower: + return "extend" + extend_cues = ( + "extend", + "continue", + "continuation", + "longer", + "further", + "keep going", + "carry on", + "more of", + ) + return "extend" if any(cue in prompt_lower for cue in extend_cues) else "edit" + return DEFAULT_OPERATION + aliases = { + "generate_video": "generate", + "edit_video": "edit", + "extend_video": "extend", + } + normalized = aliases.get(normalized, normalized) + if normalized not in VALID_OPERATIONS: + raise ValueError(f"operation must be one of {sorted(VALID_OPERATIONS)}") + return normalized + + +def _normalize_duration( + *, + operation: str, + duration: Optional[int], + seconds: Optional[int], + reference_images_present: bool, +) -> int: + if operation == "edit": + # xAI video edits inherit duration from the source video. Ignore any + # caller-provided duration/seconds instead of rejecting the request. + return DEFAULT_DURATION + + value = seconds if seconds is not None else duration + if value is None: + value = 6 if operation == "extend" else DEFAULT_DURATION + + if value < 1: + raise ValueError("duration must be at least 1 second") + + if operation == "extend": + if value > 10: + raise ValueError("xAI video extension supports a maximum duration of 10 seconds") + else: + if value > 15: + raise ValueError("xAI video generation supports a maximum duration of 15 seconds") + if reference_images_present and value > 10: + raise ValueError( + "xAI video generation supports a maximum duration of 10 seconds when using reference_image_urls" + ) + return value + + +async def _submit_video_request( + client: httpx.AsyncClient, + operation: str, + payload: Dict[str, Any], +) -> str: + endpoint_map = { + "generate": "videos/generations", + "edit": "videos/edits", + "extend": "videos/extensions", + } + submit_response = await client.post( + f"{_get_xai_base_url()}/{endpoint_map[operation]}", + headers={**_xai_headers(), "x-idempotency-key": str(uuid.uuid4())}, + json=payload, + timeout=60, + ) + submit_response.raise_for_status() + submit_payload = submit_response.json() + request_id = submit_payload.get("request_id") + if not request_id: + raise RuntimeError("xAI video response did not include request_id") + return request_id + + +async def video_generate_tool( + prompt: Optional[str] = None, + operation: Optional[str] = None, + duration: Optional[int] = DEFAULT_DURATION, + seconds: Optional[int] = None, + aspect_ratio: str = DEFAULT_ASPECT_RATIO, + resolution: str = DEFAULT_RESOLUTION, + size: Optional[str] = None, + video_url: Optional[str] = None, + image_url: Optional[str] = None, + reference_image_urls: Optional[List[str]] = None, + user: Optional[str] = None, + timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, + poll_interval_seconds: int = DEFAULT_POLL_INTERVAL_SECONDS, + prompt_source: Optional[str] = None, +) -> str: + normalized_prompt = (prompt or "").strip() + normalized_video_url = (video_url or "").strip() or None + notes: List[str] = [] + + try: + normalized_operation = _normalize_operation(operation, normalized_video_url, normalized_prompt) + except ValueError as e: + return tool_error(str(e)) + + normalized_aspect_ratio = (aspect_ratio or DEFAULT_ASPECT_RATIO).strip() + normalized_resolution = (resolution or DEFAULT_RESOLUTION).strip().lower() + normalized_size = (size or "").strip() + normalized_user = (user or "").strip() or None + + if normalized_operation == "extend" and not normalized_prompt: + normalized_prompt = "Continue the existing video naturally." + notes.append("used a default continuation prompt because extend was requested without a prompt") + elif prompt_source == "user_task_fallback" and normalized_prompt: + notes.append("used the current user message as prompt because the model omitted prompt") + if normalized_operation == "edit" and not normalized_prompt: + return tool_error(f"prompt is required for xAI video {normalized_operation}") + if normalized_operation == "generate" and not normalized_prompt and not (image_url or "").strip(): + return tool_error("prompt is required for text-to-video generation unless image_url is provided") + + if timeout_seconds < 10: + return tool_error("timeout_seconds must be at least 10") + if poll_interval_seconds < 1: + return tool_error("poll_interval_seconds must be at least 1") + + primary_image, refs = _normalize_reference_images(image_url, reference_image_urls) + if refs and len(refs) > 7: + return tool_error("reference_image_urls supports at most 7 images with xAI") + + try: + normalized_duration = _normalize_duration( + operation=normalized_operation, + duration=duration, + seconds=seconds, + reference_images_present=bool(refs), + ) + except ValueError as e: + return tool_error(str(e)) + + payload: Dict[str, Any] = { + "model": DEFAULT_XAI_VIDEO_MODEL, + } + if normalized_prompt: + payload["prompt"] = normalized_prompt + if normalized_user: + payload["user"] = normalized_user + + if normalized_operation == "generate": + if normalized_aspect_ratio not in VALID_ASPECT_RATIOS: + return tool_error( + f"aspect_ratio must be one of {sorted(VALID_ASPECT_RATIOS)}" + ) + if normalized_resolution not in VALID_RESOLUTIONS: + return tool_error( + f"resolution must be one of {sorted(VALID_RESOLUTIONS)}" + ) + if normalized_size and normalized_size not in VALID_SIZES: + return tool_error( + f"size must be one of {sorted(VALID_SIZES)}" + ) + if primary_image and refs: + return tool_error( + "image_url and reference_image_urls cannot be combined for xAI video generation" + ) + payload.update( + { + "duration": normalized_duration, + "aspect_ratio": normalized_aspect_ratio, + "resolution": normalized_resolution, + } + ) + if normalized_size: + payload["size"] = normalized_size + if primary_image: + payload["image"] = primary_image + if refs: + payload["reference_images"] = refs + + elif normalized_operation == "edit": + if not normalized_video_url: + return tool_error("video_url is required for xAI video edit") + if primary_image or refs: + return tool_error("image_url and reference_image_urls are not supported for xAI video edit") + payload["video"] = {"url": normalized_video_url} + notes.append("duration, aspect_ratio, and resolution are inherited from the source video for xAI video edit") + + else: + if not normalized_video_url: + return tool_error("video_url is required for xAI video extension") + if primary_image or refs: + return tool_error("image_url and reference_image_urls are not supported for xAI video extension") + payload["duration"] = normalized_duration + payload["video"] = {"url": normalized_video_url} + + try: + async with httpx.AsyncClient() as client: + request_id = await _submit_video_request(client, normalized_operation, payload) + + elapsed = 0.0 + last_status = "queued" + while elapsed < timeout_seconds: + status_response = await client.get( + f"{_get_xai_base_url()}/videos/{request_id}", + headers=_xai_headers(), + timeout=30, + ) + status_response.raise_for_status() + status_payload = status_response.json() + last_status = (status_payload.get("status") or "").lower() + + if last_status == "done": + video = status_payload.get("video") or {} + video_url = video.get("url") + if not video_url: + raise RuntimeError("xAI video generation completed without a video URL") + return json.dumps( + { + "success": True, + "provider": "xai", + "operation": normalized_operation, + "request_id": request_id, + "status": "done", + "video": video_url, + "duration": video.get("duration", normalized_duration), + "aspect_ratio": normalized_aspect_ratio if normalized_operation == "generate" else None, + "resolution": normalized_resolution if normalized_operation == "generate" else None, + "size": normalized_size if normalized_operation == "generate" else None, + "respect_moderation": video.get("respect_moderation"), + "model": status_payload.get("model"), + "usage": status_payload.get("usage"), + "notes": notes, + }, + ensure_ascii=False, + ) + + if last_status in {"failed", "error", "expired", "cancelled"}: + error_message = ( + status_payload.get("error", {}).get("message") + or status_payload.get("message") + or f"Video generation ended with status '{last_status}'" + ) + return json.dumps( + { + "success": False, + "provider": "xai", + "operation": normalized_operation, + "request_id": request_id, + "status": last_status, + "error": error_message, + }, + ensure_ascii=False, + ) + + await asyncio.sleep(poll_interval_seconds) + elapsed += poll_interval_seconds + + return json.dumps( + { + "success": False, + "provider": "xai", + "operation": normalized_operation, + "request_id": request_id, + "status": last_status, + "error": f"Timed out waiting for video generation after {timeout_seconds} seconds", + }, + ensure_ascii=False, + ) + except Exception as e: + logger.error("Video generation failed: %s", e, exc_info=True) + return json.dumps( + { + "success": False, + "provider": "xai", + "operation": normalized_operation, + "error": str(e), + "error_type": type(e).__name__, + }, + ensure_ascii=False, + ) + + +VIDEO_GENERATE_SCHEMA = { + "name": "video_generate", + "description": "Generate, edit, or extend short videos with xAI grok-imagine-video. Supports text-to-video, image-to-video, reference-image-guided generation, native video edits, and native video extensions.", + "parameters": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "Describe the video to generate, edit, or extend. Usually pass this whenever the user provides motion, scene, style, edit, or continuation instructions. Optional only for image-to-video calls where the image alone is the complete instruction.", + }, + "operation": { + "type": "string", + "enum": sorted(VALID_OPERATIONS), + "description": "Video mode. Use 'generate' for new videos, 'edit' to modify an existing video, and 'extend' to continue an existing video.", + "default": DEFAULT_OPERATION, + }, + "duration": { + "type": "integer", + "description": "Requested duration in seconds. Generate supports 1-15 seconds. Extend supports 1-10 seconds. For xAI video edit, the source video duration is retained.", + "default": DEFAULT_DURATION, + }, + "seconds": { + "type": "integer", + "description": "Alias for duration for OpenAI-compatible callers.", + }, + "aspect_ratio": { + "type": "string", + "enum": sorted(VALID_ASPECT_RATIOS), + "description": "Output aspect ratio for generate mode.", + "default": DEFAULT_ASPECT_RATIO, + }, + "resolution": { + "type": "string", + "enum": sorted(VALID_RESOLUTIONS), + "description": "Output resolution for generate mode.", + "default": DEFAULT_RESOLUTION, + }, + "size": { + "type": "string", + "enum": sorted(VALID_SIZES), + "description": "Optional explicit output size for generate mode.", + }, + "video_url": { + "type": "string", + "description": "Required for edit and extend modes. Source video URL to modify or continue.", + }, + "image_url": { + "type": "string", + "description": "Optional source image URL for image-to-video generation.", + }, + "reference_image_urls": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional reference image URLs for generate mode. Use these to carry people, objects, or clothing into a new video without fixing the first frame.", + }, + "user": { + "type": "string", + "description": "Optional end-user identifier forwarded to xAI.", + }, + }, + "required": ["prompt"], + }, +} + + +async def _handle_video_generate(args, **kw): + prompt = args.get("prompt", "") + prompt_source = None + if not (prompt or "").strip(): + user_task = kw.get("user_task") + if user_task and isinstance(user_task, str) and user_task.strip(): + prompt = user_task.strip() + prompt_source = "user_task_fallback" + logger.info("video_generate: prompt was empty, falling back to user_task=%r", prompt[:100]) + return await video_generate_tool( + prompt=prompt, + operation=args.get("operation"), + duration=args.get("duration", DEFAULT_DURATION), + seconds=args.get("seconds"), + aspect_ratio=args.get("aspect_ratio", DEFAULT_ASPECT_RATIO), + resolution=args.get("resolution", DEFAULT_RESOLUTION), + size=args.get("size"), + video_url=args.get("video_url"), + image_url=args.get("image_url"), + reference_image_urls=args.get("reference_image_urls"), + user=args.get("user"), + prompt_source=prompt_source, + ) + + +registry.register( + name="video_generate", + toolset="video_gen", + schema=VIDEO_GENERATE_SCHEMA, + handler=_handle_video_generate, + check_fn=check_video_generation_requirements, + requires_env=["XAI_API_KEY"], + is_async=True, + emoji="🎬", +) diff --git a/tools/x_search_tool.py b/tools/x_search_tool.py new file mode 100644 index 0000000000..0af052790f --- /dev/null +++ b/tools/x_search_tool.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 +""" +X Search tool backed by xAI's built-in x_search Responses API tool. +""" + +import json +import logging +import os +import time +from typing import Any, Dict, List, Optional + +import requests + +from tools.registry import registry, tool_error +from tools.xai_http import hermes_xai_user_agent + +logger = logging.getLogger(__name__) + +DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1" +DEFAULT_X_SEARCH_MODEL = "grok-4.20-reasoning" +DEFAULT_X_SEARCH_TIMEOUT_SECONDS = 180 +DEFAULT_X_SEARCH_RETRIES = 2 +MAX_HANDLES = 10 + + +def _get_xai_base_url() -> str: + return (os.getenv("XAI_BASE_URL") or DEFAULT_XAI_BASE_URL).strip().rstrip("/") + + +def _load_x_search_config() -> Dict[str, Any]: + try: + from hermes_cli.config import load_config + + return load_config().get("x_search", {}) + except Exception: + return {} + + +def _get_x_search_model() -> str: + cfg = _load_x_search_config() + return (cfg.get("model") or DEFAULT_X_SEARCH_MODEL).strip() + + +def _get_x_search_timeout_seconds() -> int: + cfg = _load_x_search_config() + raw_value = cfg.get("timeout_seconds", DEFAULT_X_SEARCH_TIMEOUT_SECONDS) + try: + return max(30, int(raw_value)) + except Exception: + return DEFAULT_X_SEARCH_TIMEOUT_SECONDS + + +def _get_x_search_retries() -> int: + cfg = _load_x_search_config() + raw_value = cfg.get("retries", DEFAULT_X_SEARCH_RETRIES) + try: + return max(0, int(raw_value)) + except Exception: + return DEFAULT_X_SEARCH_RETRIES + + +def check_x_search_requirements() -> bool: + return bool(os.getenv("XAI_API_KEY", "").strip()) + + +def _normalize_handles(handles: Optional[List[str]], field_name: str) -> List[str]: + cleaned = [] + for handle in handles or []: + normalized = str(handle or "").strip().lstrip("@") + if normalized: + cleaned.append(normalized) + if len(cleaned) > MAX_HANDLES: + raise ValueError(f"{field_name} supports at most {MAX_HANDLES} handles") + return cleaned + + +def _extract_response_text(payload: Dict[str, Any]) -> str: + output_text = str(payload.get("output_text") or "").strip() + if output_text: + return output_text + + parts: List[str] = [] + for item in payload.get("output", []) or []: + if item.get("type") != "message": + continue + for content in item.get("content", []) or []: + ctype = content.get("type") + if ctype in ("output_text", "text"): + text = str(content.get("text") or "").strip() + if text: + parts.append(text) + return "\n\n".join(parts).strip() + + +def _extract_inline_citations(payload: Dict[str, Any]) -> List[Dict[str, Any]]: + citations = [] + for item in payload.get("output", []) or []: + if item.get("type") != "message": + continue + for content in item.get("content", []) or []: + for annotation in content.get("annotations", []) or []: + if annotation.get("type") != "url_citation": + continue + citations.append( + { + "url": annotation.get("url", ""), + "title": annotation.get("title", ""), + "start_index": annotation.get("start_index"), + "end_index": annotation.get("end_index"), + } + ) + return citations + + +def _http_error_message(exc: requests.HTTPError) -> str: + response = getattr(exc, "response", None) + if response is None: + return str(exc) + + try: + payload = response.json() + except Exception: + payload = None + + if isinstance(payload, dict): + code = str(payload.get("code") or "").strip() + error = str(payload.get("error") or "").strip() + message = error or str(payload) + if code and code not in message: + message = f"{code}: {message}" + return message or str(exc) + + text = str(getattr(response, "text", "") or "").strip() + if text: + return text[:500] + return str(exc) + + +def x_search_tool( + query: str, + allowed_x_handles: Optional[List[str]] = None, + excluded_x_handles: Optional[List[str]] = None, + from_date: str = "", + to_date: str = "", + enable_image_understanding: bool = False, + enable_video_understanding: bool = False, +) -> str: + if not query or not query.strip(): + return tool_error("query is required for x_search") + + api_key = os.getenv("XAI_API_KEY", "").strip() + if not api_key: + return tool_error("XAI_API_KEY is not set") + + try: + allowed = _normalize_handles(allowed_x_handles, "allowed_x_handles") + excluded = _normalize_handles(excluded_x_handles, "excluded_x_handles") + if allowed and excluded: + return tool_error("allowed_x_handles and excluded_x_handles cannot be used together") + + tool_def: Dict[str, Any] = {"type": "x_search"} + if allowed: + tool_def["allowed_x_handles"] = allowed + if excluded: + tool_def["excluded_x_handles"] = excluded + if from_date.strip(): + tool_def["from_date"] = from_date.strip() + if to_date.strip(): + tool_def["to_date"] = to_date.strip() + if enable_image_understanding: + tool_def["enable_image_understanding"] = True + if enable_video_understanding: + tool_def["enable_video_understanding"] = True + + payload = { + "model": _get_x_search_model(), + "input": [ + { + "role": "user", + "content": query.strip(), + } + ], + "tools": [tool_def], + "store": False, + } + + timeout_seconds = _get_x_search_timeout_seconds() + max_retries = _get_x_search_retries() + response = None + for attempt in range(max_retries + 1): + try: + response = requests.post( + f"{_get_xai_base_url()}/responses", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": hermes_xai_user_agent(), + }, + json=payload, + timeout=timeout_seconds, + ) + response.raise_for_status() + break + except requests.HTTPError as e: + status_code = getattr(getattr(e, "response", None), "status_code", None) + if status_code is None or status_code < 500 or attempt >= max_retries: + raise + logger.warning( + "x_search upstream failure on attempt %s/%s: %s", + attempt + 1, + max_retries + 1, + _http_error_message(e), + ) + time.sleep(min(5.0, 1.5 * (attempt + 1))) + except (requests.ReadTimeout, requests.ConnectionError) as e: + if attempt >= max_retries: + raise + logger.warning( + "x_search transient failure on attempt %s/%s: %s", + attempt + 1, + max_retries + 1, + e, + ) + time.sleep(min(5.0, 1.5 * (attempt + 1))) + + if response is None: + raise RuntimeError("x_search request did not return a response") + + data = response.json() + + answer = _extract_response_text(data) + citations = list(data.get("citations") or []) + inline_citations = _extract_inline_citations(data) + + return json.dumps( + { + "success": True, + "provider": "xai", + "tool": "x_search", + "model": payload["model"], + "query": query.strip(), + "answer": answer, + "citations": citations, + "inline_citations": inline_citations, + }, + ensure_ascii=False, + ) + except requests.HTTPError as e: + logger.error("x_search failed: %s", e, exc_info=True) + return json.dumps( + { + "success": False, + "provider": "xai", + "tool": "x_search", + "error": _http_error_message(e), + "error_type": type(e).__name__, + }, + ensure_ascii=False, + ) + except requests.ReadTimeout as e: + logger.error("x_search timed out: %s", e, exc_info=True) + return json.dumps( + { + "success": False, + "provider": "xai", + "tool": "x_search", + "error": f"xAI x_search timed out after {_get_x_search_timeout_seconds()} seconds", + "error_type": type(e).__name__, + }, + ensure_ascii=False, + ) + except Exception as e: + logger.error("x_search failed: %s", e, exc_info=True) + return json.dumps( + { + "success": False, + "provider": "xai", + "tool": "x_search", + "error": str(e), + "error_type": type(e).__name__, + }, + ensure_ascii=False, + ) + + +X_SEARCH_SCHEMA = { + "name": "x_search", + "description": "Search X (Twitter) posts, profiles, and threads using xAI's built-in X Search tool. Use this for current discussion, reactions, or claims on X rather than general web pages.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "What to look up on X.", + }, + "allowed_x_handles": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of X handles to include exclusively (max 10).", + }, + "excluded_x_handles": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of X handles to exclude (max 10).", + }, + "from_date": { + "type": "string", + "description": "Optional start date in YYYY-MM-DD format.", + }, + "to_date": { + "type": "string", + "description": "Optional end date in YYYY-MM-DD format.", + }, + "enable_image_understanding": { + "type": "boolean", + "description": "Whether xAI should analyze images attached to matching X posts.", + "default": False, + }, + "enable_video_understanding": { + "type": "boolean", + "description": "Whether xAI should analyze videos attached to matching X posts.", + "default": False, + }, + }, + "required": ["query"], + }, +} + + +def _handle_x_search(args, **kw): + return x_search_tool( + query=args.get("query", ""), + allowed_x_handles=args.get("allowed_x_handles"), + excluded_x_handles=args.get("excluded_x_handles"), + from_date=args.get("from_date", ""), + to_date=args.get("to_date", ""), + enable_image_understanding=bool(args.get("enable_image_understanding", False)), + enable_video_understanding=bool(args.get("enable_video_understanding", False)), + ) + + +registry.register( + name="x_search", + toolset="x_search", + schema=X_SEARCH_SCHEMA, + handler=_handle_x_search, + check_fn=check_x_search_requirements, + requires_env=["XAI_API_KEY"], + emoji="🐦", + max_result_size_chars=100_000, +) diff --git a/tools/xai_http.py b/tools/xai_http.py new file mode 100644 index 0000000000..b5bce97c2f --- /dev/null +++ b/tools/xai_http.py @@ -0,0 +1,12 @@ +"""Shared helpers for direct xAI HTTP integrations.""" + +from __future__ import annotations + + +def hermes_xai_user_agent() -> str: + """Return a stable Hermes-specific User-Agent for xAI HTTP calls.""" + try: + from hermes_cli import __version__ + except Exception: + __version__ = "unknown" + return f"Hermes-Agent/{__version__}" diff --git a/toolsets.py b/toolsets.py index 09ee8de09b..972ad665eb 100644 --- a/toolsets.py +++ b/toolsets.py @@ -30,13 +30,13 @@ from typing import List, Dict, Any, Set, Optional # Edit this once to update all platforms simultaneously. _HERMES_CORE_TOOLS = [ # Web - "web_search", "web_extract", + "web_search", "web_extract", "x_search", # Terminal + process management "terminal", "process", # File manipulation "read_file", "write_file", "patch", "search_files", # Vision + image generation - "vision_analyze", "image_generate", + "vision_analyze", "image_generate", "video_generate", # Skills "skills_list", "skill_view", "skill_manage", # Browser automation @@ -78,6 +78,12 @@ TOOLSETS = { "tools": ["web_search"], "includes": [] }, + + "x_search": { + "description": "Search X (Twitter) posts and threads using xAI", + "tools": ["x_search"], + "includes": [] + }, "vision": { "description": "Image analysis and vision tools", @@ -90,6 +96,12 @@ TOOLSETS = { "tools": ["image_generate"], "includes": [] }, + + "video_gen": { + "description": "Creative generation tools (video)", + "tools": ["video_generate"], + "includes": [] + }, "terminal": { "description": "Terminal/command execution and process management tools", @@ -213,7 +225,7 @@ TOOLSETS = { "safe": { "description": "Safe toolkit without terminal access", "tools": [], - "includes": ["web", "vision", "image_gen"] + "includes": ["web", "x_search", "vision", "image_gen", "video_gen"] }, # ========================================================================== @@ -246,13 +258,13 @@ TOOLSETS = { "description": "OpenAI-compatible API server — full agent tools accessible via HTTP (no interactive UI tools like clarify or send_message)", "tools": [ # Web - "web_search", "web_extract", + "web_search", "web_extract", "x_search", # Terminal + process management "terminal", "process", # File manipulation "read_file", "write_file", "patch", "search_files", # Vision + image generation - "vision_analyze", "image_generate", + "vision_analyze", "image_generate", "video_generate", # Skills "skills_list", "skill_view", "skill_manage", # Browser automation