mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-04 09:47:54 +08:00
fix(skills/comfyui): bug fixes, cloud parity, expanded coverage, examples, tests
The audit of v4.1 surfaced ~70 issues across the five scripts and three
reference docs — most user-visible (silent file overwrites, status-error
misclassified as success, X-API-Key leaked to S3 on /api/view redirect,
Cloud endpoints that 404 because they were renamed). v5.0.0 fixes those
and fills the gaps that previously forced users to write their own glue
(WebSocket monitoring, batch/sweep, img2img upload helper, dep auto-fix,
log fetch, health check, example workflows).
Critical fixes
- run_workflow.py: poll_status now checks status_str==error BEFORE
completed:true, so a failed run no longer reports success
- run_workflow.py: download_output streams to disk via safe_path_join,
preserves server subfolder structure (no silent overwrites), and
retries with exponential backoff
- run_workflow.py: refuses to overwrite a link with a literal in
inject_params (would silently break wiring)
- _common.py: _StripSensitiveOnRedirectSession (subclasses
requests.Session.rebuild_auth) drops X-API-Key/Cookie on cross-host
redirects — fixes a real key-leak path through Cloud's signed-URL
download flow. Tested
- Cloud routing (verified live): /history → /history_v2,
/models/<f> → /experiment/models/<f>, plus folder aliases for the
unet ↔ diffusion_models and clip ↔ text_encoders rename
- check_deps.py: distinguishes 200/empty vs 404 folder_not_found vs
403 free-tier; emits concrete fix_command per missing dep
- extract_schema.py: prompt vs negative_prompt determined by tracing
KSampler.{positive,negative} connections (incl. through Reroute /
Primitive nodes) instead of meta-title heuristic; symmetric
duplicate-name resolution; cycle-safe trace_to_node
- hardware_check.py: multi-GPU pick-best, Apple variant detection,
Rosetta detection, WSL2, ROCm --json, disk-space check, optional
PyTorch probe; powershell preferred over deprecated wmic
- comfyui_setup.sh: prefers pipx → uvx → pip --user (with PEP-668
fallback); idempotent — skips relaunch if server already up;
configurable port/workspace; persistent log; SIGINT trap
New scripts
- run_batch.py — count or sweep (cartesian product), parallel up to
cloud tier limit
- ws_monitor.py — real-time WebSocket viewer; saves preview frames
- auto_fix_deps.py — runs comfy node install / model download for
whatever check_deps reports missing (with --dry-run)
- health_check.py — single command that runs the verification checklist
(comfy-cli + server + checkpoints + optional smoke test that cancels
itself to avoid burning compute)
- fetch_logs.py — pull traceback / status messages for a prompt_id
Coverage expansion
- Param patterns now cover Flux (BasicScheduler, BasicGuider,
RandomNoise, ModelSamplingFlux), SD3, Wan/Hunyuan/LTX video,
IPAdapter, rgthree, easy-use, AnimateDiff
- Embedding refs in CLIPTextEncode strings extracted as model deps
- ckpt_name / vae_name / lora_name / unet_name now controllable so
workflows can be retargeted per run
Examples
- workflows/{sd15,sdxl,flux_dev}_txt2img.json
- workflows/sdxl_{img2img,inpaint}.json
- workflows/upscale_4x.json
- workflows/{animatediff_video,wan_video_t2v}.json + README
Tests
- 117 tests (105 unit + 8 cloud integration + 4 cross-host security)
- Cloud tests auto-skip without COMFY_CLOUD_API_KEY; verified end-to-end
against live cloud API
Backwards compatibility
- All existing CLI flags continue to work; new behavior is opt-in
(--ws, --input-image, --randomize-seed, --flat-output, etc.)
This commit is contained in:
50
skills/creative/comfyui/tests/README.md
Normal file
50
skills/creative/comfyui/tests/README.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# ComfyUI Skill Tests
|
||||
|
||||
Pytest suite covering the skill's scripts. Pure-stdlib unit tests run
|
||||
without any setup; cloud integration tests need a Comfy Cloud API key.
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# Unit tests only (no network required) — runs in <1s
|
||||
python3 -m pytest tests/ -c tests/pytest.ini -o addopts="-p no:xdist"
|
||||
|
||||
# Including cloud integration tests
|
||||
COMFY_CLOUD_API_KEY="comfyui-..." python3 -m pytest tests/ \
|
||||
-c tests/pytest.ini -o addopts="-p no:xdist"
|
||||
|
||||
# Just cloud tests
|
||||
COMFY_CLOUD_API_KEY="comfyui-..." python3 -m pytest tests/test_cloud_integration.py \
|
||||
-c tests/pytest.ini -o addopts="-p no:xdist" -v
|
||||
```
|
||||
|
||||
The `-c` and `-o` overrides isolate this suite from any parent
|
||||
`pyproject.toml` pytest config (e.g. the `-n auto` from a parent repo).
|
||||
|
||||
## Test files
|
||||
|
||||
| File | Coverage |
|
||||
|------|----------|
|
||||
| `test_common.py` | Cloud detection, URL routing, format validation, embeddings, paths, seeds, model-list parsing, folder aliases |
|
||||
| `test_extract_schema.py` | Connection tracing, positive/negative prompt detection, dedup logic, embedding deps |
|
||||
| `test_run_workflow.py` | Param injection (incl. -1 seed, link refusal), output download walk, runner construction |
|
||||
| `test_check_deps.py` | Model-name fuzzy matching, install command suggestions |
|
||||
| `test_cloud_integration.py` | Live cloud API contract tests (auto-skipped without API key) |
|
||||
|
||||
## Adding tests
|
||||
|
||||
When you change a script:
|
||||
|
||||
1. Add a unit test if the change is pure logic (cloud detection, parsing, etc.)
|
||||
2. Add a cloud integration test if the change depends on cloud API behavior
|
||||
(use `pytestmark = pytest.mark.cloud` so it auto-skips without a key)
|
||||
3. Workflow fixtures live in `conftest.py` (`sd15_workflow`, `flux_workflow`,
|
||||
`video_workflow`)
|
||||
|
||||
## Why the explicit `-c` / `-o`?
|
||||
|
||||
The parent hermes-agent repo's `pyproject.toml` enables `pytest-xdist` by
|
||||
default (`-n auto`). This suite is small enough that parallelism isn't
|
||||
worth the complexity, and pytest-xdist isn't always installed in the user's
|
||||
environment. The `-c tests/pytest.ini -o addopts="-p no:xdist"` flags make
|
||||
the suite run identically regardless of the parent project's config.
|
||||
64
skills/creative/comfyui/tests/conftest.py
Normal file
64
skills/creative/comfyui/tests/conftest.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Pytest configuration for the comfyui skill test suite.
|
||||
|
||||
Adds `scripts/` to sys.path so tests can `from _common import ...`, and
|
||||
provides a few common fixtures.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
SCRIPTS = ROOT / "scripts"
|
||||
WORKFLOWS = ROOT / "workflows"
|
||||
|
||||
sys.path.insert(0, str(SCRIPTS))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sd15_workflow() -> dict:
|
||||
return json.loads((WORKFLOWS / "sd15_txt2img.json").read_text())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flux_workflow() -> dict:
|
||||
return json.loads((WORKFLOWS / "flux_dev_txt2img.json").read_text())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_workflow() -> dict:
|
||||
return json.loads((WORKFLOWS / "wan_video_t2v.json").read_text())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflows_dir() -> Path:
|
||||
return WORKFLOWS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scripts_dir() -> Path:
|
||||
return SCRIPTS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cloud_key() -> str | None:
|
||||
"""Cloud API key if set, otherwise None.
|
||||
|
||||
Tests that need cloud connectivity should skip when this is None.
|
||||
"""
|
||||
return os.environ.get("COMFY_CLOUD_API_KEY")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Auto-skip cloud tests when no API key is set."""
|
||||
if os.environ.get("COMFY_CLOUD_API_KEY"):
|
||||
return
|
||||
skip_cloud = pytest.mark.skip(reason="Set COMFY_CLOUD_API_KEY to run cloud tests")
|
||||
for item in items:
|
||||
if "cloud" in item.keywords:
|
||||
item.add_marker(skip_cloud)
|
||||
5
skills/creative/comfyui/tests/pytest.ini
Normal file
5
skills/creative/comfyui/tests/pytest.ini
Normal file
@@ -0,0 +1,5 @@
|
||||
[pytest]
|
||||
markers =
|
||||
cloud: tests that hit live Comfy Cloud API (require COMFY_CLOUD_API_KEY)
|
||||
testpaths = .
|
||||
addopts = -p no:xdist
|
||||
65
skills/creative/comfyui/tests/test_check_deps.py
Normal file
65
skills/creative/comfyui/tests/test_check_deps.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Tests for check_deps.py — focuses on parsing logic that doesn't need a server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from check_deps import (
|
||||
NODE_TO_PACKAGE,
|
||||
model_present,
|
||||
normalize_for_match,
|
||||
suggest_install_command,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeForMatch:
|
||||
def test_basic(self):
|
||||
s = normalize_for_match("model.safetensors")
|
||||
assert "model.safetensors" in s
|
||||
assert "model" in s
|
||||
|
||||
def test_subfolder(self):
|
||||
s = normalize_for_match("subdir/model.pt")
|
||||
assert "subdir/model.pt" in s
|
||||
assert "model.pt" in s
|
||||
assert "model" in s
|
||||
|
||||
|
||||
class TestModelPresent:
|
||||
def test_exact_match(self):
|
||||
assert model_present("a.safetensors", {"a.safetensors", "b.safetensors"}) is True
|
||||
|
||||
def test_extension_difference(self):
|
||||
# User said "model" but installed is "model.safetensors"
|
||||
assert model_present("model", {"model.safetensors"}) is True
|
||||
# Reverse direction — also matches
|
||||
assert model_present("model.safetensors", {"model"}) is True
|
||||
|
||||
def test_subfolder_match(self):
|
||||
# Installed list has "subdir/model.safetensors", workflow asks "model.safetensors"
|
||||
assert model_present("model.safetensors", {"subdir/model.safetensors"}) is True
|
||||
|
||||
def test_missing(self):
|
||||
assert model_present("missing.safetensors", {"a.safetensors", "b.safetensors"}) is False
|
||||
|
||||
def test_empty_installed(self):
|
||||
assert model_present("anything.safetensors", set()) is False
|
||||
|
||||
|
||||
class TestSuggestInstallCommand:
|
||||
def test_known_node(self):
|
||||
cmd = suggest_install_command("VHS_VideoCombine")
|
||||
assert cmd == "comfy node install comfyui-videohelpersuite"
|
||||
|
||||
def test_unknown_node(self):
|
||||
assert suggest_install_command("SomeRandomNodeName123") is None
|
||||
|
||||
|
||||
class TestNodePackageMap:
|
||||
def test_no_duplicates(self):
|
||||
# Each node should map to exactly one package
|
||||
keys = list(NODE_TO_PACKAGE.keys())
|
||||
assert len(keys) == len(set(keys))
|
||||
|
||||
def test_all_lowercase_packages(self):
|
||||
# Convention: package names are lowercase with hyphens/underscores
|
||||
for pkg in NODE_TO_PACKAGE.values():
|
||||
assert pkg.lower() == pkg, f"Package name should be lowercase: {pkg}"
|
||||
95
skills/creative/comfyui/tests/test_cloud_integration.py
Normal file
95
skills/creative/comfyui/tests/test_cloud_integration.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Integration tests against the live Comfy Cloud API.
|
||||
|
||||
These tests are auto-skipped when COMFY_CLOUD_API_KEY is not set.
|
||||
They never SUBMIT workflows (would need a paid subscription) — they only
|
||||
verify the read-only endpoints we rely on.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from _common import http_get, parse_model_list, resolve_url
|
||||
|
||||
|
||||
pytestmark = pytest.mark.cloud
|
||||
|
||||
|
||||
class TestCloudEndpointsLive:
|
||||
def test_system_stats_reachable(self, cloud_key):
|
||||
url = resolve_url("https://cloud.comfy.org", "/system_stats")
|
||||
r = http_get(url, headers={"X-API-Key": cloud_key})
|
||||
assert r.status == 200
|
||||
data = r.json()
|
||||
assert "system" in data
|
||||
|
||||
def test_models_endpoint_routed_to_experiment(self, cloud_key):
|
||||
# We expect the skill to route /models/checkpoints → /api/experiment/models/checkpoints
|
||||
url = resolve_url("https://cloud.comfy.org", "/models/checkpoints")
|
||||
assert "/api/experiment/models/checkpoints" in url
|
||||
r = http_get(url, headers={"X-API-Key": cloud_key})
|
||||
assert r.status == 200
|
||||
|
||||
def test_models_endpoint_returns_dicts(self, cloud_key):
|
||||
url = resolve_url("https://cloud.comfy.org", "/models/checkpoints")
|
||||
r = http_get(url, headers={"X-API-Key": cloud_key})
|
||||
data = r.json()
|
||||
assert isinstance(data, list)
|
||||
if data:
|
||||
# Cloud format: list of dicts with `name`
|
||||
assert isinstance(data[0], dict)
|
||||
assert "name" in data[0]
|
||||
# Our parser normalizes both
|
||||
normalized = parse_model_list(data)
|
||||
assert len(normalized) == len(data)
|
||||
|
||||
def test_history_renamed_to_v2(self, cloud_key):
|
||||
# /history → /api/history_v2 on cloud
|
||||
url = resolve_url("https://cloud.comfy.org", "/history/some-fake-id")
|
||||
assert "/api/history_v2/some-fake-id" in url
|
||||
|
||||
def test_object_info_paid_tier(self, cloud_key):
|
||||
# On free tier, /object_info returns 403 with a recognizable message
|
||||
url = resolve_url("https://cloud.comfy.org", "/object_info")
|
||||
r = http_get(url, headers={"X-API-Key": cloud_key})
|
||||
# Should be either 200 (paid) or 403 (free) — not 404 / 500
|
||||
assert r.status in (200, 403)
|
||||
if r.status == 403:
|
||||
# Body should mention the limitation
|
||||
assert "free tier" in r.text().lower() or "subscription" in r.text().lower()
|
||||
|
||||
|
||||
class TestCloudCheckDepsLive:
|
||||
def test_check_deps_against_cloud(self, cloud_key, sd15_workflow):
|
||||
from check_deps import check_deps
|
||||
report = check_deps(sd15_workflow, host="https://cloud.comfy.org", api_key=cloud_key)
|
||||
# Either node check passed OR was skipped (free tier)
|
||||
assert "missing_models" in report
|
||||
assert "is_cloud" in report and report["is_cloud"] is True
|
||||
|
||||
def test_flux_workflow_models_resolved_via_aliases(self, cloud_key, flux_workflow):
|
||||
"""Flux uses unet/clip folders; cloud has them in diffusion_models/text_encoders.
|
||||
With folder aliasing, the check should still find them."""
|
||||
from check_deps import check_deps
|
||||
report = check_deps(flux_workflow, host="https://cloud.comfy.org", api_key=cloud_key)
|
||||
# The exact required Flux files (flux1-dev.safetensors, t5xxl_fp16, clip_l, ae)
|
||||
# are present on cloud; with folder aliasing, none should be missing.
|
||||
# If this fails, either the cloud removed the model or the aliasing logic broke.
|
||||
missing_filenames = {m["value"] for m in report["missing_models"]}
|
||||
assert "ae.safetensors" not in missing_filenames, \
|
||||
"ae.safetensors should be on cloud's vae folder"
|
||||
# t5xxl_fp16 / clip_l should be reachable via the clip → text_encoders alias
|
||||
# flux1-dev.safetensors likewise via unet → diffusion_models
|
||||
|
||||
|
||||
class TestHealthCheckLive:
|
||||
def test_health_check_passes(self, cloud_key, capsys):
|
||||
from health_check import main as health_main
|
||||
rc = health_main(["--host", "https://cloud.comfy.org", "--api-key", cloud_key])
|
||||
captured = capsys.readouterr()
|
||||
# Should produce JSON
|
||||
import json
|
||||
report = json.loads(captured.out)
|
||||
assert report["server"]["reachable"] is True
|
||||
assert report["checkpoints"]["queryable"] is True
|
||||
assert report["checkpoints"]["count"] > 0
|
||||
447
skills/creative/comfyui/tests/test_common.py
Normal file
447
skills/creative/comfyui/tests/test_common.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Unit tests for _common.py — pure logic only, no network."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from _common import (
|
||||
DEFAULT_LOCAL_HOST,
|
||||
EMBEDDING_REGEX,
|
||||
FOLDER_ALIASES,
|
||||
build_cloud_aware_url,
|
||||
cloud_endpoint,
|
||||
coerce_seed,
|
||||
folder_aliases_for,
|
||||
is_api_format,
|
||||
is_cloud_host,
|
||||
is_link,
|
||||
iter_embedding_refs,
|
||||
iter_model_deps,
|
||||
iter_nodes,
|
||||
looks_like_video_workflow,
|
||||
media_type_from_filename,
|
||||
parse_model_list,
|
||||
resolve_url,
|
||||
safe_path_join,
|
||||
unwrap_workflow,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cloud detection / URL routing
|
||||
# =============================================================================
|
||||
|
||||
class TestCloudDetection:
|
||||
def test_cloud_host_exact(self):
|
||||
assert is_cloud_host("https://cloud.comfy.org") is True
|
||||
assert is_cloud_host("https://cloud.comfy.org/foo/bar") is True
|
||||
|
||||
def test_cloud_host_subdomain(self):
|
||||
assert is_cloud_host("https://staging.cloud.comfy.org") is True
|
||||
assert is_cloud_host("https://api.cloud.comfy.org") is True
|
||||
|
||||
def test_local_not_cloud(self):
|
||||
assert is_cloud_host("http://127.0.0.1:8188") is False
|
||||
assert is_cloud_host("http://localhost:8188") is False
|
||||
assert is_cloud_host("http://my-server.local:8188") is False
|
||||
|
||||
def test_no_scheme(self):
|
||||
# Defaults to http://
|
||||
assert is_cloud_host("cloud.comfy.org") is True
|
||||
assert is_cloud_host("127.0.0.1:8188") is False
|
||||
|
||||
|
||||
class TestCloudEndpointRename:
|
||||
def test_history_renamed(self):
|
||||
assert cloud_endpoint("/history") == "/history_v2"
|
||||
assert cloud_endpoint("/history/abc-123") == "/history_v2/abc-123"
|
||||
|
||||
def test_history_v2_preserved(self):
|
||||
assert cloud_endpoint("/history_v2") == "/history_v2"
|
||||
|
||||
def test_models_renamed(self):
|
||||
assert cloud_endpoint("/models") == "/experiment/models"
|
||||
assert cloud_endpoint("/models/checkpoints") == "/experiment/models/checkpoints"
|
||||
assert cloud_endpoint("/models/loras") == "/experiment/models/loras"
|
||||
|
||||
def test_other_paths_unchanged(self):
|
||||
assert cloud_endpoint("/prompt") == "/prompt"
|
||||
assert cloud_endpoint("/queue") == "/queue"
|
||||
|
||||
|
||||
class TestResolveURL:
|
||||
def test_local_no_prefix(self):
|
||||
assert resolve_url("http://127.0.0.1:8188", "/prompt") == "http://127.0.0.1:8188/prompt"
|
||||
|
||||
def test_cloud_adds_api_prefix(self):
|
||||
assert resolve_url("https://cloud.comfy.org", "/prompt") == "https://cloud.comfy.org/api/prompt"
|
||||
|
||||
def test_cloud_history_renamed(self):
|
||||
assert resolve_url("https://cloud.comfy.org", "/history/abc") == "https://cloud.comfy.org/api/history_v2/abc"
|
||||
|
||||
def test_cloud_models_renamed(self):
|
||||
assert resolve_url("https://cloud.comfy.org", "/models/loras") == "https://cloud.comfy.org/api/experiment/models/loras"
|
||||
|
||||
def test_cloud_already_has_api(self):
|
||||
# Don't double-prefix
|
||||
assert resolve_url("https://cloud.comfy.org", "/api/prompt") == "https://cloud.comfy.org/api/prompt"
|
||||
|
||||
def test_trailing_slash_stripped(self):
|
||||
assert resolve_url("http://127.0.0.1:8188/", "/prompt") == "http://127.0.0.1:8188/prompt"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Workflow validation
|
||||
# =============================================================================
|
||||
|
||||
class TestAPIFormatDetection:
|
||||
def test_valid_api(self, sd15_workflow):
|
||||
assert is_api_format(sd15_workflow) is True
|
||||
|
||||
def test_editor_format_rejected(self):
|
||||
editor = {"nodes": [], "links": [], "version": 0.4}
|
||||
assert is_api_format(editor) is False
|
||||
|
||||
def test_empty_dict(self):
|
||||
assert is_api_format({}) is False
|
||||
|
||||
def test_non_dict(self):
|
||||
assert is_api_format([]) is False
|
||||
assert is_api_format(None) is False
|
||||
assert is_api_format("string") is False
|
||||
|
||||
def test_node_with_class_type(self):
|
||||
wf = {"3": {"class_type": "KSampler", "inputs": {}}}
|
||||
assert is_api_format(wf) is True
|
||||
|
||||
|
||||
class TestUnwrapWorkflow:
|
||||
def test_passthrough_api_format(self, sd15_workflow):
|
||||
result = unwrap_workflow(sd15_workflow)
|
||||
assert result is sd15_workflow
|
||||
|
||||
def test_unwrap_prompt_key(self, sd15_workflow):
|
||||
wrapped = {"prompt": sd15_workflow, "client_id": "abc"}
|
||||
result = unwrap_workflow(wrapped)
|
||||
assert result is sd15_workflow
|
||||
|
||||
def test_editor_format_raises(self):
|
||||
with pytest.raises(ValueError, match="editor format"):
|
||||
unwrap_workflow({"nodes": [], "links": []})
|
||||
|
||||
def test_garbage_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
unwrap_workflow({"foo": "bar"})
|
||||
|
||||
|
||||
class TestIsLink:
|
||||
def test_valid_link(self):
|
||||
assert is_link(["3", 0]) is True
|
||||
assert is_link(["10", 1]) is True
|
||||
|
||||
def test_non_link(self):
|
||||
assert is_link("string") is False
|
||||
assert is_link(42) is False
|
||||
assert is_link([]) is False
|
||||
assert is_link(["3"]) is False # missing slot
|
||||
assert is_link(["3", "0"]) is False # slot must be int
|
||||
assert is_link([3, 0]) is False # node_id must be string
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Workflow iterators
|
||||
# =============================================================================
|
||||
|
||||
class TestIterators:
|
||||
def test_iter_nodes(self, sd15_workflow):
|
||||
nodes = dict(iter_nodes(sd15_workflow))
|
||||
assert "3" in nodes
|
||||
assert nodes["3"]["class_type"] == "KSampler"
|
||||
|
||||
def test_iter_nodes_skips_comments(self, sd15_workflow):
|
||||
# _comment is not a node
|
||||
nodes = dict(iter_nodes(sd15_workflow))
|
||||
assert "_comment" not in nodes
|
||||
|
||||
def test_iter_model_deps(self, sd15_workflow):
|
||||
deps = list(iter_model_deps(sd15_workflow))
|
||||
names = [d["value"] for d in deps]
|
||||
assert "v1-5-pruned-emaonly.safetensors" in names
|
||||
|
||||
def test_iter_model_deps_flux(self, flux_workflow):
|
||||
deps = list(iter_model_deps(flux_workflow))
|
||||
names = {d["value"]: d["folder"] for d in deps}
|
||||
assert names["flux1-dev.safetensors"] == "unet"
|
||||
assert names["t5xxl_fp16.safetensors"] == "clip"
|
||||
assert names["clip_l.safetensors"] == "clip"
|
||||
assert names["ae.safetensors"] == "vae"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding extraction
|
||||
# =============================================================================
|
||||
|
||||
class TestEmbeddingRegex:
|
||||
def test_basic_embedding(self):
|
||||
m = EMBEDDING_REGEX.search("a cat, embedding:goodvibes, more text")
|
||||
assert m is not None
|
||||
assert m.group(1) == "goodvibes"
|
||||
|
||||
def test_embedding_with_strength(self):
|
||||
m = EMBEDDING_REGEX.search("embedding:bad-hands-5:1.2")
|
||||
assert m is not None
|
||||
assert m.group(1) == "bad-hands-5"
|
||||
|
||||
def test_embedding_with_extension(self):
|
||||
# Strips .pt / .safetensors / .bin
|
||||
m = EMBEDDING_REGEX.search("embedding:my-emb.pt")
|
||||
assert m is not None
|
||||
assert m.group(1) == "my-emb"
|
||||
|
||||
def test_embedding_in_parens(self):
|
||||
m = EMBEDDING_REGEX.search("(embedding:foo:0.8)")
|
||||
assert m is not None
|
||||
assert m.group(1) == "foo"
|
||||
|
||||
def test_multiple_in_one_string(self):
|
||||
text = "a cat, embedding:foo:1.2, and embedding:bar"
|
||||
matches = [m.group(1) for m in EMBEDDING_REGEX.finditer(text)]
|
||||
assert matches == ["foo", "bar"]
|
||||
|
||||
def test_no_false_positive_on_word_embedding(self):
|
||||
# "embedding " (with space, no colon) should not match
|
||||
m = EMBEDDING_REGEX.search("the embedding is great")
|
||||
assert m is None
|
||||
|
||||
|
||||
class TestIterEmbeddingRefs:
|
||||
def test_finds_in_clip_text_encode(self):
|
||||
wf = {
|
||||
"1": {"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "embedding:foo, embedding:bar:0.5", "clip": ["2", 0]}},
|
||||
"2": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
|
||||
}
|
||||
refs = list(iter_embedding_refs(wf))
|
||||
names = [name for _, name in refs]
|
||||
assert names == ["foo", "bar"]
|
||||
|
||||
def test_ignores_non_prompt_fields(self):
|
||||
wf = {
|
||||
"1": {"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {"ckpt_name": "embedding:foo.safetensors"}},
|
||||
}
|
||||
refs = list(iter_embedding_refs(wf))
|
||||
# ckpt_name is not a prompt field — ignored
|
||||
assert refs == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Path safety
|
||||
# =============================================================================
|
||||
|
||||
class TestSafePathJoin:
|
||||
def test_normal_join(self, tmp_path):
|
||||
p = safe_path_join(tmp_path, "subdir", "file.png")
|
||||
assert p.is_relative_to(tmp_path)
|
||||
|
||||
def test_blocks_traversal(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="path traversal"):
|
||||
safe_path_join(tmp_path, "..", "..", "etc", "passwd")
|
||||
|
||||
def test_blocks_absolute(self, tmp_path):
|
||||
with pytest.raises(ValueError):
|
||||
safe_path_join(tmp_path, "/etc/passwd")
|
||||
|
||||
def test_subfolder_with_filename(self, tmp_path):
|
||||
p = safe_path_join(tmp_path, "outputs", "img.png")
|
||||
assert p.name == "img.png"
|
||||
assert p.parent.name == "outputs"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Seed coercion
|
||||
# =============================================================================
|
||||
|
||||
class TestCoerceSeed:
|
||||
def test_explicit_int(self):
|
||||
assert coerce_seed(42) == 42
|
||||
assert coerce_seed(0) == 0
|
||||
|
||||
def test_minus_one_randomizes(self):
|
||||
s = coerce_seed(-1)
|
||||
assert isinstance(s, int)
|
||||
assert 0 <= s < 2**63
|
||||
|
||||
def test_none_randomizes(self):
|
||||
s = coerce_seed(None)
|
||||
assert isinstance(s, int)
|
||||
|
||||
def test_string_int(self):
|
||||
# str() that converts cleanly is allowed (relaxed)
|
||||
assert coerce_seed("12345") == 12345
|
||||
|
||||
def test_string_minus_one_randomizes(self):
|
||||
# CLI / JSON sometimes carries seed as a string.
|
||||
s = coerce_seed("-1")
|
||||
assert isinstance(s, int)
|
||||
assert 0 <= s < 2**63
|
||||
# And whitespace tolerated
|
||||
s2 = coerce_seed(" -1 ")
|
||||
assert isinstance(s2, int)
|
||||
assert 0 <= s2 < 2**63
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model list normalization (cloud format)
|
||||
# =============================================================================
|
||||
|
||||
class TestParseModelList:
|
||||
def test_local_format_strings(self):
|
||||
result = parse_model_list(["a.safetensors", "b.safetensors"])
|
||||
assert result == {"a.safetensors", "b.safetensors"}
|
||||
|
||||
def test_cloud_format_dicts(self):
|
||||
result = parse_model_list([
|
||||
{"name": "a.safetensors", "pathIndex": 0},
|
||||
{"name": "b.safetensors", "pathIndex": 1},
|
||||
])
|
||||
assert result == {"a.safetensors", "b.safetensors"}
|
||||
|
||||
def test_empty(self):
|
||||
assert parse_model_list([]) == set()
|
||||
|
||||
def test_garbage(self):
|
||||
assert parse_model_list("not a list") == set()
|
||||
assert parse_model_list(None) == set()
|
||||
|
||||
def test_mixed_format(self):
|
||||
result = parse_model_list([
|
||||
"string-form.safetensors",
|
||||
{"name": "dict-form.safetensors"},
|
||||
])
|
||||
assert result == {"string-form.safetensors", "dict-form.safetensors"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Folder aliases
|
||||
# =============================================================================
|
||||
|
||||
class TestFolderAliases:
|
||||
def test_unet_aliases_diffusion_models(self):
|
||||
aliases = folder_aliases_for("unet")
|
||||
assert "unet" in aliases
|
||||
assert "diffusion_models" in aliases
|
||||
|
||||
def test_clip_aliases_text_encoders(self):
|
||||
aliases = folder_aliases_for("clip")
|
||||
assert "clip" in aliases
|
||||
assert "text_encoders" in aliases
|
||||
|
||||
def test_unknown_folder_returns_self(self):
|
||||
assert folder_aliases_for("checkpoints") == ["checkpoints"]
|
||||
|
||||
def test_primary_first(self):
|
||||
# Order matters: primary should be first for human-friendly fix hints
|
||||
assert folder_aliases_for("unet")[0] == "unet"
|
||||
assert folder_aliases_for("diffusion_models")[0] == "diffusion_models"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Media-type detection
|
||||
# =============================================================================
|
||||
|
||||
class TestMediaType:
|
||||
def test_video_extensions(self):
|
||||
assert media_type_from_filename("vid.mp4") == "video"
|
||||
assert media_type_from_filename("foo.webm") == "video"
|
||||
assert media_type_from_filename("bar.gif") == "video"
|
||||
|
||||
def test_audio_extensions(self):
|
||||
assert media_type_from_filename("song.wav") == "audio"
|
||||
assert media_type_from_filename("music.mp3") == "audio"
|
||||
|
||||
def test_image_default(self):
|
||||
assert media_type_from_filename("pic.png") == "image"
|
||||
assert media_type_from_filename("image.jpg") == "image"
|
||||
assert media_type_from_filename("unknown.xyz") == "image"
|
||||
|
||||
def test_3d(self):
|
||||
assert media_type_from_filename("model.glb") == "3d"
|
||||
assert media_type_from_filename("scene.gltf") == "3d"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cross-host header stripping (security)
|
||||
# =============================================================================
|
||||
|
||||
class TestRedirectHeaderStripping:
|
||||
"""Verify X-API-Key is dropped when redirect crosses to a different host
|
||||
(e.g. cloud /api/view → S3 signed URL). Critical to prevent leaking auth
|
||||
tokens to the storage backend.
|
||||
"""
|
||||
|
||||
def _build_session(self):
|
||||
from _common import _StripSensitiveOnRedirectSession, HAS_REQUESTS
|
||||
if not HAS_REQUESTS:
|
||||
import pytest
|
||||
pytest.skip("requests not installed")
|
||||
return _StripSensitiveOnRedirectSession()
|
||||
|
||||
def test_strips_x_api_key_cross_host(self):
|
||||
import requests
|
||||
s = self._build_session()
|
||||
prep = requests.PreparedRequest()
|
||||
prep.prepare(method="GET", url="https://other.example.com/file",
|
||||
headers={"X-API-Key": "leak", "Authorization": "Bearer x"})
|
||||
resp = requests.Response()
|
||||
orig = requests.PreparedRequest()
|
||||
orig.prepare(method="GET", url="https://cloud.comfy.org/api/view", headers={})
|
||||
resp.request = orig
|
||||
s.rebuild_auth(prep, resp)
|
||||
assert "X-API-Key" not in prep.headers
|
||||
assert "Authorization" not in prep.headers
|
||||
|
||||
def test_preserves_x_api_key_same_host(self):
|
||||
import requests
|
||||
s = self._build_session()
|
||||
prep = requests.PreparedRequest()
|
||||
prep.prepare(method="GET", url="https://cloud.comfy.org/foo",
|
||||
headers={"X-API-Key": "keep"})
|
||||
resp = requests.Response()
|
||||
orig = requests.PreparedRequest()
|
||||
orig.prepare(method="GET", url="https://cloud.comfy.org/bar", headers={})
|
||||
resp.request = orig
|
||||
s.rebuild_auth(prep, resp)
|
||||
assert prep.headers.get("X-API-Key") == "keep"
|
||||
|
||||
def test_strips_cookie_cross_host(self):
|
||||
import requests
|
||||
s = self._build_session()
|
||||
prep = requests.PreparedRequest()
|
||||
prep.prepare(method="GET", url="https://other.example.com/x",
|
||||
headers={"Cookie": "session=secret"})
|
||||
resp = requests.Response()
|
||||
orig = requests.PreparedRequest()
|
||||
orig.prepare(method="GET", url="https://cloud.comfy.org/foo", headers={})
|
||||
resp.request = orig
|
||||
s.rebuild_auth(prep, resp)
|
||||
assert "Cookie" not in prep.headers
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Video workflow detection
|
||||
# =============================================================================
|
||||
|
||||
class TestVideoWorkflow:
|
||||
def test_image_workflow(self, sd15_workflow):
|
||||
assert looks_like_video_workflow(sd15_workflow) is False
|
||||
|
||||
def test_animatediff_workflow(self, workflows_dir):
|
||||
import json
|
||||
wf = json.loads((workflows_dir / "animatediff_video.json").read_text())
|
||||
assert looks_like_video_workflow(wf) is True
|
||||
|
||||
def test_wan_workflow(self, video_workflow):
|
||||
assert looks_like_video_workflow(video_workflow) is True
|
||||
185
skills/creative/comfyui/tests/test_extract_schema.py
Normal file
185
skills/creative/comfyui/tests/test_extract_schema.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Tests for extract_schema.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from extract_schema import (
|
||||
extract_schema,
|
||||
find_negative_prompt_node,
|
||||
find_positive_prompt_node,
|
||||
trace_to_node,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Connection tracing
|
||||
# =============================================================================
|
||||
|
||||
class TestConnectionTracing:
|
||||
def test_direct_link(self):
|
||||
wf = {
|
||||
"1": {"class_type": "CLIPTextEncode", "inputs": {"text": "x"}},
|
||||
"2": {"class_type": "KSampler",
|
||||
"inputs": {"positive": ["1", 0], "negative": ["1", 0]}},
|
||||
}
|
||||
assert trace_to_node(wf, ["1", 0]) == "1"
|
||||
|
||||
def test_through_reroute(self):
|
||||
wf = {
|
||||
"1": {"class_type": "CLIPTextEncode", "inputs": {"text": "x"}},
|
||||
"2": {"class_type": "Reroute", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "Reroute", "inputs": {"input": ["2", 0]}},
|
||||
}
|
||||
assert trace_to_node(wf, ["3", 0]) == "1"
|
||||
|
||||
def test_circular_safe(self):
|
||||
wf = {
|
||||
"1": {"class_type": "Reroute", "inputs": {"input": ["2", 0]}},
|
||||
"2": {"class_type": "Reroute", "inputs": {"input": ["1", 0]}},
|
||||
}
|
||||
# Should hit max_hops without infinite loop
|
||||
result = trace_to_node(wf, ["1", 0], max_hops=5)
|
||||
assert result in ("1", "2") # any node, just don't hang
|
||||
|
||||
|
||||
class TestPositiveNegativeDetection:
|
||||
def test_basic(self, sd15_workflow):
|
||||
# In sd15_workflow.json node 6 is positive, node 7 is negative
|
||||
assert find_positive_prompt_node(sd15_workflow) == "6"
|
||||
assert find_negative_prompt_node(sd15_workflow) == "7"
|
||||
|
||||
def test_swapped_order(self):
|
||||
wf = {
|
||||
"3": {"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"positive": ["7", 0], "negative": ["6", 0],
|
||||
"model": ["4", 0], "latent_image": ["5", 0],
|
||||
"seed": 1, "steps": 20, "cfg": 7.5,
|
||||
"sampler_name": "euler", "scheduler": "normal", "denoise": 1.0,
|
||||
}},
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
|
||||
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": "ugly", "clip": ["4", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "beautiful", "clip": ["4", 1]}},
|
||||
}
|
||||
# Now 7 is the positive (despite higher node ID)
|
||||
assert find_positive_prompt_node(wf) == "7"
|
||||
assert find_negative_prompt_node(wf) == "6"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema extraction
|
||||
# =============================================================================
|
||||
|
||||
class TestExtractSchema:
|
||||
def test_basic_sd15(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
params = schema["parameters"]
|
||||
assert "prompt" in params
|
||||
assert "negative_prompt" in params
|
||||
assert "seed" in params
|
||||
assert "steps" in params
|
||||
assert "cfg" in params
|
||||
assert "width" in params
|
||||
assert "height" in params
|
||||
|
||||
def test_prompt_value_correct(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
# The positive prompt in the example is the landscape one
|
||||
assert "landscape" in schema["parameters"]["prompt"]["value"]
|
||||
assert "ugly" in schema["parameters"]["negative_prompt"]["value"]
|
||||
|
||||
def test_model_dependencies(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
deps = schema["model_dependencies"]
|
||||
ckpts = [d["value"] for d in deps if d["folder"] == "checkpoints"]
|
||||
assert "v1-5-pruned-emaonly.safetensors" in ckpts
|
||||
|
||||
def test_output_nodes(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
assert "9" in schema["output_nodes"]
|
||||
|
||||
def test_summary(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
s = schema["summary"]
|
||||
assert s["has_negative_prompt"] is True
|
||||
assert s["has_seed"] is True
|
||||
assert s["is_video_workflow"] is False
|
||||
assert s["parameter_count"] > 5
|
||||
|
||||
def test_flux_workflow(self, flux_workflow):
|
||||
schema = extract_schema(flux_workflow)
|
||||
# Flux uses RandomNoise for seed
|
||||
assert schema["summary"]["has_seed"] is True
|
||||
# Flux has only positive prompt (no negative encoder)
|
||||
assert schema["summary"]["has_negative_prompt"] is False
|
||||
|
||||
def test_video_detected(self, video_workflow):
|
||||
schema = extract_schema(video_workflow)
|
||||
assert schema["summary"]["is_video_workflow"] is True
|
||||
|
||||
|
||||
class TestEmbeddingDeps:
|
||||
def test_extract_from_prompt(self):
|
||||
wf = {
|
||||
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
|
||||
"5": {"class_type": "EmptyLatentImage",
|
||||
"inputs": {"width": 512, "height": 512, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "a cat, embedding:goodvibes, embedding:art:1.2",
|
||||
"clip": ["1", 1]
|
||||
}},
|
||||
"7": {"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"text": "ugly, embedding:badhands",
|
||||
"clip": ["1", 1]
|
||||
}},
|
||||
"3": {"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"positive": ["6", 0], "negative": ["7", 0],
|
||||
"model": ["1", 0], "latent_image": ["5", 0],
|
||||
"seed": 1, "steps": 20, "cfg": 7.5,
|
||||
"sampler_name": "euler", "scheduler": "normal", "denoise": 1.0,
|
||||
}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}},
|
||||
}
|
||||
schema = extract_schema(wf)
|
||||
names = [d["embedding_name"] for d in schema["embedding_dependencies"]]
|
||||
assert sorted(names) == ["art", "badhands", "goodvibes"]
|
||||
|
||||
|
||||
class TestDuplicateDeduplication:
|
||||
def test_two_ksamplers_get_unique_names(self):
|
||||
wf = {
|
||||
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
|
||||
"5": {"class_type": "EmptyLatentImage",
|
||||
"inputs": {"width": 512, "height": 512, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": "a", "clip": ["1", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "b", "clip": ["1", 1]}},
|
||||
"3": {"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"positive": ["6", 0], "negative": ["7", 0],
|
||||
"model": ["1", 0], "latent_image": ["5", 0],
|
||||
"seed": 42, "steps": 20, "cfg": 7.5,
|
||||
"sampler_name": "euler", "scheduler": "normal", "denoise": 1.0,
|
||||
}},
|
||||
"4": {"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"positive": ["6", 0], "negative": ["7", 0],
|
||||
"model": ["1", 0], "latent_image": ["5", 0],
|
||||
"seed": 99, "steps": 30, "cfg": 8.0,
|
||||
"sampler_name": "euler", "scheduler": "normal", "denoise": 0.6,
|
||||
}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}},
|
||||
}
|
||||
schema = extract_schema(wf)
|
||||
params = schema["parameters"]
|
||||
# Both seeds present with disambiguated names
|
||||
seed_keys = [k for k in params if "seed" in k]
|
||||
# Symmetric: both renamed (no bare "seed")
|
||||
assert "seed" not in params
|
||||
assert "seed_3" in params and "seed_4" in params
|
||||
assert params["seed_3"]["value"] == 42
|
||||
assert params["seed_4"]["value"] == 99
|
||||
213
skills/creative/comfyui/tests/test_run_workflow.py
Normal file
213
skills/creative/comfyui/tests/test_run_workflow.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Tests for run_workflow.py — focuses on logic that doesn't require a server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from extract_schema import extract_schema
|
||||
from run_workflow import (
|
||||
ComfyRunner,
|
||||
download_outputs,
|
||||
inject_params,
|
||||
parse_input_image_arg,
|
||||
)
|
||||
|
||||
|
||||
class TestParseInputImageArg:
|
||||
def test_with_name(self, tmp_path):
|
||||
f = tmp_path / "x.png"
|
||||
f.write_text("x")
|
||||
n, p = parse_input_image_arg(f"image={f}")
|
||||
assert n == "image"
|
||||
assert p == f
|
||||
|
||||
def test_without_name_defaults(self, tmp_path):
|
||||
f = tmp_path / "x.png"
|
||||
f.write_text("x")
|
||||
n, p = parse_input_image_arg(str(f))
|
||||
assert n == "image"
|
||||
|
||||
def test_custom_name(self, tmp_path):
|
||||
f = tmp_path / "x.png"
|
||||
f.write_text("x")
|
||||
n, p = parse_input_image_arg(f"mask_image={f}")
|
||||
assert n == "mask_image"
|
||||
|
||||
|
||||
class TestInjectParams:
|
||||
def test_basic_injection(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
wf, warnings = inject_params(sd15_workflow, schema, {
|
||||
"prompt": "new prompt",
|
||||
"seed": 999,
|
||||
"steps": 25,
|
||||
})
|
||||
assert wf["6"]["inputs"]["text"] == "new prompt"
|
||||
assert wf["3"]["inputs"]["seed"] == 999
|
||||
assert wf["3"]["inputs"]["steps"] == 25
|
||||
assert warnings == []
|
||||
|
||||
def test_unknown_param_warns(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
_, warnings = inject_params(sd15_workflow, schema, {"foobar": "x"})
|
||||
assert any("foobar" in w for w in warnings)
|
||||
|
||||
def test_seed_minus_one_randomizes(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
wf, warnings = inject_params(sd15_workflow, schema, {"seed": -1})
|
||||
assert wf["3"]["inputs"]["seed"] != -1
|
||||
assert isinstance(wf["3"]["inputs"]["seed"], int)
|
||||
assert any("expanded" in w.lower() for w in warnings)
|
||||
|
||||
def test_randomize_seed_when_unset(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
original = sd15_workflow["3"]["inputs"]["seed"]
|
||||
wf, warnings = inject_params(sd15_workflow, schema, {}, randomize_seed_if_unset=True)
|
||||
assert wf["3"]["inputs"]["seed"] != original
|
||||
assert isinstance(wf["3"]["inputs"]["seed"], int)
|
||||
|
||||
def test_does_not_mutate_original(self, sd15_workflow):
|
||||
schema = extract_schema(sd15_workflow)
|
||||
original_text = sd15_workflow["6"]["inputs"]["text"]
|
||||
inject_params(sd15_workflow, schema, {"prompt": "MUTATED"})
|
||||
assert sd15_workflow["6"]["inputs"]["text"] == original_text
|
||||
|
||||
def test_refuses_to_overwrite_link(self):
|
||||
wf = {
|
||||
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
|
||||
"5": {"class_type": "EmptyLatentImage",
|
||||
"inputs": {"width": 512, "height": 512, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": ["3", 0], "clip": ["1", 1]}}, # text is a link!
|
||||
"3": {"class_type": "KSampler",
|
||||
"inputs": {"seed": 1, "steps": 20, "cfg": 7.5,
|
||||
"sampler_name": "euler", "scheduler": "normal", "denoise": 1.0,
|
||||
"model": ["1", 0], "positive": ["6", 0], "negative": ["6", 0],
|
||||
"latent_image": ["5", 0]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}},
|
||||
}
|
||||
# Manually create a schema that has prompt pointing at 6.text
|
||||
schema = {
|
||||
"parameters": {
|
||||
"prompt": {"node_id": "6", "field": "text", "type": "string", "value": ""},
|
||||
}
|
||||
}
|
||||
wf2, warnings = inject_params(wf, schema, {"prompt": "literal value"})
|
||||
# The link should NOT have been overwritten
|
||||
assert wf2["6"]["inputs"]["text"] == ["3", 0]
|
||||
assert any("link" in w.lower() for w in warnings)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Output download walk
|
||||
# =============================================================================
|
||||
|
||||
class TestDownloadOutputsWalk:
|
||||
"""Test that download_outputs walks the structure correctly."""
|
||||
|
||||
def test_handles_videos_plural(self, tmp_path, monkeypatch):
|
||||
"""Local ComfyUI uses 'videos'/'gifs' (plural) keys."""
|
||||
downloads = []
|
||||
|
||||
class FakeRunner:
|
||||
def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite):
|
||||
downloads.append((filename, subfolder, file_type))
|
||||
p = output_dir / filename
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_bytes(b"x")
|
||||
return p
|
||||
|
||||
outputs = {
|
||||
"9": {"images": [{"filename": "img1.png", "subfolder": "", "type": "output"}]},
|
||||
"10": {"videos": [{"filename": "vid1.mp4", "subfolder": "", "type": "output"}]},
|
||||
"11": {"gifs": [{"filename": "anim1.gif", "subfolder": "", "type": "output"}]},
|
||||
}
|
||||
|
||||
result = download_outputs(FakeRunner(), outputs, tmp_path)
|
||||
files = sorted(d["filename"] for d in result)
|
||||
assert files == ["anim1.gif", "img1.png", "vid1.mp4"]
|
||||
|
||||
def test_handles_video_singular_cloud(self, tmp_path):
|
||||
"""Cloud uses 'video' (singular)."""
|
||||
class FakeRunner:
|
||||
def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite):
|
||||
p = output_dir / filename
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_bytes(b"x")
|
||||
return p
|
||||
|
||||
outputs = {
|
||||
"10": {"video": [{"filename": "cloud.mp4", "subfolder": "", "type": "output"}]},
|
||||
}
|
||||
result = download_outputs(FakeRunner(), outputs, tmp_path)
|
||||
assert len(result) == 1
|
||||
assert result[0]["filename"] == "cloud.mp4"
|
||||
|
||||
def test_preserves_subfolder(self, tmp_path):
|
||||
"""When preserve_subfolder=True, server subfolder becomes local subdir."""
|
||||
class FakeRunner:
|
||||
def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite):
|
||||
if preserve_subfolder and subfolder:
|
||||
p = output_dir / subfolder / filename
|
||||
else:
|
||||
p = output_dir / filename
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_bytes(b"x")
|
||||
return p
|
||||
|
||||
outputs = {
|
||||
"9": {"images": [
|
||||
{"filename": "img.png", "subfolder": "myrun", "type": "output"},
|
||||
{"filename": "img.png", "subfolder": "otherrun", "type": "output"},
|
||||
]},
|
||||
}
|
||||
result = download_outputs(FakeRunner(), outputs, tmp_path, preserve_subfolder=True)
|
||||
files = [d["file"] for d in result]
|
||||
assert any("myrun" in f for f in files)
|
||||
assert any("otherrun" in f for f in files)
|
||||
# Both must exist (no collision)
|
||||
assert len({str(f) for f in files}) == 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ComfyRunner construction
|
||||
# =============================================================================
|
||||
|
||||
class TestRunnerConstruction:
|
||||
def test_local_default(self):
|
||||
r = ComfyRunner()
|
||||
assert r.is_cloud is False
|
||||
assert r.host == "http://127.0.0.1:8188"
|
||||
|
||||
def test_cloud_detection(self):
|
||||
r = ComfyRunner(host="https://cloud.comfy.org", api_key="abc")
|
||||
assert r.is_cloud is True
|
||||
assert "X-API-Key" in r.headers
|
||||
|
||||
def test_cloud_subdomain_detected(self):
|
||||
r = ComfyRunner(host="https://staging.cloud.comfy.org", api_key="abc")
|
||||
assert r.is_cloud is True
|
||||
|
||||
def test_partner_key_does_not_pollute_extra_data(self):
|
||||
r = ComfyRunner(host="https://cloud.comfy.org", api_key="auth-key")
|
||||
# No partner-key set → no extra_data should appear in submitted prompt
|
||||
# (This is a static check; runtime check happens in submit())
|
||||
assert r.partner_key is None
|
||||
|
||||
def test_url_routing_local(self):
|
||||
r = ComfyRunner()
|
||||
url = r._url("/prompt")
|
||||
assert url == "http://127.0.0.1:8188/prompt"
|
||||
|
||||
def test_url_routing_cloud(self):
|
||||
r = ComfyRunner(host="https://cloud.comfy.org", api_key="x")
|
||||
url = r._url("/prompt")
|
||||
assert url == "https://cloud.comfy.org/api/prompt"
|
||||
|
||||
def test_url_routing_cloud_history_renamed(self):
|
||||
r = ComfyRunner(host="https://cloud.comfy.org", api_key="x")
|
||||
url = r._url("/history/abc-123")
|
||||
assert url == "https://cloud.comfy.org/api/history_v2/abc-123"
|
||||
Reference in New Issue
Block a user