mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-30 16:01:49 +08:00
Compare commits
1 Commits
fix/plugin
...
hermes/her
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bd72285907 |
@@ -2039,6 +2039,66 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return SendResult(success=False, error=str(e))
|
return SendResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
async def send_model_picker(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
providers: list,
|
||||||
|
current_model: str,
|
||||||
|
current_provider: str,
|
||||||
|
session_key: str,
|
||||||
|
on_model_selected,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> SendResult:
|
||||||
|
"""Send an interactive select-menu model picker.
|
||||||
|
|
||||||
|
Two-step drill-down: provider dropdown → model dropdown.
|
||||||
|
Uses Discord embeds + Select menus via ``ModelPickerView``.
|
||||||
|
"""
|
||||||
|
if not self._client or not DISCORD_AVAILABLE:
|
||||||
|
return SendResult(success=False, error="Not connected")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Resolve target channel (use thread_id if present)
|
||||||
|
target_id = chat_id
|
||||||
|
if metadata and metadata.get("thread_id"):
|
||||||
|
target_id = metadata["thread_id"]
|
||||||
|
|
||||||
|
channel = self._client.get_channel(int(target_id))
|
||||||
|
if not channel:
|
||||||
|
channel = await self._client.fetch_channel(int(target_id))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hermes_cli.providers import get_label
|
||||||
|
provider_label = get_label(current_provider)
|
||||||
|
except Exception:
|
||||||
|
provider_label = current_provider
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="⚙ Model Configuration",
|
||||||
|
description=(
|
||||||
|
f"Current model: `{current_model or 'unknown'}`\n"
|
||||||
|
f"Provider: {provider_label}\n\n"
|
||||||
|
f"Select a provider:"
|
||||||
|
),
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
)
|
||||||
|
|
||||||
|
view = ModelPickerView(
|
||||||
|
providers=providers,
|
||||||
|
current_model=current_model,
|
||||||
|
current_provider=current_provider,
|
||||||
|
session_key=session_key,
|
||||||
|
on_model_selected=on_model_selected,
|
||||||
|
allowed_user_ids=self._allowed_user_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await channel.send(embed=embed, view=view)
|
||||||
|
return SendResult(success=True, message_id=str(msg.id))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("[%s] send_model_picker failed: %s", self.name, e)
|
||||||
|
return SendResult(success=False, error=str(e))
|
||||||
|
|
||||||
def _get_parent_channel_id(self, channel: Any) -> Optional[str]:
|
def _get_parent_channel_id(self, channel: Any) -> Optional[str]:
|
||||||
"""Return the parent channel ID for a Discord thread-like channel, if present."""
|
"""Return the parent channel ID for a Discord thread-like channel, if present."""
|
||||||
parent = getattr(channel, "parent", None)
|
parent = getattr(channel, "parent", None)
|
||||||
@@ -2530,3 +2590,219 @@ if DISCORD_AVAILABLE:
|
|||||||
self.resolved = True
|
self.resolved = True
|
||||||
for child in self.children:
|
for child in self.children:
|
||||||
child.disabled = True
|
child.disabled = True
|
||||||
|
|
||||||
|
class ModelPickerView(discord.ui.View):
|
||||||
|
"""Interactive select-menu view for model switching.
|
||||||
|
|
||||||
|
Two-step drill-down: provider dropdown → model dropdown.
|
||||||
|
Edits the original message in-place as the user navigates.
|
||||||
|
Times out after 2 minutes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
providers: list,
|
||||||
|
current_model: str,
|
||||||
|
current_provider: str,
|
||||||
|
session_key: str,
|
||||||
|
on_model_selected,
|
||||||
|
allowed_user_ids: set,
|
||||||
|
):
|
||||||
|
super().__init__(timeout=120)
|
||||||
|
self.providers = providers
|
||||||
|
self.current_model = current_model
|
||||||
|
self.current_provider = current_provider
|
||||||
|
self.session_key = session_key
|
||||||
|
self.on_model_selected = on_model_selected
|
||||||
|
self.allowed_user_ids = allowed_user_ids
|
||||||
|
self.resolved = False
|
||||||
|
self._selected_provider: str = ""
|
||||||
|
|
||||||
|
self._build_provider_select()
|
||||||
|
|
||||||
|
def _check_auth(self, interaction: discord.Interaction) -> bool:
|
||||||
|
if not self.allowed_user_ids:
|
||||||
|
return True
|
||||||
|
return str(interaction.user.id) in self.allowed_user_ids
|
||||||
|
|
||||||
|
def _build_provider_select(self):
|
||||||
|
"""Build the provider dropdown menu."""
|
||||||
|
self.clear_items()
|
||||||
|
options = []
|
||||||
|
for p in self.providers:
|
||||||
|
count = p.get("total_models", len(p.get("models", [])))
|
||||||
|
label = f"{p['name']} ({count} models)"
|
||||||
|
desc = "current" if p.get("is_current") else None
|
||||||
|
options.append(
|
||||||
|
discord.SelectOption(
|
||||||
|
label=label[:100],
|
||||||
|
value=p["slug"],
|
||||||
|
default=bool(p.get("is_current")),
|
||||||
|
description=desc,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not options:
|
||||||
|
return
|
||||||
|
|
||||||
|
select = discord.ui.Select(
|
||||||
|
placeholder="Choose a provider...",
|
||||||
|
options=options[:25],
|
||||||
|
custom_id="model_provider_select",
|
||||||
|
)
|
||||||
|
select.callback = self._on_provider_selected
|
||||||
|
self.add_item(select)
|
||||||
|
|
||||||
|
cancel_btn = discord.ui.Button(
|
||||||
|
label="Cancel", style=discord.ButtonStyle.red, custom_id="model_cancel"
|
||||||
|
)
|
||||||
|
cancel_btn.callback = self._on_cancel
|
||||||
|
self.add_item(cancel_btn)
|
||||||
|
|
||||||
|
def _build_model_select(self, provider_slug: str):
|
||||||
|
"""Build the model dropdown for a specific provider."""
|
||||||
|
self.clear_items()
|
||||||
|
provider = next(
|
||||||
|
(p for p in self.providers if p["slug"] == provider_slug), None
|
||||||
|
)
|
||||||
|
if not provider:
|
||||||
|
return
|
||||||
|
|
||||||
|
models = provider.get("models", [])
|
||||||
|
options = []
|
||||||
|
for model_id in models[:25]:
|
||||||
|
short = model_id.split("/")[-1] if "/" in model_id else model_id
|
||||||
|
options.append(
|
||||||
|
discord.SelectOption(
|
||||||
|
label=short[:100],
|
||||||
|
value=model_id[:100],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not options:
|
||||||
|
return
|
||||||
|
|
||||||
|
select = discord.ui.Select(
|
||||||
|
placeholder=f"Choose a model from {provider.get('name', provider_slug)}...",
|
||||||
|
options=options,
|
||||||
|
custom_id="model_model_select",
|
||||||
|
)
|
||||||
|
select.callback = self._on_model_selected
|
||||||
|
self.add_item(select)
|
||||||
|
|
||||||
|
back_btn = discord.ui.Button(
|
||||||
|
label="◀ Back", style=discord.ButtonStyle.grey, custom_id="model_back"
|
||||||
|
)
|
||||||
|
back_btn.callback = self._on_back
|
||||||
|
self.add_item(back_btn)
|
||||||
|
|
||||||
|
cancel_btn = discord.ui.Button(
|
||||||
|
label="Cancel", style=discord.ButtonStyle.red, custom_id="model_cancel2"
|
||||||
|
)
|
||||||
|
cancel_btn.callback = self._on_cancel
|
||||||
|
self.add_item(cancel_btn)
|
||||||
|
|
||||||
|
async def _on_provider_selected(self, interaction: discord.Interaction):
|
||||||
|
if not self._check_auth(interaction):
|
||||||
|
await interaction.response.send_message(
|
||||||
|
"You're not authorized~", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
provider_slug = interaction.data["values"][0]
|
||||||
|
self._selected_provider = provider_slug
|
||||||
|
provider = next(
|
||||||
|
(p for p in self.providers if p["slug"] == provider_slug), None
|
||||||
|
)
|
||||||
|
pname = provider.get("name", provider_slug) if provider else provider_slug
|
||||||
|
|
||||||
|
self._build_model_select(provider_slug)
|
||||||
|
|
||||||
|
total = provider.get("total_models", 0) if provider else 0
|
||||||
|
shown = min(len(provider.get("models", [])), 25) if provider else 0
|
||||||
|
extra = f"\n*{total - shown} more available — type `/model <name>` directly*" if total > shown else ""
|
||||||
|
|
||||||
|
await interaction.response.edit_message(
|
||||||
|
embed=discord.Embed(
|
||||||
|
title="⚙ Model Configuration",
|
||||||
|
description=f"Provider: **{pname}**\nSelect a model:{extra}",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
),
|
||||||
|
view=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _on_model_selected(self, interaction: discord.Interaction):
|
||||||
|
if self.resolved:
|
||||||
|
await interaction.response.send_message(
|
||||||
|
"Already resolved~", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if not self._check_auth(interaction):
|
||||||
|
await interaction.response.send_message(
|
||||||
|
"You're not authorized~", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.resolved = True
|
||||||
|
model_id = interaction.data["values"][0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result_text = await self.on_model_selected(
|
||||||
|
str(interaction.channel_id),
|
||||||
|
model_id,
|
||||||
|
self._selected_provider,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
result_text = f"Error switching model: {exc}"
|
||||||
|
|
||||||
|
self.clear_items()
|
||||||
|
await interaction.response.edit_message(
|
||||||
|
embed=discord.Embed(
|
||||||
|
title="⚙ Model Switched",
|
||||||
|
description=result_text,
|
||||||
|
color=discord.Color.green(),
|
||||||
|
),
|
||||||
|
view=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _on_back(self, interaction: discord.Interaction):
|
||||||
|
if not self._check_auth(interaction):
|
||||||
|
await interaction.response.send_message(
|
||||||
|
"You're not authorized~", ephemeral=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._build_provider_select()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hermes_cli.providers import get_label
|
||||||
|
provider_label = get_label(self.current_provider)
|
||||||
|
except Exception:
|
||||||
|
provider_label = self.current_provider
|
||||||
|
|
||||||
|
await interaction.response.edit_message(
|
||||||
|
embed=discord.Embed(
|
||||||
|
title="⚙ Model Configuration",
|
||||||
|
description=(
|
||||||
|
f"Current model: `{self.current_model or 'unknown'}`\n"
|
||||||
|
f"Provider: {provider_label}\n\n"
|
||||||
|
f"Select a provider:"
|
||||||
|
),
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
),
|
||||||
|
view=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _on_cancel(self, interaction: discord.Interaction):
|
||||||
|
self.resolved = True
|
||||||
|
self.clear_items()
|
||||||
|
await interaction.response.edit_message(
|
||||||
|
embed=discord.Embed(
|
||||||
|
title="⚙ Model Configuration",
|
||||||
|
description="Model selection cancelled.",
|
||||||
|
color=discord.Color.greyple(),
|
||||||
|
),
|
||||||
|
view=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_timeout(self):
|
||||||
|
self.resolved = True
|
||||||
|
self.clear_items()
|
||||||
|
|||||||
@@ -151,6 +151,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||||||
self._dm_topics: Dict[str, int] = {}
|
self._dm_topics: Dict[str, int] = {}
|
||||||
# DM Topics config from extra.dm_topics
|
# DM Topics config from extra.dm_topics
|
||||||
self._dm_topics_config: List[Dict[str, Any]] = self.config.extra.get("dm_topics", [])
|
self._dm_topics_config: List[Dict[str, Any]] = self.config.extra.get("dm_topics", [])
|
||||||
|
# Interactive model picker state per chat
|
||||||
|
self._model_picker_state: Dict[str, dict] = {}
|
||||||
|
|
||||||
def _fallback_ips(self) -> list[str]:
|
def _fallback_ips(self) -> list[str]:
|
||||||
"""Return validated fallback IPs from config (populated by _apply_env_overrides)."""
|
"""Return validated fallback IPs from config (populated by _apply_env_overrides)."""
|
||||||
@@ -1008,14 +1010,252 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||||||
logger.warning("[%s] send_update_prompt failed: %s", self.name, e)
|
logger.warning("[%s] send_update_prompt failed: %s", self.name, e)
|
||||||
return SendResult(success=False, error=str(e))
|
return SendResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
async def send_model_picker(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
providers: list,
|
||||||
|
current_model: str,
|
||||||
|
current_provider: str,
|
||||||
|
session_key: str,
|
||||||
|
on_model_selected,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> SendResult:
|
||||||
|
"""Send an interactive inline-keyboard model picker.
|
||||||
|
|
||||||
|
Two-step drill-down: provider selection → model selection.
|
||||||
|
Edits the same message in-place as the user navigates.
|
||||||
|
"""
|
||||||
|
if not self._bot:
|
||||||
|
return SendResult(success=False, error="Not connected")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hermes_cli.providers import get_label
|
||||||
|
except ImportError:
|
||||||
|
def get_label(slug):
|
||||||
|
return slug
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build provider buttons — 2 per row
|
||||||
|
buttons: list = []
|
||||||
|
for p in providers:
|
||||||
|
count = p.get("total_models", len(p.get("models", [])))
|
||||||
|
label = f"{p['name']} ({count})"
|
||||||
|
if p.get("is_current"):
|
||||||
|
label = f"✓ {label}"
|
||||||
|
# Compact callback data: mp:<slug> (max 64 bytes)
|
||||||
|
buttons.append(
|
||||||
|
InlineKeyboardButton(label, callback_data=f"mp:{p['slug']}")
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = [buttons[i : i + 2] for i in range(0, len(buttons), 2)]
|
||||||
|
rows.append([InlineKeyboardButton("✗ Cancel", callback_data="mx")])
|
||||||
|
keyboard = InlineKeyboardMarkup(rows)
|
||||||
|
|
||||||
|
provider_label = get_label(current_provider)
|
||||||
|
text = (
|
||||||
|
f"⚙ *Model Configuration*\n\n"
|
||||||
|
f"Current model: `{current_model or 'unknown'}`\n"
|
||||||
|
f"Provider: {provider_label}\n\n"
|
||||||
|
f"Select a provider:"
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_id = metadata.get("thread_id") if metadata else None
|
||||||
|
msg = await self._bot.send_message(
|
||||||
|
chat_id=int(chat_id),
|
||||||
|
text=text,
|
||||||
|
parse_mode=ParseMode.MARKDOWN,
|
||||||
|
reply_markup=keyboard,
|
||||||
|
message_thread_id=int(thread_id) if thread_id else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store picker state keyed by chat_id
|
||||||
|
self._model_picker_state[str(chat_id)] = {
|
||||||
|
"msg_id": msg.message_id,
|
||||||
|
"providers": providers,
|
||||||
|
"session_key": session_key,
|
||||||
|
"on_model_selected": on_model_selected,
|
||||||
|
"current_model": current_model,
|
||||||
|
"current_provider": current_provider,
|
||||||
|
}
|
||||||
|
|
||||||
|
return SendResult(success=True, message_id=str(msg.message_id))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("[%s] send_model_picker failed: %s", self.name, e)
|
||||||
|
return SendResult(success=False, error=str(e))
|
||||||
|
|
||||||
|
async def _handle_model_picker_callback(
|
||||||
|
self, query, data: str, chat_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Handle model picker inline keyboard callbacks (mp:/mm:/mb:/mx:)."""
|
||||||
|
state = self._model_picker_state.get(chat_id)
|
||||||
|
if not state:
|
||||||
|
await query.answer(text="Picker expired — use /model again.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hermes_cli.providers import get_label
|
||||||
|
except ImportError:
|
||||||
|
def get_label(slug):
|
||||||
|
return slug
|
||||||
|
|
||||||
|
if data.startswith("mp:"):
|
||||||
|
# --- Provider selected: show model buttons ---
|
||||||
|
provider_slug = data[3:]
|
||||||
|
provider = next(
|
||||||
|
(p for p in state["providers"] if p["slug"] == provider_slug),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not provider:
|
||||||
|
await query.answer(text="Provider not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
models = provider.get("models", [])
|
||||||
|
state["selected_provider"] = provider_slug
|
||||||
|
state["selected_provider_name"] = provider.get("name", provider_slug)
|
||||||
|
state["model_list"] = models
|
||||||
|
|
||||||
|
buttons: list = []
|
||||||
|
for i, model_id in enumerate(models):
|
||||||
|
# Short display label: strip vendor prefix
|
||||||
|
short = model_id.split("/")[-1] if "/" in model_id else model_id
|
||||||
|
# Truncate long model names for button label (max ~40 chars)
|
||||||
|
if len(short) > 38:
|
||||||
|
short = short[:35] + "..."
|
||||||
|
buttons.append(
|
||||||
|
InlineKeyboardButton(short, callback_data=f"mm:{i}")
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = [buttons[i : i + 2] for i in range(0, len(buttons), 2)]
|
||||||
|
rows.append([
|
||||||
|
InlineKeyboardButton("◀ Back", callback_data="mb"),
|
||||||
|
InlineKeyboardButton("✗ Cancel", callback_data="mx"),
|
||||||
|
])
|
||||||
|
keyboard = InlineKeyboardMarkup(rows)
|
||||||
|
|
||||||
|
pname = provider.get("name", provider_slug)
|
||||||
|
total = provider.get("total_models", len(models))
|
||||||
|
shown = len(models)
|
||||||
|
extra = f"\n_{total - shown} more available — type `/model <name>` directly_" if total > shown else ""
|
||||||
|
|
||||||
|
await query.edit_message_text(
|
||||||
|
text=(
|
||||||
|
f"⚙ *Model Configuration*\n\n"
|
||||||
|
f"Provider: *{pname}*\n"
|
||||||
|
f"Select a model:{extra}"
|
||||||
|
),
|
||||||
|
parse_mode=ParseMode.MARKDOWN,
|
||||||
|
reply_markup=keyboard,
|
||||||
|
)
|
||||||
|
await query.answer()
|
||||||
|
|
||||||
|
elif data.startswith("mm:"):
|
||||||
|
# --- Model selected: perform the switch ---
|
||||||
|
try:
|
||||||
|
idx = int(data[3:])
|
||||||
|
except ValueError:
|
||||||
|
await query.answer(text="Invalid selection.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_list = state.get("model_list", [])
|
||||||
|
if idx < 0 or idx >= len(model_list):
|
||||||
|
await query.answer(text="Invalid model index.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_id = model_list[idx]
|
||||||
|
provider_slug = state.get("selected_provider", "")
|
||||||
|
callback = state.get("on_model_selected")
|
||||||
|
|
||||||
|
if not callback:
|
||||||
|
await query.answer(text="Picker expired.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result_text = await callback(chat_id, model_id, provider_slug)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Model picker switch failed: %s", exc)
|
||||||
|
result_text = f"Error switching model: {exc}"
|
||||||
|
|
||||||
|
# Edit message to show confirmation, remove buttons
|
||||||
|
try:
|
||||||
|
await query.edit_message_text(
|
||||||
|
text=result_text,
|
||||||
|
parse_mode=ParseMode.MARKDOWN,
|
||||||
|
reply_markup=None,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Markdown parse failure — retry as plain text
|
||||||
|
try:
|
||||||
|
await query.edit_message_text(
|
||||||
|
text=result_text,
|
||||||
|
parse_mode=None,
|
||||||
|
reply_markup=None,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await query.answer(text="Model switched!")
|
||||||
|
|
||||||
|
# Clean up state
|
||||||
|
self._model_picker_state.pop(chat_id, None)
|
||||||
|
|
||||||
|
elif data == "mb":
|
||||||
|
# --- Back to provider list ---
|
||||||
|
buttons = []
|
||||||
|
for p in state["providers"]:
|
||||||
|
count = p.get("total_models", len(p.get("models", [])))
|
||||||
|
label = f"{p['name']} ({count})"
|
||||||
|
if p.get("is_current"):
|
||||||
|
label = f"✓ {label}"
|
||||||
|
buttons.append(
|
||||||
|
InlineKeyboardButton(label, callback_data=f"mp:{p['slug']}")
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = [buttons[i : i + 2] for i in range(0, len(buttons), 2)]
|
||||||
|
rows.append([InlineKeyboardButton("✗ Cancel", callback_data="mx")])
|
||||||
|
keyboard = InlineKeyboardMarkup(rows)
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_label = get_label(state["current_provider"])
|
||||||
|
except Exception:
|
||||||
|
provider_label = state["current_provider"]
|
||||||
|
|
||||||
|
await query.edit_message_text(
|
||||||
|
text=(
|
||||||
|
f"⚙ *Model Configuration*\n\n"
|
||||||
|
f"Current model: `{state['current_model'] or 'unknown'}`\n"
|
||||||
|
f"Provider: {provider_label}\n\n"
|
||||||
|
f"Select a provider:"
|
||||||
|
),
|
||||||
|
parse_mode=ParseMode.MARKDOWN,
|
||||||
|
reply_markup=keyboard,
|
||||||
|
)
|
||||||
|
await query.answer()
|
||||||
|
|
||||||
|
elif data == "mx":
|
||||||
|
# --- Cancel ---
|
||||||
|
self._model_picker_state.pop(chat_id, None)
|
||||||
|
await query.edit_message_text(
|
||||||
|
text="Model selection cancelled.",
|
||||||
|
reply_markup=None,
|
||||||
|
)
|
||||||
|
await query.answer()
|
||||||
|
|
||||||
async def _handle_callback_query(
|
async def _handle_callback_query(
|
||||||
self, update: "Update", context: "ContextTypes.DEFAULT_TYPE"
|
self, update: "Update", context: "ContextTypes.DEFAULT_TYPE"
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle inline keyboard button clicks (update prompts)."""
|
"""Handle inline keyboard button clicks."""
|
||||||
query = update.callback_query
|
query = update.callback_query
|
||||||
if not query or not query.data:
|
if not query or not query.data:
|
||||||
return
|
return
|
||||||
data = query.data
|
data = query.data
|
||||||
|
|
||||||
|
# --- Model picker callbacks ---
|
||||||
|
if data.startswith(("mp:", "mm:", "mb", "mx")):
|
||||||
|
chat_id = str(query.message.chat_id) if query.message else None
|
||||||
|
if chat_id:
|
||||||
|
await self._handle_model_picker_callback(query, data, chat_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Update prompt callbacks ---
|
||||||
if not data.startswith("update_prompt:"):
|
if not data.startswith("update_prompt:"):
|
||||||
return
|
return
|
||||||
answer = data.split(":", 1)[1] # "y" or "n"
|
answer = data.split(":", 1)[1] # "y" or "n"
|
||||||
|
|||||||
116
gateway/run.py
116
gateway/run.py
@@ -3464,11 +3464,11 @@ class GatewayRunner:
|
|||||||
lines.append(f"_(Requested page {requested_page} was out of range, showing page {page}.)_")
|
lines.append(f"_(Requested page {requested_page} was out of range, showing page {page}.)_")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
async def _handle_model_command(self, event: MessageEvent) -> str:
|
async def _handle_model_command(self, event: MessageEvent) -> Optional[str]:
|
||||||
"""Handle /model command — switch model for this session.
|
"""Handle /model command — switch model for this session.
|
||||||
|
|
||||||
Supports:
|
Supports:
|
||||||
/model — show current model info
|
/model — interactive picker (Telegram/Discord) or text list
|
||||||
/model <name> — switch for this session only
|
/model <name> — switch for this session only
|
||||||
/model <name> --global — switch and persist to config.yaml
|
/model <name> --global — switch and persist to config.yaml
|
||||||
/model <name> --provider <provider> — switch provider + model
|
/model <name> --provider <provider> — switch provider + model
|
||||||
@@ -3516,8 +3516,118 @@ class GatewayRunner:
|
|||||||
current_base_url = override.get("base_url", current_base_url)
|
current_base_url = override.get("base_url", current_base_url)
|
||||||
current_api_key = override.get("api_key", current_api_key)
|
current_api_key = override.get("api_key", current_api_key)
|
||||||
|
|
||||||
# No args: show authenticated providers with models
|
# No args: show interactive picker (Telegram/Discord) or text list
|
||||||
if not model_input and not explicit_provider:
|
if not model_input and not explicit_provider:
|
||||||
|
# Try interactive picker if the platform supports it
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
has_picker = (
|
||||||
|
adapter is not None
|
||||||
|
and getattr(type(adapter), "send_model_picker", None) is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_picker:
|
||||||
|
try:
|
||||||
|
providers = list_authenticated_providers(
|
||||||
|
current_provider=current_provider,
|
||||||
|
user_providers=user_provs,
|
||||||
|
max_models=8,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
providers = []
|
||||||
|
|
||||||
|
if providers:
|
||||||
|
# Build a callback closure for when the user picks a model.
|
||||||
|
# Captures self + locals needed for the switch logic.
|
||||||
|
_self = self
|
||||||
|
_session_key = session_key
|
||||||
|
_cur_model = current_model
|
||||||
|
_cur_provider = current_provider
|
||||||
|
_cur_base_url = current_base_url
|
||||||
|
_cur_api_key = current_api_key
|
||||||
|
|
||||||
|
async def _on_model_selected(
|
||||||
|
_chat_id: str, model_id: str, provider_slug: str
|
||||||
|
) -> str:
|
||||||
|
"""Perform the model switch and return confirmation text."""
|
||||||
|
result = _switch_model(
|
||||||
|
raw_input=model_id,
|
||||||
|
current_provider=_cur_provider,
|
||||||
|
current_model=_cur_model,
|
||||||
|
current_base_url=_cur_base_url,
|
||||||
|
current_api_key=_cur_api_key,
|
||||||
|
is_global=False,
|
||||||
|
explicit_provider=provider_slug,
|
||||||
|
)
|
||||||
|
if not result.success:
|
||||||
|
return f"Error: {result.error_message}"
|
||||||
|
|
||||||
|
# Update cached agent in-place
|
||||||
|
cached_entry = None
|
||||||
|
_cache_lock = getattr(_self, "_agent_cache_lock", None)
|
||||||
|
_cache = getattr(_self, "_agent_cache", None)
|
||||||
|
if _cache_lock and _cache is not None:
|
||||||
|
with _cache_lock:
|
||||||
|
cached_entry = _cache.get(_session_key)
|
||||||
|
if cached_entry and cached_entry[0] is not None:
|
||||||
|
try:
|
||||||
|
cached_entry[0].switch_model(
|
||||||
|
new_model=result.new_model,
|
||||||
|
new_provider=result.target_provider,
|
||||||
|
api_key=result.api_key,
|
||||||
|
base_url=result.base_url,
|
||||||
|
api_mode=result.api_mode,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Picker model switch failed for cached agent: %s", exc)
|
||||||
|
|
||||||
|
# Store model note + session override
|
||||||
|
if not hasattr(_self, "_pending_model_notes"):
|
||||||
|
_self._pending_model_notes = {}
|
||||||
|
_self._pending_model_notes[_session_key] = (
|
||||||
|
f"[Note: model was just switched from {_cur_model} to {result.new_model} "
|
||||||
|
f"via {result.provider_label or result.target_provider}. "
|
||||||
|
f"Adjust your self-identification accordingly.]"
|
||||||
|
)
|
||||||
|
if not hasattr(_self, "_session_model_overrides"):
|
||||||
|
_self._session_model_overrides = {}
|
||||||
|
_self._session_model_overrides[_session_key] = {
|
||||||
|
"model": result.new_model,
|
||||||
|
"provider": result.target_provider,
|
||||||
|
"api_key": result.api_key,
|
||||||
|
"base_url": result.base_url,
|
||||||
|
"api_mode": result.api_mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build confirmation text
|
||||||
|
plabel = result.provider_label or result.target_provider
|
||||||
|
lines = [f"Model switched to `{result.new_model}`"]
|
||||||
|
lines.append(f"Provider: {plabel}")
|
||||||
|
mi = result.model_info
|
||||||
|
if mi:
|
||||||
|
if mi.context_window:
|
||||||
|
lines.append(f"Context: {mi.context_window:,} tokens")
|
||||||
|
if mi.max_output:
|
||||||
|
lines.append(f"Max output: {mi.max_output:,} tokens")
|
||||||
|
if mi.has_cost_data():
|
||||||
|
lines.append(f"Cost: {mi.format_cost()}")
|
||||||
|
lines.append(f"Capabilities: {mi.format_capabilities()}")
|
||||||
|
lines.append("_(session only — use `/model <name> --global` to persist)_")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
metadata = {"thread_id": source.thread_id} if source.thread_id else None
|
||||||
|
result = await adapter.send_model_picker(
|
||||||
|
chat_id=source.chat_id,
|
||||||
|
providers=providers,
|
||||||
|
current_model=current_model,
|
||||||
|
current_provider=current_provider,
|
||||||
|
session_key=session_key,
|
||||||
|
on_model_selected=_on_model_selected,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
if result.success:
|
||||||
|
return None # Picker sent — adapter handles the response
|
||||||
|
|
||||||
|
# Fallback: text list (for platforms without picker or if picker failed)
|
||||||
provider_label = get_label(current_provider)
|
provider_label = get_label(current_provider)
|
||||||
lines = [f"Current: `{current_model or 'unknown'}` on {provider_label}", ""]
|
lines = [f"Current: `{current_model or 'unknown'}` on {provider_label}", ""]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user