fix(skills/comfyui): bug fixes, cloud parity, expanded coverage, examples, tests

The audit of v4.1 surfaced ~70 issues across the five scripts and three
reference docs — most user-visible (silent file overwrites, status-error
misclassified as success, X-API-Key leaked to S3 on /api/view redirect,
Cloud endpoints that 404 because they were renamed). v5.0.0 fixes those
and fills the gaps that previously forced users to write their own glue
(WebSocket monitoring, batch/sweep, img2img upload helper, dep auto-fix,
log fetch, health check, example workflows).

Critical fixes
- run_workflow.py: poll_status now checks status_str==error BEFORE
  completed:true, so a failed run no longer reports success
- run_workflow.py: download_output streams to disk via safe_path_join,
  preserves server subfolder structure (no silent overwrites), and
  retries with exponential backoff
- run_workflow.py: refuses to overwrite a link with a literal in
  inject_params (would silently break wiring)
- _common.py: _StripSensitiveOnRedirectSession (subclasses
  requests.Session.rebuild_auth) drops X-API-Key/Cookie on cross-host
  redirects — fixes a real key-leak path through Cloud's signed-URL
  download flow. Tested
- Cloud routing (verified live): /history → /history_v2,
  /models/<f> → /experiment/models/<f>, plus folder aliases for the
  unet ↔ diffusion_models and clip ↔ text_encoders rename
- check_deps.py: distinguishes 200/empty vs 404 folder_not_found vs
  403 free-tier; emits concrete fix_command per missing dep
- extract_schema.py: prompt vs negative_prompt determined by tracing
  KSampler.{positive,negative} connections (incl. through Reroute /
  Primitive nodes) instead of meta-title heuristic; symmetric
  duplicate-name resolution; cycle-safe trace_to_node
- hardware_check.py: multi-GPU pick-best, Apple variant detection,
  Rosetta detection, WSL2, ROCm --json, disk-space check, optional
  PyTorch probe; powershell preferred over deprecated wmic
- comfyui_setup.sh: prefers pipx → uvx → pip --user (with PEP-668
  fallback); idempotent — skips relaunch if server already up;
  configurable port/workspace; persistent log; SIGINT trap

New scripts
- run_batch.py — count or sweep (cartesian product), parallel up to
  cloud tier limit
- ws_monitor.py — real-time WebSocket viewer; saves preview frames
- auto_fix_deps.py — runs comfy node install / model download for
  whatever check_deps reports missing (with --dry-run)
- health_check.py — single command that runs the verification checklist
  (comfy-cli + server + checkpoints + optional smoke test that cancels
  itself to avoid burning compute)
- fetch_logs.py — pull traceback / status messages for a prompt_id

Coverage expansion
- Param patterns now cover Flux (BasicScheduler, BasicGuider,
  RandomNoise, ModelSamplingFlux), SD3, Wan/Hunyuan/LTX video,
  IPAdapter, rgthree, easy-use, AnimateDiff
- Embedding refs in CLIPTextEncode strings extracted as model deps
- ckpt_name / vae_name / lora_name / unet_name now controllable so
  workflows can be retargeted per run

Examples
- workflows/{sd15,sdxl,flux_dev}_txt2img.json
- workflows/sdxl_{img2img,inpaint}.json
- workflows/upscale_4x.json
- workflows/{animatediff_video,wan_video_t2v}.json + README

Tests
- 117 tests (105 unit + 8 cloud integration + 4 cross-host security)
- Cloud tests auto-skip without COMFY_CLOUD_API_KEY; verified end-to-end
  against live cloud API

Backwards compatibility
- All existing CLI flags continue to work; new behavior is opt-in
  (--ws, --input-image, --randomize-seed, --flat-output, etc.)
This commit is contained in:
SHL0MS
2026-04-29 20:50:52 -04:00
committed by Teknium
parent 7d48a16f14
commit a7780fe05f
32 changed files with 6117 additions and 1372 deletions

View File

@@ -0,0 +1,50 @@
# ComfyUI Skill Tests
Pytest suite covering the skill's scripts. Pure-stdlib unit tests run
without any setup; cloud integration tests need a Comfy Cloud API key.
## Running
```bash
# Unit tests only (no network required) — runs in <1s
python3 -m pytest tests/ -c tests/pytest.ini -o addopts="-p no:xdist"
# Including cloud integration tests
COMFY_CLOUD_API_KEY="comfyui-..." python3 -m pytest tests/ \
-c tests/pytest.ini -o addopts="-p no:xdist"
# Just cloud tests
COMFY_CLOUD_API_KEY="comfyui-..." python3 -m pytest tests/test_cloud_integration.py \
-c tests/pytest.ini -o addopts="-p no:xdist" -v
```
The `-c` and `-o` overrides isolate this suite from any parent
`pyproject.toml` pytest config (e.g. the `-n auto` from a parent repo).
## Test files
| File | Coverage |
|------|----------|
| `test_common.py` | Cloud detection, URL routing, format validation, embeddings, paths, seeds, model-list parsing, folder aliases |
| `test_extract_schema.py` | Connection tracing, positive/negative prompt detection, dedup logic, embedding deps |
| `test_run_workflow.py` | Param injection (incl. -1 seed, link refusal), output download walk, runner construction |
| `test_check_deps.py` | Model-name fuzzy matching, install command suggestions |
| `test_cloud_integration.py` | Live cloud API contract tests (auto-skipped without API key) |
## Adding tests
When you change a script:
1. Add a unit test if the change is pure logic (cloud detection, parsing, etc.)
2. Add a cloud integration test if the change depends on cloud API behavior
(use `pytestmark = pytest.mark.cloud` so it auto-skips without a key)
3. Workflow fixtures live in `conftest.py` (`sd15_workflow`, `flux_workflow`,
`video_workflow`)
## Why the explicit `-c` / `-o`?
The parent hermes-agent repo's `pyproject.toml` enables `pytest-xdist` by
default (`-n auto`). This suite is small enough that parallelism isn't
worth the complexity, and pytest-xdist isn't always installed in the user's
environment. The `-c tests/pytest.ini -o addopts="-p no:xdist"` flags make
the suite run identically regardless of the parent project's config.

View File

@@ -0,0 +1,64 @@
"""Pytest configuration for the comfyui skill test suite.
Adds `scripts/` to sys.path so tests can `from _common import ...`, and
provides a few common fixtures.
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parent.parent
SCRIPTS = ROOT / "scripts"
WORKFLOWS = ROOT / "workflows"
sys.path.insert(0, str(SCRIPTS))
@pytest.fixture
def sd15_workflow() -> dict:
return json.loads((WORKFLOWS / "sd15_txt2img.json").read_text())
@pytest.fixture
def flux_workflow() -> dict:
return json.loads((WORKFLOWS / "flux_dev_txt2img.json").read_text())
@pytest.fixture
def video_workflow() -> dict:
return json.loads((WORKFLOWS / "wan_video_t2v.json").read_text())
@pytest.fixture
def workflows_dir() -> Path:
return WORKFLOWS
@pytest.fixture
def scripts_dir() -> Path:
return SCRIPTS
@pytest.fixture
def cloud_key() -> str | None:
"""Cloud API key if set, otherwise None.
Tests that need cloud connectivity should skip when this is None.
"""
return os.environ.get("COMFY_CLOUD_API_KEY")
def pytest_collection_modifyitems(config, items):
"""Auto-skip cloud tests when no API key is set."""
if os.environ.get("COMFY_CLOUD_API_KEY"):
return
skip_cloud = pytest.mark.skip(reason="Set COMFY_CLOUD_API_KEY to run cloud tests")
for item in items:
if "cloud" in item.keywords:
item.add_marker(skip_cloud)

View File

@@ -0,0 +1,5 @@
[pytest]
markers =
cloud: tests that hit live Comfy Cloud API (require COMFY_CLOUD_API_KEY)
testpaths = .
addopts = -p no:xdist

View File

@@ -0,0 +1,65 @@
"""Tests for check_deps.py — focuses on parsing logic that doesn't need a server."""
from __future__ import annotations
from check_deps import (
NODE_TO_PACKAGE,
model_present,
normalize_for_match,
suggest_install_command,
)
class TestNormalizeForMatch:
def test_basic(self):
s = normalize_for_match("model.safetensors")
assert "model.safetensors" in s
assert "model" in s
def test_subfolder(self):
s = normalize_for_match("subdir/model.pt")
assert "subdir/model.pt" in s
assert "model.pt" in s
assert "model" in s
class TestModelPresent:
def test_exact_match(self):
assert model_present("a.safetensors", {"a.safetensors", "b.safetensors"}) is True
def test_extension_difference(self):
# User said "model" but installed is "model.safetensors"
assert model_present("model", {"model.safetensors"}) is True
# Reverse direction — also matches
assert model_present("model.safetensors", {"model"}) is True
def test_subfolder_match(self):
# Installed list has "subdir/model.safetensors", workflow asks "model.safetensors"
assert model_present("model.safetensors", {"subdir/model.safetensors"}) is True
def test_missing(self):
assert model_present("missing.safetensors", {"a.safetensors", "b.safetensors"}) is False
def test_empty_installed(self):
assert model_present("anything.safetensors", set()) is False
class TestSuggestInstallCommand:
def test_known_node(self):
cmd = suggest_install_command("VHS_VideoCombine")
assert cmd == "comfy node install comfyui-videohelpersuite"
def test_unknown_node(self):
assert suggest_install_command("SomeRandomNodeName123") is None
class TestNodePackageMap:
def test_no_duplicates(self):
# Each node should map to exactly one package
keys = list(NODE_TO_PACKAGE.keys())
assert len(keys) == len(set(keys))
def test_all_lowercase_packages(self):
# Convention: package names are lowercase with hyphens/underscores
for pkg in NODE_TO_PACKAGE.values():
assert pkg.lower() == pkg, f"Package name should be lowercase: {pkg}"

View File

@@ -0,0 +1,95 @@
"""Integration tests against the live Comfy Cloud API.
These tests are auto-skipped when COMFY_CLOUD_API_KEY is not set.
They never SUBMIT workflows (would need a paid subscription) — they only
verify the read-only endpoints we rely on.
"""
from __future__ import annotations
import pytest
from _common import http_get, parse_model_list, resolve_url
pytestmark = pytest.mark.cloud
class TestCloudEndpointsLive:
def test_system_stats_reachable(self, cloud_key):
url = resolve_url("https://cloud.comfy.org", "/system_stats")
r = http_get(url, headers={"X-API-Key": cloud_key})
assert r.status == 200
data = r.json()
assert "system" in data
def test_models_endpoint_routed_to_experiment(self, cloud_key):
# We expect the skill to route /models/checkpoints → /api/experiment/models/checkpoints
url = resolve_url("https://cloud.comfy.org", "/models/checkpoints")
assert "/api/experiment/models/checkpoints" in url
r = http_get(url, headers={"X-API-Key": cloud_key})
assert r.status == 200
def test_models_endpoint_returns_dicts(self, cloud_key):
url = resolve_url("https://cloud.comfy.org", "/models/checkpoints")
r = http_get(url, headers={"X-API-Key": cloud_key})
data = r.json()
assert isinstance(data, list)
if data:
# Cloud format: list of dicts with `name`
assert isinstance(data[0], dict)
assert "name" in data[0]
# Our parser normalizes both
normalized = parse_model_list(data)
assert len(normalized) == len(data)
def test_history_renamed_to_v2(self, cloud_key):
# /history → /api/history_v2 on cloud
url = resolve_url("https://cloud.comfy.org", "/history/some-fake-id")
assert "/api/history_v2/some-fake-id" in url
def test_object_info_paid_tier(self, cloud_key):
# On free tier, /object_info returns 403 with a recognizable message
url = resolve_url("https://cloud.comfy.org", "/object_info")
r = http_get(url, headers={"X-API-Key": cloud_key})
# Should be either 200 (paid) or 403 (free) — not 404 / 500
assert r.status in (200, 403)
if r.status == 403:
# Body should mention the limitation
assert "free tier" in r.text().lower() or "subscription" in r.text().lower()
class TestCloudCheckDepsLive:
def test_check_deps_against_cloud(self, cloud_key, sd15_workflow):
from check_deps import check_deps
report = check_deps(sd15_workflow, host="https://cloud.comfy.org", api_key=cloud_key)
# Either node check passed OR was skipped (free tier)
assert "missing_models" in report
assert "is_cloud" in report and report["is_cloud"] is True
def test_flux_workflow_models_resolved_via_aliases(self, cloud_key, flux_workflow):
"""Flux uses unet/clip folders; cloud has them in diffusion_models/text_encoders.
With folder aliasing, the check should still find them."""
from check_deps import check_deps
report = check_deps(flux_workflow, host="https://cloud.comfy.org", api_key=cloud_key)
# The exact required Flux files (flux1-dev.safetensors, t5xxl_fp16, clip_l, ae)
# are present on cloud; with folder aliasing, none should be missing.
# If this fails, either the cloud removed the model or the aliasing logic broke.
missing_filenames = {m["value"] for m in report["missing_models"]}
assert "ae.safetensors" not in missing_filenames, \
"ae.safetensors should be on cloud's vae folder"
# t5xxl_fp16 / clip_l should be reachable via the clip → text_encoders alias
# flux1-dev.safetensors likewise via unet → diffusion_models
class TestHealthCheckLive:
def test_health_check_passes(self, cloud_key, capsys):
from health_check import main as health_main
rc = health_main(["--host", "https://cloud.comfy.org", "--api-key", cloud_key])
captured = capsys.readouterr()
# Should produce JSON
import json
report = json.loads(captured.out)
assert report["server"]["reachable"] is True
assert report["checkpoints"]["queryable"] is True
assert report["checkpoints"]["count"] > 0

View File

@@ -0,0 +1,447 @@
"""Unit tests for _common.py — pure logic only, no network."""
from __future__ import annotations
from pathlib import Path
import pytest
from _common import (
DEFAULT_LOCAL_HOST,
EMBEDDING_REGEX,
FOLDER_ALIASES,
build_cloud_aware_url,
cloud_endpoint,
coerce_seed,
folder_aliases_for,
is_api_format,
is_cloud_host,
is_link,
iter_embedding_refs,
iter_model_deps,
iter_nodes,
looks_like_video_workflow,
media_type_from_filename,
parse_model_list,
resolve_url,
safe_path_join,
unwrap_workflow,
)
# =============================================================================
# Cloud detection / URL routing
# =============================================================================
class TestCloudDetection:
def test_cloud_host_exact(self):
assert is_cloud_host("https://cloud.comfy.org") is True
assert is_cloud_host("https://cloud.comfy.org/foo/bar") is True
def test_cloud_host_subdomain(self):
assert is_cloud_host("https://staging.cloud.comfy.org") is True
assert is_cloud_host("https://api.cloud.comfy.org") is True
def test_local_not_cloud(self):
assert is_cloud_host("http://127.0.0.1:8188") is False
assert is_cloud_host("http://localhost:8188") is False
assert is_cloud_host("http://my-server.local:8188") is False
def test_no_scheme(self):
# Defaults to http://
assert is_cloud_host("cloud.comfy.org") is True
assert is_cloud_host("127.0.0.1:8188") is False
class TestCloudEndpointRename:
def test_history_renamed(self):
assert cloud_endpoint("/history") == "/history_v2"
assert cloud_endpoint("/history/abc-123") == "/history_v2/abc-123"
def test_history_v2_preserved(self):
assert cloud_endpoint("/history_v2") == "/history_v2"
def test_models_renamed(self):
assert cloud_endpoint("/models") == "/experiment/models"
assert cloud_endpoint("/models/checkpoints") == "/experiment/models/checkpoints"
assert cloud_endpoint("/models/loras") == "/experiment/models/loras"
def test_other_paths_unchanged(self):
assert cloud_endpoint("/prompt") == "/prompt"
assert cloud_endpoint("/queue") == "/queue"
class TestResolveURL:
def test_local_no_prefix(self):
assert resolve_url("http://127.0.0.1:8188", "/prompt") == "http://127.0.0.1:8188/prompt"
def test_cloud_adds_api_prefix(self):
assert resolve_url("https://cloud.comfy.org", "/prompt") == "https://cloud.comfy.org/api/prompt"
def test_cloud_history_renamed(self):
assert resolve_url("https://cloud.comfy.org", "/history/abc") == "https://cloud.comfy.org/api/history_v2/abc"
def test_cloud_models_renamed(self):
assert resolve_url("https://cloud.comfy.org", "/models/loras") == "https://cloud.comfy.org/api/experiment/models/loras"
def test_cloud_already_has_api(self):
# Don't double-prefix
assert resolve_url("https://cloud.comfy.org", "/api/prompt") == "https://cloud.comfy.org/api/prompt"
def test_trailing_slash_stripped(self):
assert resolve_url("http://127.0.0.1:8188/", "/prompt") == "http://127.0.0.1:8188/prompt"
# =============================================================================
# Workflow validation
# =============================================================================
class TestAPIFormatDetection:
def test_valid_api(self, sd15_workflow):
assert is_api_format(sd15_workflow) is True
def test_editor_format_rejected(self):
editor = {"nodes": [], "links": [], "version": 0.4}
assert is_api_format(editor) is False
def test_empty_dict(self):
assert is_api_format({}) is False
def test_non_dict(self):
assert is_api_format([]) is False
assert is_api_format(None) is False
assert is_api_format("string") is False
def test_node_with_class_type(self):
wf = {"3": {"class_type": "KSampler", "inputs": {}}}
assert is_api_format(wf) is True
class TestUnwrapWorkflow:
def test_passthrough_api_format(self, sd15_workflow):
result = unwrap_workflow(sd15_workflow)
assert result is sd15_workflow
def test_unwrap_prompt_key(self, sd15_workflow):
wrapped = {"prompt": sd15_workflow, "client_id": "abc"}
result = unwrap_workflow(wrapped)
assert result is sd15_workflow
def test_editor_format_raises(self):
with pytest.raises(ValueError, match="editor format"):
unwrap_workflow({"nodes": [], "links": []})
def test_garbage_raises(self):
with pytest.raises(ValueError):
unwrap_workflow({"foo": "bar"})
class TestIsLink:
def test_valid_link(self):
assert is_link(["3", 0]) is True
assert is_link(["10", 1]) is True
def test_non_link(self):
assert is_link("string") is False
assert is_link(42) is False
assert is_link([]) is False
assert is_link(["3"]) is False # missing slot
assert is_link(["3", "0"]) is False # slot must be int
assert is_link([3, 0]) is False # node_id must be string
# =============================================================================
# Workflow iterators
# =============================================================================
class TestIterators:
def test_iter_nodes(self, sd15_workflow):
nodes = dict(iter_nodes(sd15_workflow))
assert "3" in nodes
assert nodes["3"]["class_type"] == "KSampler"
def test_iter_nodes_skips_comments(self, sd15_workflow):
# _comment is not a node
nodes = dict(iter_nodes(sd15_workflow))
assert "_comment" not in nodes
def test_iter_model_deps(self, sd15_workflow):
deps = list(iter_model_deps(sd15_workflow))
names = [d["value"] for d in deps]
assert "v1-5-pruned-emaonly.safetensors" in names
def test_iter_model_deps_flux(self, flux_workflow):
deps = list(iter_model_deps(flux_workflow))
names = {d["value"]: d["folder"] for d in deps}
assert names["flux1-dev.safetensors"] == "unet"
assert names["t5xxl_fp16.safetensors"] == "clip"
assert names["clip_l.safetensors"] == "clip"
assert names["ae.safetensors"] == "vae"
# =============================================================================
# Embedding extraction
# =============================================================================
class TestEmbeddingRegex:
def test_basic_embedding(self):
m = EMBEDDING_REGEX.search("a cat, embedding:goodvibes, more text")
assert m is not None
assert m.group(1) == "goodvibes"
def test_embedding_with_strength(self):
m = EMBEDDING_REGEX.search("embedding:bad-hands-5:1.2")
assert m is not None
assert m.group(1) == "bad-hands-5"
def test_embedding_with_extension(self):
# Strips .pt / .safetensors / .bin
m = EMBEDDING_REGEX.search("embedding:my-emb.pt")
assert m is not None
assert m.group(1) == "my-emb"
def test_embedding_in_parens(self):
m = EMBEDDING_REGEX.search("(embedding:foo:0.8)")
assert m is not None
assert m.group(1) == "foo"
def test_multiple_in_one_string(self):
text = "a cat, embedding:foo:1.2, and embedding:bar"
matches = [m.group(1) for m in EMBEDDING_REGEX.finditer(text)]
assert matches == ["foo", "bar"]
def test_no_false_positive_on_word_embedding(self):
# "embedding " (with space, no colon) should not match
m = EMBEDDING_REGEX.search("the embedding is great")
assert m is None
class TestIterEmbeddingRefs:
def test_finds_in_clip_text_encode(self):
wf = {
"1": {"class_type": "CLIPTextEncode",
"inputs": {"text": "embedding:foo, embedding:bar:0.5", "clip": ["2", 0]}},
"2": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
}
refs = list(iter_embedding_refs(wf))
names = [name for _, name in refs]
assert names == ["foo", "bar"]
def test_ignores_non_prompt_fields(self):
wf = {
"1": {"class_type": "CheckpointLoaderSimple",
"inputs": {"ckpt_name": "embedding:foo.safetensors"}},
}
refs = list(iter_embedding_refs(wf))
# ckpt_name is not a prompt field — ignored
assert refs == []
# =============================================================================
# Path safety
# =============================================================================
class TestSafePathJoin:
def test_normal_join(self, tmp_path):
p = safe_path_join(tmp_path, "subdir", "file.png")
assert p.is_relative_to(tmp_path)
def test_blocks_traversal(self, tmp_path):
with pytest.raises(ValueError, match="path traversal"):
safe_path_join(tmp_path, "..", "..", "etc", "passwd")
def test_blocks_absolute(self, tmp_path):
with pytest.raises(ValueError):
safe_path_join(tmp_path, "/etc/passwd")
def test_subfolder_with_filename(self, tmp_path):
p = safe_path_join(tmp_path, "outputs", "img.png")
assert p.name == "img.png"
assert p.parent.name == "outputs"
# =============================================================================
# Seed coercion
# =============================================================================
class TestCoerceSeed:
def test_explicit_int(self):
assert coerce_seed(42) == 42
assert coerce_seed(0) == 0
def test_minus_one_randomizes(self):
s = coerce_seed(-1)
assert isinstance(s, int)
assert 0 <= s < 2**63
def test_none_randomizes(self):
s = coerce_seed(None)
assert isinstance(s, int)
def test_string_int(self):
# str() that converts cleanly is allowed (relaxed)
assert coerce_seed("12345") == 12345
def test_string_minus_one_randomizes(self):
# CLI / JSON sometimes carries seed as a string.
s = coerce_seed("-1")
assert isinstance(s, int)
assert 0 <= s < 2**63
# And whitespace tolerated
s2 = coerce_seed(" -1 ")
assert isinstance(s2, int)
assert 0 <= s2 < 2**63
# =============================================================================
# Model list normalization (cloud format)
# =============================================================================
class TestParseModelList:
def test_local_format_strings(self):
result = parse_model_list(["a.safetensors", "b.safetensors"])
assert result == {"a.safetensors", "b.safetensors"}
def test_cloud_format_dicts(self):
result = parse_model_list([
{"name": "a.safetensors", "pathIndex": 0},
{"name": "b.safetensors", "pathIndex": 1},
])
assert result == {"a.safetensors", "b.safetensors"}
def test_empty(self):
assert parse_model_list([]) == set()
def test_garbage(self):
assert parse_model_list("not a list") == set()
assert parse_model_list(None) == set()
def test_mixed_format(self):
result = parse_model_list([
"string-form.safetensors",
{"name": "dict-form.safetensors"},
])
assert result == {"string-form.safetensors", "dict-form.safetensors"}
# =============================================================================
# Folder aliases
# =============================================================================
class TestFolderAliases:
def test_unet_aliases_diffusion_models(self):
aliases = folder_aliases_for("unet")
assert "unet" in aliases
assert "diffusion_models" in aliases
def test_clip_aliases_text_encoders(self):
aliases = folder_aliases_for("clip")
assert "clip" in aliases
assert "text_encoders" in aliases
def test_unknown_folder_returns_self(self):
assert folder_aliases_for("checkpoints") == ["checkpoints"]
def test_primary_first(self):
# Order matters: primary should be first for human-friendly fix hints
assert folder_aliases_for("unet")[0] == "unet"
assert folder_aliases_for("diffusion_models")[0] == "diffusion_models"
# =============================================================================
# Media-type detection
# =============================================================================
class TestMediaType:
def test_video_extensions(self):
assert media_type_from_filename("vid.mp4") == "video"
assert media_type_from_filename("foo.webm") == "video"
assert media_type_from_filename("bar.gif") == "video"
def test_audio_extensions(self):
assert media_type_from_filename("song.wav") == "audio"
assert media_type_from_filename("music.mp3") == "audio"
def test_image_default(self):
assert media_type_from_filename("pic.png") == "image"
assert media_type_from_filename("image.jpg") == "image"
assert media_type_from_filename("unknown.xyz") == "image"
def test_3d(self):
assert media_type_from_filename("model.glb") == "3d"
assert media_type_from_filename("scene.gltf") == "3d"
# =============================================================================
# Cross-host header stripping (security)
# =============================================================================
class TestRedirectHeaderStripping:
"""Verify X-API-Key is dropped when redirect crosses to a different host
(e.g. cloud /api/view → S3 signed URL). Critical to prevent leaking auth
tokens to the storage backend.
"""
def _build_session(self):
from _common import _StripSensitiveOnRedirectSession, HAS_REQUESTS
if not HAS_REQUESTS:
import pytest
pytest.skip("requests not installed")
return _StripSensitiveOnRedirectSession()
def test_strips_x_api_key_cross_host(self):
import requests
s = self._build_session()
prep = requests.PreparedRequest()
prep.prepare(method="GET", url="https://other.example.com/file",
headers={"X-API-Key": "leak", "Authorization": "Bearer x"})
resp = requests.Response()
orig = requests.PreparedRequest()
orig.prepare(method="GET", url="https://cloud.comfy.org/api/view", headers={})
resp.request = orig
s.rebuild_auth(prep, resp)
assert "X-API-Key" not in prep.headers
assert "Authorization" not in prep.headers
def test_preserves_x_api_key_same_host(self):
import requests
s = self._build_session()
prep = requests.PreparedRequest()
prep.prepare(method="GET", url="https://cloud.comfy.org/foo",
headers={"X-API-Key": "keep"})
resp = requests.Response()
orig = requests.PreparedRequest()
orig.prepare(method="GET", url="https://cloud.comfy.org/bar", headers={})
resp.request = orig
s.rebuild_auth(prep, resp)
assert prep.headers.get("X-API-Key") == "keep"
def test_strips_cookie_cross_host(self):
import requests
s = self._build_session()
prep = requests.PreparedRequest()
prep.prepare(method="GET", url="https://other.example.com/x",
headers={"Cookie": "session=secret"})
resp = requests.Response()
orig = requests.PreparedRequest()
orig.prepare(method="GET", url="https://cloud.comfy.org/foo", headers={})
resp.request = orig
s.rebuild_auth(prep, resp)
assert "Cookie" not in prep.headers
# =============================================================================
# Video workflow detection
# =============================================================================
class TestVideoWorkflow:
def test_image_workflow(self, sd15_workflow):
assert looks_like_video_workflow(sd15_workflow) is False
def test_animatediff_workflow(self, workflows_dir):
import json
wf = json.loads((workflows_dir / "animatediff_video.json").read_text())
assert looks_like_video_workflow(wf) is True
def test_wan_workflow(self, video_workflow):
assert looks_like_video_workflow(video_workflow) is True

View File

@@ -0,0 +1,185 @@
"""Tests for extract_schema.py."""
from __future__ import annotations
import pytest
from extract_schema import (
extract_schema,
find_negative_prompt_node,
find_positive_prompt_node,
trace_to_node,
)
# =============================================================================
# Connection tracing
# =============================================================================
class TestConnectionTracing:
def test_direct_link(self):
wf = {
"1": {"class_type": "CLIPTextEncode", "inputs": {"text": "x"}},
"2": {"class_type": "KSampler",
"inputs": {"positive": ["1", 0], "negative": ["1", 0]}},
}
assert trace_to_node(wf, ["1", 0]) == "1"
def test_through_reroute(self):
wf = {
"1": {"class_type": "CLIPTextEncode", "inputs": {"text": "x"}},
"2": {"class_type": "Reroute", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "Reroute", "inputs": {"input": ["2", 0]}},
}
assert trace_to_node(wf, ["3", 0]) == "1"
def test_circular_safe(self):
wf = {
"1": {"class_type": "Reroute", "inputs": {"input": ["2", 0]}},
"2": {"class_type": "Reroute", "inputs": {"input": ["1", 0]}},
}
# Should hit max_hops without infinite loop
result = trace_to_node(wf, ["1", 0], max_hops=5)
assert result in ("1", "2") # any node, just don't hang
class TestPositiveNegativeDetection:
def test_basic(self, sd15_workflow):
# In sd15_workflow.json node 6 is positive, node 7 is negative
assert find_positive_prompt_node(sd15_workflow) == "6"
assert find_negative_prompt_node(sd15_workflow) == "7"
def test_swapped_order(self):
wf = {
"3": {"class_type": "KSampler",
"inputs": {
"positive": ["7", 0], "negative": ["6", 0],
"model": ["4", 0], "latent_image": ["5", 0],
"seed": 1, "steps": 20, "cfg": 7.5,
"sampler_name": "euler", "scheduler": "normal", "denoise": 1.0,
}},
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}},
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": "ugly", "clip": ["4", 1]}},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "beautiful", "clip": ["4", 1]}},
}
# Now 7 is the positive (despite higher node ID)
assert find_positive_prompt_node(wf) == "7"
assert find_negative_prompt_node(wf) == "6"
# =============================================================================
# Schema extraction
# =============================================================================
class TestExtractSchema:
def test_basic_sd15(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
params = schema["parameters"]
assert "prompt" in params
assert "negative_prompt" in params
assert "seed" in params
assert "steps" in params
assert "cfg" in params
assert "width" in params
assert "height" in params
def test_prompt_value_correct(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
# The positive prompt in the example is the landscape one
assert "landscape" in schema["parameters"]["prompt"]["value"]
assert "ugly" in schema["parameters"]["negative_prompt"]["value"]
def test_model_dependencies(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
deps = schema["model_dependencies"]
ckpts = [d["value"] for d in deps if d["folder"] == "checkpoints"]
assert "v1-5-pruned-emaonly.safetensors" in ckpts
def test_output_nodes(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
assert "9" in schema["output_nodes"]
def test_summary(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
s = schema["summary"]
assert s["has_negative_prompt"] is True
assert s["has_seed"] is True
assert s["is_video_workflow"] is False
assert s["parameter_count"] > 5
def test_flux_workflow(self, flux_workflow):
schema = extract_schema(flux_workflow)
# Flux uses RandomNoise for seed
assert schema["summary"]["has_seed"] is True
# Flux has only positive prompt (no negative encoder)
assert schema["summary"]["has_negative_prompt"] is False
def test_video_detected(self, video_workflow):
schema = extract_schema(video_workflow)
assert schema["summary"]["is_video_workflow"] is True
class TestEmbeddingDeps:
def test_extract_from_prompt(self):
wf = {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
"5": {"class_type": "EmptyLatentImage",
"inputs": {"width": 512, "height": 512, "batch_size": 1}},
"6": {"class_type": "CLIPTextEncode",
"inputs": {
"text": "a cat, embedding:goodvibes, embedding:art:1.2",
"clip": ["1", 1]
}},
"7": {"class_type": "CLIPTextEncode",
"inputs": {
"text": "ugly, embedding:badhands",
"clip": ["1", 1]
}},
"3": {"class_type": "KSampler",
"inputs": {
"positive": ["6", 0], "negative": ["7", 0],
"model": ["1", 0], "latent_image": ["5", 0],
"seed": 1, "steps": 20, "cfg": 7.5,
"sampler_name": "euler", "scheduler": "normal", "denoise": 1.0,
}},
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}},
}
schema = extract_schema(wf)
names = [d["embedding_name"] for d in schema["embedding_dependencies"]]
assert sorted(names) == ["art", "badhands", "goodvibes"]
class TestDuplicateDeduplication:
def test_two_ksamplers_get_unique_names(self):
wf = {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
"5": {"class_type": "EmptyLatentImage",
"inputs": {"width": 512, "height": 512, "batch_size": 1}},
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": "a", "clip": ["1", 1]}},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "b", "clip": ["1", 1]}},
"3": {"class_type": "KSampler",
"inputs": {
"positive": ["6", 0], "negative": ["7", 0],
"model": ["1", 0], "latent_image": ["5", 0],
"seed": 42, "steps": 20, "cfg": 7.5,
"sampler_name": "euler", "scheduler": "normal", "denoise": 1.0,
}},
"4": {"class_type": "KSampler",
"inputs": {
"positive": ["6", 0], "negative": ["7", 0],
"model": ["1", 0], "latent_image": ["5", 0],
"seed": 99, "steps": 30, "cfg": 8.0,
"sampler_name": "euler", "scheduler": "normal", "denoise": 0.6,
}},
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}},
}
schema = extract_schema(wf)
params = schema["parameters"]
# Both seeds present with disambiguated names
seed_keys = [k for k in params if "seed" in k]
# Symmetric: both renamed (no bare "seed")
assert "seed" not in params
assert "seed_3" in params and "seed_4" in params
assert params["seed_3"]["value"] == 42
assert params["seed_4"]["value"] == 99

View File

@@ -0,0 +1,213 @@
"""Tests for run_workflow.py — focuses on logic that doesn't require a server."""
from __future__ import annotations
import copy
import json
import pytest
from extract_schema import extract_schema
from run_workflow import (
ComfyRunner,
download_outputs,
inject_params,
parse_input_image_arg,
)
class TestParseInputImageArg:
def test_with_name(self, tmp_path):
f = tmp_path / "x.png"
f.write_text("x")
n, p = parse_input_image_arg(f"image={f}")
assert n == "image"
assert p == f
def test_without_name_defaults(self, tmp_path):
f = tmp_path / "x.png"
f.write_text("x")
n, p = parse_input_image_arg(str(f))
assert n == "image"
def test_custom_name(self, tmp_path):
f = tmp_path / "x.png"
f.write_text("x")
n, p = parse_input_image_arg(f"mask_image={f}")
assert n == "mask_image"
class TestInjectParams:
def test_basic_injection(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
wf, warnings = inject_params(sd15_workflow, schema, {
"prompt": "new prompt",
"seed": 999,
"steps": 25,
})
assert wf["6"]["inputs"]["text"] == "new prompt"
assert wf["3"]["inputs"]["seed"] == 999
assert wf["3"]["inputs"]["steps"] == 25
assert warnings == []
def test_unknown_param_warns(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
_, warnings = inject_params(sd15_workflow, schema, {"foobar": "x"})
assert any("foobar" in w for w in warnings)
def test_seed_minus_one_randomizes(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
wf, warnings = inject_params(sd15_workflow, schema, {"seed": -1})
assert wf["3"]["inputs"]["seed"] != -1
assert isinstance(wf["3"]["inputs"]["seed"], int)
assert any("expanded" in w.lower() for w in warnings)
def test_randomize_seed_when_unset(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
original = sd15_workflow["3"]["inputs"]["seed"]
wf, warnings = inject_params(sd15_workflow, schema, {}, randomize_seed_if_unset=True)
assert wf["3"]["inputs"]["seed"] != original
assert isinstance(wf["3"]["inputs"]["seed"], int)
def test_does_not_mutate_original(self, sd15_workflow):
schema = extract_schema(sd15_workflow)
original_text = sd15_workflow["6"]["inputs"]["text"]
inject_params(sd15_workflow, schema, {"prompt": "MUTATED"})
assert sd15_workflow["6"]["inputs"]["text"] == original_text
def test_refuses_to_overwrite_link(self):
wf = {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}},
"5": {"class_type": "EmptyLatentImage",
"inputs": {"width": 512, "height": 512, "batch_size": 1}},
"6": {"class_type": "CLIPTextEncode",
"inputs": {"text": ["3", 0], "clip": ["1", 1]}}, # text is a link!
"3": {"class_type": "KSampler",
"inputs": {"seed": 1, "steps": 20, "cfg": 7.5,
"sampler_name": "euler", "scheduler": "normal", "denoise": 1.0,
"model": ["1", 0], "positive": ["6", 0], "negative": ["6", 0],
"latent_image": ["5", 0]}},
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}},
}
# Manually create a schema that has prompt pointing at 6.text
schema = {
"parameters": {
"prompt": {"node_id": "6", "field": "text", "type": "string", "value": ""},
}
}
wf2, warnings = inject_params(wf, schema, {"prompt": "literal value"})
# The link should NOT have been overwritten
assert wf2["6"]["inputs"]["text"] == ["3", 0]
assert any("link" in w.lower() for w in warnings)
# =============================================================================
# Output download walk
# =============================================================================
class TestDownloadOutputsWalk:
"""Test that download_outputs walks the structure correctly."""
def test_handles_videos_plural(self, tmp_path, monkeypatch):
"""Local ComfyUI uses 'videos'/'gifs' (plural) keys."""
downloads = []
class FakeRunner:
def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite):
downloads.append((filename, subfolder, file_type))
p = output_dir / filename
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(b"x")
return p
outputs = {
"9": {"images": [{"filename": "img1.png", "subfolder": "", "type": "output"}]},
"10": {"videos": [{"filename": "vid1.mp4", "subfolder": "", "type": "output"}]},
"11": {"gifs": [{"filename": "anim1.gif", "subfolder": "", "type": "output"}]},
}
result = download_outputs(FakeRunner(), outputs, tmp_path)
files = sorted(d["filename"] for d in result)
assert files == ["anim1.gif", "img1.png", "vid1.mp4"]
def test_handles_video_singular_cloud(self, tmp_path):
"""Cloud uses 'video' (singular)."""
class FakeRunner:
def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite):
p = output_dir / filename
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(b"x")
return p
outputs = {
"10": {"video": [{"filename": "cloud.mp4", "subfolder": "", "type": "output"}]},
}
result = download_outputs(FakeRunner(), outputs, tmp_path)
assert len(result) == 1
assert result[0]["filename"] == "cloud.mp4"
def test_preserves_subfolder(self, tmp_path):
"""When preserve_subfolder=True, server subfolder becomes local subdir."""
class FakeRunner:
def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite):
if preserve_subfolder and subfolder:
p = output_dir / subfolder / filename
else:
p = output_dir / filename
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(b"x")
return p
outputs = {
"9": {"images": [
{"filename": "img.png", "subfolder": "myrun", "type": "output"},
{"filename": "img.png", "subfolder": "otherrun", "type": "output"},
]},
}
result = download_outputs(FakeRunner(), outputs, tmp_path, preserve_subfolder=True)
files = [d["file"] for d in result]
assert any("myrun" in f for f in files)
assert any("otherrun" in f for f in files)
# Both must exist (no collision)
assert len({str(f) for f in files}) == 2
# =============================================================================
# ComfyRunner construction
# =============================================================================
class TestRunnerConstruction:
def test_local_default(self):
r = ComfyRunner()
assert r.is_cloud is False
assert r.host == "http://127.0.0.1:8188"
def test_cloud_detection(self):
r = ComfyRunner(host="https://cloud.comfy.org", api_key="abc")
assert r.is_cloud is True
assert "X-API-Key" in r.headers
def test_cloud_subdomain_detected(self):
r = ComfyRunner(host="https://staging.cloud.comfy.org", api_key="abc")
assert r.is_cloud is True
def test_partner_key_does_not_pollute_extra_data(self):
r = ComfyRunner(host="https://cloud.comfy.org", api_key="auth-key")
# No partner-key set → no extra_data should appear in submitted prompt
# (This is a static check; runtime check happens in submit())
assert r.partner_key is None
def test_url_routing_local(self):
r = ComfyRunner()
url = r._url("/prompt")
assert url == "http://127.0.0.1:8188/prompt"
def test_url_routing_cloud(self):
r = ComfyRunner(host="https://cloud.comfy.org", api_key="x")
url = r._url("/prompt")
assert url == "https://cloud.comfy.org/api/prompt"
def test_url_routing_cloud_history_renamed(self):
r = ComfyRunner(host="https://cloud.comfy.org", api_key="x")
url = r._url("/history/abc-123")
assert url == "https://cloud.comfy.org/api/history_v2/abc-123"