mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
|
|
"""
|
||
|
|
Image Generation Provider Registry
|
||
|
|
==================================
|
||
|
|
|
||
|
|
Central map of registered providers. Populated by plugins at import-time via
|
||
|
|
``PluginContext.register_image_gen_provider()``; consumed by the
|
||
|
|
``image_generate`` tool to dispatch each call to the active backend.
|
||
|
|
|
||
|
|
Active selection
|
||
|
|
----------------
|
||
|
|
The active provider is chosen by ``image_gen.provider`` in ``config.yaml``.
|
||
|
|
If unset, :func:`get_active_provider` applies fallback logic:
|
||
|
|
|
||
|
|
1. If exactly one provider is registered, use it.
|
||
|
|
2. Otherwise if a provider named ``fal`` is registered, use it (legacy
|
||
|
|
default — matches pre-plugin behavior).
|
||
|
|
3. Otherwise return ``None`` (the tool surfaces a helpful error pointing
|
||
|
|
the user at ``hermes tools``).
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import threading
|
||
|
|
from typing import Dict, List, Optional
|
||
|
|
|
||
|
|
from agent.image_gen_provider import ImageGenProvider
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
_providers: Dict[str, ImageGenProvider] = {}
|
||
|
|
_lock = threading.Lock()
|
||
|
|
|
||
|
|
|
||
|
|
def register_provider(provider: ImageGenProvider) -> None:
|
||
|
|
"""Register an image generation provider.
|
||
|
|
|
||
|
|
Re-registration (same ``name``) overwrites the previous entry and logs
|
||
|
|
a debug message — this makes hot-reload scenarios (tests, dev loops)
|
||
|
|
behave predictably.
|
||
|
|
"""
|
||
|
|
if not isinstance(provider, ImageGenProvider):
|
||
|
|
raise TypeError(
|
||
|
|
f"register_provider() expects an ImageGenProvider instance, "
|
||
|
|
f"got {type(provider).__name__}"
|
||
|
|
)
|
||
|
|
name = provider.name
|
||
|
|
if not isinstance(name, str) or not name.strip():
|
||
|
|
raise ValueError("Image gen provider .name must be a non-empty string")
|
||
|
|
with _lock:
|
||
|
|
existing = _providers.get(name)
|
||
|
|
_providers[name] = provider
|
||
|
|
if existing is not None:
|
||
|
|
logger.debug("Image gen provider '%s' re-registered (was %r)", name, type(existing).__name__)
|
||
|
|
else:
|
||
|
|
logger.debug("Registered image gen provider '%s' (%s)", name, type(provider).__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def list_providers() -> List[ImageGenProvider]:
|
||
|
|
"""Return all registered providers, sorted by name."""
|
||
|
|
with _lock:
|
||
|
|
items = list(_providers.values())
|
||
|
|
return sorted(items, key=lambda p: p.name)
|
||
|
|
|
||
|
|
|
||
|
|
def get_provider(name: str) -> Optional[ImageGenProvider]:
|
||
|
|
"""Return the provider registered under *name*, or None."""
|
||
|
|
if not isinstance(name, str):
|
||
|
|
return None
|
||
|
|
with _lock:
|
||
|
|
return _providers.get(name.strip())
|
||
|
|
|
||
|
|
|
||
|
|
def get_active_provider() -> Optional[ImageGenProvider]:
|
||
|
|
"""Resolve the currently-active provider.
|
||
|
|
|
||
|
|
Reads ``image_gen.provider`` from config.yaml; falls back per the
|
||
|
|
module docstring.
|
||
|
|
"""
|
||
|
|
configured: Optional[str] = None
|
||
|
|
try:
|
||
|
|
from hermes_cli.config import load_config
|
||
|
|
|
||
|
|
cfg = load_config()
|
||
|
|
section = cfg.get("image_gen") if isinstance(cfg, dict) else None
|
||
|
|
if isinstance(section, dict):
|
||
|
|
raw = section.get("provider")
|
||
|
|
if isinstance(raw, str) and raw.strip():
|
||
|
|
configured = raw.strip()
|
||
|
|
except Exception as exc:
|
||
|
|
logger.debug("Could not read image_gen.provider from config: %s", exc)
|
||
|
|
|
||
|
|
with _lock:
|
||
|
|
snapshot = dict(_providers)
|
||
|
|
|
||
|
|
if configured:
|
||
|
|
provider = snapshot.get(configured)
|
||
|
|
if provider is not None:
|
||
|
|
return provider
|
||
|
|
logger.debug(
|
||
|
|
"image_gen.provider='%s' configured but not registered; falling back",
|
||
|
|
configured,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Fallback: single-provider case
|
||
|
|
if len(snapshot) == 1:
|
||
|
|
return next(iter(snapshot.values()))
|
||
|
|
|
||
|
|
# Fallback: prefer legacy FAL for backward compat
|
||
|
|
if "fal" in snapshot:
|
||
|
|
return snapshot["fal"]
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _reset_for_tests() -> None:
|
||
|
|
"""Clear the registry. **Test-only.**"""
|
||
|
|
with _lock:
|
||
|
|
_providers.clear()
|