Files
hermes-agent/skills/creative/comfyui/scripts/run_workflow.py
SHL0MS a7780fe05f fix(skills/comfyui): bug fixes, cloud parity, expanded coverage, examples, tests
The audit of v4.1 surfaced ~70 issues across the five scripts and three
reference docs — most user-visible (silent file overwrites, status-error
misclassified as success, X-API-Key leaked to S3 on /api/view redirect,
Cloud endpoints that 404 because they were renamed). v5.0.0 fixes those
and fills the gaps that previously forced users to write their own glue
(WebSocket monitoring, batch/sweep, img2img upload helper, dep auto-fix,
log fetch, health check, example workflows).

Critical fixes
- run_workflow.py: poll_status now checks status_str==error BEFORE
  completed:true, so a failed run no longer reports success
- run_workflow.py: download_output streams to disk via safe_path_join,
  preserves server subfolder structure (no silent overwrites), and
  retries with exponential backoff
- run_workflow.py: refuses to overwrite a link with a literal in
  inject_params (would silently break wiring)
- _common.py: _StripSensitiveOnRedirectSession (subclasses
  requests.Session.rebuild_auth) drops X-API-Key/Cookie on cross-host
  redirects — fixes a real key-leak path through Cloud's signed-URL
  download flow. Tested
- Cloud routing (verified live): /history → /history_v2,
  /models/<f> → /experiment/models/<f>, plus folder aliases for the
  unet ↔ diffusion_models and clip ↔ text_encoders rename
- check_deps.py: distinguishes 200/empty vs 404 folder_not_found vs
  403 free-tier; emits concrete fix_command per missing dep
- extract_schema.py: prompt vs negative_prompt determined by tracing
  KSampler.{positive,negative} connections (incl. through Reroute /
  Primitive nodes) instead of meta-title heuristic; symmetric
  duplicate-name resolution; cycle-safe trace_to_node
- hardware_check.py: multi-GPU pick-best, Apple variant detection,
  Rosetta detection, WSL2, ROCm --json, disk-space check, optional
  PyTorch probe; powershell preferred over deprecated wmic
- comfyui_setup.sh: prefers pipx → uvx → pip --user (with PEP-668
  fallback); idempotent — skips relaunch if server already up;
  configurable port/workspace; persistent log; SIGINT trap

New scripts
- run_batch.py — count or sweep (cartesian product), parallel up to
  cloud tier limit
- ws_monitor.py — real-time WebSocket viewer; saves preview frames
- auto_fix_deps.py — runs comfy node install / model download for
  whatever check_deps reports missing (with --dry-run)
- health_check.py — single command that runs the verification checklist
  (comfy-cli + server + checkpoints + optional smoke test that cancels
  itself to avoid burning compute)
- fetch_logs.py — pull traceback / status messages for a prompt_id

Coverage expansion
- Param patterns now cover Flux (BasicScheduler, BasicGuider,
  RandomNoise, ModelSamplingFlux), SD3, Wan/Hunyuan/LTX video,
  IPAdapter, rgthree, easy-use, AnimateDiff
- Embedding refs in CLIPTextEncode strings extracted as model deps
- ckpt_name / vae_name / lora_name / unet_name now controllable so
  workflows can be retargeted per run

Examples
- workflows/{sd15,sdxl,flux_dev}_txt2img.json
- workflows/sdxl_{img2img,inpaint}.json
- workflows/upscale_4x.json
- workflows/{animatediff_video,wan_video_t2v}.json + README

Tests
- 117 tests (105 unit + 8 cloud integration + 4 cross-host security)
- Cloud tests auto-skip without COMFY_CLOUD_API_KEY; verified end-to-end
  against live cloud API

Backwards compatibility
- All existing CLI flags continue to work; new behavior is opt-in
  (--ws, --input-image, --randomize-seed, --flat-output, etc.)
2026-04-29 20:48:01 -07:00

797 lines
31 KiB
Python
Executable File

#!/usr/bin/env python3
"""
run_workflow.py — Inject parameters into a ComfyUI workflow, submit it, monitor
execution, and download outputs.
Improvements over v1:
- Cloud-aware URL routing (handles /api prefix and /history_v2 / /experiment/models renames)
- API key from CLI flag OR $COMFY_CLOUD_API_KEY env var
- WebSocket progress monitoring (--ws), with HTTP polling fallback
- Streaming download (no whole-file buffering — handles GB-size video outputs)
- Path-traversal-safe output writes
- Subfolder-aware download paths (no silent overwrites)
- Retry with exponential backoff on transient errors
- Status-error correctly classified before "completed: true"
- Image upload helper (--input-image NAME=PATH)
- Auto-randomize seed when value is -1 or omitted on a randomize-seed flag
- Auto-extends timeout heuristically for video workflows
- Editor-format detection with helpful error
- Doesn't pollute extra_data.api_key_comfy_org with the cloud auth key
unless --partner-key is provided (correct semantic per cloud docs)
Usage:
# Local server
python3 run_workflow.py --workflow workflow_api.json \
--args '{"prompt": "a cat", "seed": 42}' \
--output-dir ./outputs
# Cloud server (API key from env var)
export COMFY_CLOUD_API_KEY="comfyui-xxxxxxx"
python3 run_workflow.py --workflow workflow_api.json \
--args '{"prompt": "a cat"}' \
--host https://cloud.comfy.org \
--output-dir ./outputs
# With image input (auto-uploads, then references)
python3 run_workflow.py --workflow img2img.json \
--input-image image=./photo.png \
--args '{"prompt": "make it cyberpunk"}'
# WebSocket real-time progress
python3 run_workflow.py --workflow flux_dev.json \
--args '{"prompt": "..."}' \
--ws
Stdlib-only by default (Python 3.10+). Will use `requests`/`websocket-client`
if installed for nicer behavior.
"""
from __future__ import annotations
import argparse
import copy
import json
import sys
import time
from pathlib import Path
from typing import Any
from urllib.parse import urlencode, urlparse
# Local import — _common.py sits next to this script.
sys.path.insert(0, str(Path(__file__).resolve().parent))
from _common import ( # noqa: E402
DEFAULT_LOCAL_HOST, ENV_API_KEY,
coerce_seed, emit_json, http_get, http_post, http_request,
is_cloud_host, is_link, log, looks_like_video_workflow,
media_type_from_filename, new_client_id, resolve_api_key, resolve_url,
safe_path_join, unwrap_workflow,
)
# =============================================================================
# Runner
# =============================================================================
class WorkflowRunError(Exception):
"""Raised when a workflow run fails (validation, execution, timeout)."""
def __init__(self, status: str, message: str, **details: Any):
super().__init__(message)
self.status = status
self.message = message
self.details = details
def to_dict(self) -> dict:
d = {"status": self.status, "error": self.message}
d.update(self.details)
return d
class ComfyRunner:
def __init__(
self,
host: str = DEFAULT_LOCAL_HOST,
api_key: str | None = None,
client_id: str | None = None,
partner_key: str | None = None,
):
self.host = host.rstrip("/")
self.api_key = api_key
self.partner_key = partner_key
self.is_cloud = is_cloud_host(self.host)
self.client_id = client_id or new_client_id()
@property
def headers(self) -> dict[str, str]:
h: dict[str, str] = {}
if self.api_key:
h["X-API-Key"] = self.api_key
return h
def _url(self, path: str) -> str:
return resolve_url(self.host, path, is_cloud=self.is_cloud)
# ---------- server health ----------
def check_server(self) -> tuple[bool, dict | None]:
try:
r = http_get(self._url("/system_stats"), headers=self.headers, retries=2)
if r.status == 200:
try:
return True, r.json()
except Exception:
return True, None
return False, {"http_status": r.status, "body": r.text()[:500]}
except Exception as e:
return False, {"error": str(e)}
# ---------- upload ----------
def upload_image(self, path: Path, *, image_type: str = "input", overwrite: bool = True,
endpoint: str = "/upload/image", extra_form: dict | None = None) -> dict:
"""Upload an image file via multipart. Returns server-side ref dict."""
if not path.exists():
raise FileNotFoundError(f"input image not found: {path}")
# Stream the file via a handle to avoid OOM on huge inputs (16MP+ photos).
with path.open("rb") as fh:
files = {"image": (path.name, fh)}
form = {"type": image_type}
if overwrite:
form["overwrite"] = "true"
if extra_form:
form.update({k: str(v) for k, v in extra_form.items()})
r = http_request(
"POST", self._url(endpoint),
headers=self.headers, files=files, form=form,
timeout=300, retries=2,
)
if r.status != 200:
raise WorkflowRunError(
"upload_failed",
f"Upload of {path.name} failed: HTTP {r.status}",
body=r.text()[:500],
)
try:
return r.json()
except Exception:
return {"name": path.name}
def upload_mask(self, path: Path, original_ref: dict) -> dict:
"""Upload an inpaint mask, linked to a previously uploaded source image.
`original_ref` should be the dict returned by `upload_image()` for the
source image (or `{"filename": ..., "subfolder": ..., "type": "input"}`).
"""
return self.upload_image(
path,
endpoint="/upload/mask",
extra_form={
"subfolder": "clipspace",
"original_ref": json.dumps(original_ref),
},
)
# ---------- submit ----------
def submit(self, workflow: dict) -> dict:
payload: dict[str, Any] = {"prompt": workflow, "client_id": self.client_id}
if self.partner_key:
payload["extra_data"] = {"api_key_comfy_org": self.partner_key}
r = http_post(self._url("/prompt"), headers=self.headers, json_body=payload, timeout=120)
try:
body = r.json()
except Exception:
body = {"raw": r.text()[:500]}
if r.status != 200:
return {"_http_error": r.status, "body": body}
return body
# ---------- HTTP polling ----------
def poll_status(self, prompt_id: str, *, timeout: float = 300.0,
initial_interval: float = 1.5, max_interval: float = 8.0) -> dict:
start = time.time()
interval = initial_interval
while time.time() - start < timeout:
if self.is_cloud:
r = http_get(
self._url(f"/job/{prompt_id}/status"),
headers=self.headers, retries=2, timeout=30,
)
if r.status == 200:
try:
data = r.json()
except Exception:
data = {}
s = data.get("status")
if s == "completed":
return {"status": "success", "data": data}
if s in ("failed",):
return {"status": "error", "data": data}
if s == "cancelled":
return {"status": "cancelled", "data": data}
# pending / in_progress → continue
elif r.status == 404:
# Cloud sometimes 404s briefly between submit and dispatcher pickup
pass
else:
# transient error — retry loop covers it
pass
else:
# Local: /history/{id} grows once execution completes
r = http_get(
self._url(f"/history/{prompt_id}"),
headers=self.headers, retries=2, timeout=30,
)
if r.status == 200:
try:
data = r.json() or {}
except Exception:
data = {}
entry = data.get(prompt_id)
if isinstance(entry, dict):
st = entry.get("status") or {}
# IMPORTANT: check error first — `completed: true` can coexist with errors
status_str = st.get("status_str")
if status_str == "error":
return {"status": "error", "data": entry}
if st.get("completed", False):
return {"status": "success", "outputs": entry.get("outputs", {})}
# not in history yet → continue polling
time.sleep(interval)
interval = min(max_interval, interval * 1.4)
return {"status": "timeout", "elapsed": time.time() - start}
# ---------- WebSocket monitoring ----------
def monitor_ws(self, prompt_id: str, *, timeout: float = 300.0,
on_progress: Any = None) -> dict:
"""Connect to /ws and listen until execution_success / execution_error.
Falls back to HTTP polling if `websocket-client` is not installed.
Returns same shape as poll_status.
"""
try:
import websocket # type: ignore[import-not-found]
except ImportError:
log("websocket-client not installed; falling back to HTTP polling")
return self.poll_status(prompt_id, timeout=timeout)
# Build WS URL. Preserve any base-path components the user gave us
# (e.g. http://example.com/comfyui → ws://example.com/comfyui/ws).
parsed = urlparse(self.host)
scheme = "wss" if parsed.scheme == "https" else "ws"
netloc = parsed.netloc
base_path = parsed.path.rstrip("/")
ws_url = f"{scheme}://{netloc}{base_path}/ws?clientId={self.client_id}"
if self.is_cloud and self.api_key:
ws_url += f"&token={self.api_key}"
outputs: dict[str, Any] = {}
error_payload: dict[str, Any] | None = None
success = False
seen_executed = False
ws = websocket.create_connection(ws_url, timeout=timeout)
try:
ws.settimeout(timeout)
deadline = time.time() + timeout
while time.time() < deadline:
msg = ws.recv()
if isinstance(msg, bytes):
# Binary preview frame — ignore for now; ws_monitor.py prints them
continue
try:
payload = json.loads(msg)
except Exception:
continue
mtype = payload.get("type", "")
mdata = payload.get("data", {}) or {}
# Filter to our job (cloud broadcasts; local filters via client_id)
pid = mdata.get("prompt_id")
if pid is not None and pid != prompt_id:
continue
if mtype == "progress":
if callable(on_progress):
on_progress({
"type": "progress",
"value": mdata.get("value"),
"max": mdata.get("max"),
"node": mdata.get("node"),
})
elif mtype == "progress_state":
if callable(on_progress):
on_progress({"type": "progress_state", "nodes": mdata.get("nodes", {})})
elif mtype == "executing":
node = mdata.get("node")
if callable(on_progress):
on_progress({"type": "executing", "node": node})
# When `node` is None on a local server, that signals end-of-run
if node is None and not self.is_cloud and seen_executed:
success = True
break
elif mtype == "executed":
seen_executed = True
nid = mdata.get("node")
out = mdata.get("output") or {}
if nid:
outputs[nid] = out
elif mtype == "notification":
if callable(on_progress):
on_progress({"type": "notification", "message": mdata.get("value", "")})
elif mtype == "execution_success":
success = True
break
elif mtype == "execution_error":
error_payload = mdata
break
elif mtype == "execution_interrupted":
error_payload = {"interrupted": True, **mdata}
break
finally:
try:
ws.close()
except Exception:
pass
if error_payload is not None:
return {"status": "error", "data": error_payload}
if success:
return {"status": "success", "outputs": outputs}
return {"status": "timeout", "elapsed": timeout}
# ---------- outputs ----------
def get_outputs(self, prompt_id: str) -> dict:
if self.is_cloud:
# Try /jobs/{id} first (returns full job with outputs); fall back to /history_v2
r = http_get(self._url(f"/jobs/{prompt_id}"), headers=self.headers, retries=2)
if r.status == 200:
try:
return (r.json() or {}).get("outputs", {}) or {}
except Exception:
pass
# Fallback
r = http_get(self._url(f"/history/{prompt_id}"), headers=self.headers, retries=2)
if r.status == 200:
try:
body = r.json() or {}
except Exception:
body = {}
if isinstance(body, dict) and prompt_id in body:
return body[prompt_id].get("outputs", {}) or {}
if isinstance(body, dict) and "outputs" in body:
return body["outputs"] or {}
return {}
# Local
r = http_get(self._url(f"/history/{prompt_id}"), headers=self.headers, retries=2)
if r.status != 200:
return {}
try:
body = r.json() or {}
except Exception:
return {}
entry = body.get(prompt_id) or {}
return entry.get("outputs", {}) or {}
def download_output(
self, *, filename: str, subfolder: str, file_type: str,
output_dir: Path, preserve_subfolder: bool = True, overwrite: bool = False,
) -> Path:
"""Stream a single output to disk. Path-traversal-safe."""
params = {"filename": filename, "subfolder": subfolder, "type": file_type}
url = self._url("/view") + "?" + urlencode(params)
# Compute target path safely. If preserve_subfolder, include subfolder in the
# local path; otherwise put the file in output_dir flat.
target_parts: list[str] = []
if preserve_subfolder and subfolder:
target_parts.extend(p for p in subfolder.split("/") if p and p not in (".", ".."))
target_parts.append(filename)
out_path = safe_path_join(output_dir, *target_parts)
if out_path.exists() and not overwrite:
stem, suffix = out_path.stem, out_path.suffix
i = 1
while True:
candidate = out_path.with_name(f"{stem}_{i}{suffix}")
if not candidate.exists():
out_path = candidate
break
i += 1
out_path.parent.mkdir(parents=True, exist_ok=True)
# Stream download. Two-step for cloud: get the 302, then fetch signed URL
# so we don't accidentally send X-API-Key to the storage backend.
# The HTTP transport already strips X-API-Key on cross-host redirect
# via _strip_api_key_on_redirect, so a single follow_redirects=True call
# is safe AND simpler.
r = http_request(
"GET", url, headers=self.headers,
timeout=600, retries=3, follow_redirects=True,
stream=True, sink=out_path,
)
if r.status != 200:
try:
if out_path.exists():
out_path.unlink()
except Exception:
pass
raise WorkflowRunError(
"download_failed",
f"Download of {filename} failed: HTTP {r.status}",
url=url,
)
return out_path
# ---------- queue / cancel ----------
def cancel(self, prompt_id: str | None = None) -> bool:
if prompt_id:
r = http_post(
self._url("/queue"), headers=self.headers,
json_body={"delete": [prompt_id]}, retries=1,
)
return r.status == 200
# Interrupt currently running
r = http_post(self._url("/interrupt"), headers=self.headers, retries=1)
return r.status == 200
# =============================================================================
# Schema / parameter injection
# =============================================================================
def _inline_schema(workflow: dict) -> dict:
"""Generate schema using the sibling extract_schema module."""
from extract_schema import extract_schema # noqa: WPS433
return extract_schema(workflow)
def load_schema(schema_path: str | None, workflow: dict) -> dict:
if schema_path:
with open(schema_path) as f:
return json.load(f)
return _inline_schema(workflow)
def inject_params(
workflow: dict, schema: dict, args: dict,
*, randomize_seed_if_unset: bool = False,
) -> tuple[dict, list[str]]:
"""Inject user args into the workflow. Returns (new_workflow, warnings)."""
wf = copy.deepcopy(workflow)
params = schema.get("parameters", {}) or {}
warnings: list[str] = []
# Auto-randomize seed when it's -1 in args, or when randomize_seed_if_unset
# and user didn't pass a seed.
if "seed" in params:
if "seed" in args and args["seed"] in (None, -1, "-1"):
args = dict(args)
args["seed"] = coerce_seed(args["seed"])
warnings.append(f"seed=-1 expanded to {args['seed']}")
elif randomize_seed_if_unset and "seed" not in args:
args = dict(args)
args["seed"] = coerce_seed(None)
warnings.append(f"seed auto-randomized to {args['seed']}")
for name, value in args.items():
if name not in params:
warnings.append(f"unknown parameter '{name}' (not in schema), skipping")
continue
m = params[name]
nid, field = m["node_id"], m["field"]
node = wf.get(nid)
if not isinstance(node, dict) or "inputs" not in node:
warnings.append(f"node '{nid}' for parameter '{name}' missing in workflow")
continue
# Refuse to overwrite a link with a literal — would silently break wiring
cur = node["inputs"].get(field)
if is_link(cur):
warnings.append(
f"parameter '{name}' targets {nid}.{field} which is currently a link; "
f"refusing to overwrite (set the schema to point at the source node instead)"
)
continue
node["inputs"][field] = value
return wf, warnings
# =============================================================================
# Output download helper
# =============================================================================
def download_outputs(
runner: ComfyRunner, outputs: dict, output_dir: Path,
*, preserve_subfolder: bool = True, overwrite: bool = False,
) -> list[dict]:
"""Walk the outputs dict and download every file. Cloud uses `video` (singular);
local uses `videos` (plural). We accept both."""
output_dir.mkdir(parents=True, exist_ok=True)
downloaded: list[dict] = []
OUTPUT_KEYS = ("images", "gifs", "videos", "video", "audio", "files", "models", "3d")
for node_id, node_output in (outputs or {}).items():
if not isinstance(node_output, dict):
continue
for key in OUTPUT_KEYS:
entries = node_output.get(key)
if not entries:
continue
if not isinstance(entries, list):
entries = [entries]
for fi in entries:
if not isinstance(fi, dict):
continue
filename = fi.get("filename") or ""
if not filename:
continue
subfolder = fi.get("subfolder") or ""
file_type = fi.get("type") or "output"
try:
out_path = runner.download_output(
filename=filename, subfolder=subfolder, file_type=file_type,
output_dir=output_dir, preserve_subfolder=preserve_subfolder,
overwrite=overwrite,
)
downloaded.append({
"file": str(out_path),
"node_id": node_id,
"type": media_type_from_filename(filename),
"filename": filename,
"subfolder": subfolder,
"source_type": file_type,
})
except Exception as e:
log(f"WARN: failed to download {filename}: {e}")
return downloaded
# =============================================================================
# CLI
# =============================================================================
def parse_input_image_arg(spec: str) -> tuple[str, Path]:
"""Parse `name=path` (or `path` alone, defaulting to name='image')."""
if "=" in spec:
name, path = spec.split("=", 1)
return name.strip(), Path(path).expanduser()
return "image", Path(spec).expanduser()
def main(argv: list[str] | None = None) -> int:
p = argparse.ArgumentParser(
description="Run a ComfyUI workflow with parameter injection.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--workflow", required=True, help="Path to workflow API JSON file")
p.add_argument("--args", default="{}",
help="JSON parameters to inject (or `@/path/to/args.json`)")
p.add_argument("--schema", help="Path to schema JSON (auto-generated if omitted)")
p.add_argument("--host", default=DEFAULT_LOCAL_HOST, help="ComfyUI server URL")
p.add_argument("--api-key",
help=f"API key for cloud (or set ${ENV_API_KEY} env var)")
p.add_argument("--partner-key",
help="Partner-node API key (extra_data.api_key_comfy_org). "
"Required for Flux Pro / Ideogram / etc. Defaults to --api-key if not set.")
p.add_argument("--output-dir", default="./outputs", help="Directory to save outputs")
p.add_argument("--timeout", type=int, default=0,
help="Max seconds to wait (0=auto: 300 / 900 for video workflows)")
p.add_argument("--input-image", action="append", default=[],
help="Upload local image before running. Format: `name=path` or `path`. "
"The `name` becomes the value injected into the matching schema parameter.")
p.add_argument("--randomize-seed", action="store_true",
help="If schema has a 'seed' parameter and --args didn't set one, randomize it")
p.add_argument("--ws", action="store_true",
help="Use WebSocket for real-time progress (requires `websocket-client`)")
p.add_argument("--no-download", action="store_true", help="Skip downloading outputs")
p.add_argument("--flat-output", action="store_true",
help="Don't preserve server-side subfolder structure when saving outputs")
p.add_argument("--overwrite", action="store_true",
help="Overwrite existing files instead of appending _1, _2, ...")
p.add_argument("--submit-only", action="store_true",
help="Submit and return prompt_id without waiting")
p.add_argument("--client-id", help="Override generated client_id (UUID)")
p.add_argument("--use-partner-key-as-auth", action="store_true",
help="(Compat) Use --partner-key value as cloud X-API-Key. Don't use unless you know why.")
args = p.parse_args(argv)
# ---- Load workflow ----
wf_path = Path(args.workflow).expanduser()
if not wf_path.exists():
emit_json({"error": f"Workflow file not found: {args.workflow}"})
return 1
try:
with wf_path.open() as f:
workflow_raw = json.load(f)
workflow = unwrap_workflow(workflow_raw)
except ValueError as e:
emit_json({"error": str(e)})
return 1
except json.JSONDecodeError as e:
emit_json({"error": f"Invalid JSON in workflow file: {e}"})
return 1
# ---- Parse user args ----
args_str = args.args
if args_str.startswith("@"):
try:
args_str = Path(args_str[1:]).read_text()
except OSError as e:
emit_json({"error": f"Cannot read args file: {e}"})
return 1
try:
user_args = json.loads(args_str) if args_str.strip() else {}
except json.JSONDecodeError as e:
emit_json({"error": f"Invalid --args JSON: {e}"})
return 1
if not isinstance(user_args, dict):
emit_json({"error": "--args must be a JSON object"})
return 1
# ---- Resolve API key ----
api_key = resolve_api_key(args.api_key)
partner_key = args.partner_key or None
if args.use_partner_key_as_auth and not api_key and partner_key:
api_key = partner_key
# ---- Connect ----
runner = ComfyRunner(
host=args.host, api_key=api_key, partner_key=partner_key,
client_id=args.client_id,
)
# Server reachability
ok, info = runner.check_server()
if not ok:
emit_json({
"error": f"Cannot reach server at {args.host}",
"details": info,
"hint": (
"Check `comfy launch --background` is running for local, "
f"or set ${ENV_API_KEY} for cloud."
),
})
return 1
# ---- Upload input images ----
upload_warnings: list[str] = []
for spec in args.input_image:
try:
param_name, path = parse_input_image_arg(spec)
except Exception as e:
emit_json({"error": f"Bad --input-image spec '{spec}': {e}"})
return 1
try:
ref = runner.upload_image(path)
except Exception as e:
emit_json({"error": f"Upload failed for {path}: {e}"})
return 1
# Register as a user arg so inject_params consumes it through the schema
uploaded_name = ref.get("name") or path.name
if param_name not in user_args:
user_args[param_name] = uploaded_name
# ---- Inject params ----
schema = load_schema(args.schema, workflow)
workflow, inj_warnings = inject_params(
workflow, schema, user_args, randomize_seed_if_unset=args.randomize_seed,
)
warnings = upload_warnings + inj_warnings
for w in warnings:
log(f"WARN: {w}")
# ---- Submit ----
submit_resp = runner.submit(workflow)
if "_http_error" in submit_resp:
emit_json({
"error": "Submission HTTP error",
"http_status": submit_resp["_http_error"],
"body": submit_resp.get("body"),
})
return 1
if isinstance(submit_resp.get("error"), dict):
emit_json({
"error": "Workflow validation failed",
"details": submit_resp["error"],
"node_errors": submit_resp.get("node_errors"),
})
return 1
prompt_id = submit_resp.get("prompt_id")
if not prompt_id:
emit_json({"error": "No prompt_id in submit response", "response": submit_resp})
return 1
node_errors = submit_resp.get("node_errors") or {}
if node_errors:
emit_json({"error": "Workflow validation failed", "node_errors": node_errors})
return 1
if args.submit_only:
emit_json({"status": "submitted", "prompt_id": prompt_id, "warnings": warnings})
return 0
# ---- Wait ----
timeout = args.timeout
if timeout <= 0:
timeout = 900 if looks_like_video_workflow(workflow) else 300
log(f"Submitted: prompt_id={prompt_id}, waiting (timeout={timeout}s)…")
def _on_progress(evt: dict) -> None:
t = evt.get("type")
if t == "progress":
log(f" step {evt.get('value')}/{evt.get('max')} on node {evt.get('node')}")
elif t == "executing":
node = evt.get("node")
if node:
log(f" executing node {node}")
try:
if args.ws:
wait_result = runner.monitor_ws(prompt_id, timeout=timeout, on_progress=_on_progress)
else:
wait_result = runner.poll_status(prompt_id, timeout=timeout)
except KeyboardInterrupt:
log(f"Interrupted — cancelling job {prompt_id} on server…")
try:
runner.cancel(prompt_id)
except Exception as e:
log(f" (cancel request failed: {e})")
emit_json({
"status": "interrupted",
"prompt_id": prompt_id,
"note": "Ctrl+C received; sent cancellation to server.",
})
return 130
if wait_result["status"] == "timeout":
emit_json({
"status": "timeout",
"prompt_id": prompt_id,
"elapsed": wait_result.get("elapsed"),
"hint": "Re-run with larger --timeout, or use --submit-only and check later.",
})
return 1
if wait_result["status"] == "error":
emit_json({"status": "error", "prompt_id": prompt_id, "details": wait_result.get("data")})
return 1
if wait_result["status"] == "cancelled":
emit_json({"status": "cancelled", "prompt_id": prompt_id})
return 1
# ---- Outputs ----
outputs = wait_result.get("outputs")
if not outputs:
outputs = runner.get_outputs(prompt_id)
if args.no_download:
emit_json({
"status": "success", "prompt_id": prompt_id,
"outputs": outputs, "warnings": warnings,
})
return 0
downloaded = download_outputs(
runner, outputs, Path(args.output_dir).expanduser(),
preserve_subfolder=not args.flat_output, overwrite=args.overwrite,
)
emit_json({
"status": "success",
"prompt_id": prompt_id,
"outputs": downloaded,
"warnings": warnings,
})
return 0
if __name__ == "__main__":
sys.exit(main())