Compare commits

...

8 Commits

Author SHA1 Message Date
dmahan93
be43bee11a final changes from successful run 2026-04-22 14:57:57 -05:00
dmahan93
721e0b96cd add length eviction if no compression 2026-04-16 01:10:11 -05:00
dmahan93
d988343570 fixup some compression stuff 2026-04-14 00:22:52 -05:00
dmahan93
43dee2e1cf update for rl overrides 2026-04-14 00:16:04 -05:00
dmahan93
637a214820 fix: token ID extraction bugs in run_agent.py
- hasattr() returns bool, not None — changed 'is not None' to proper check
- Fixed variable name typo: assistant_msg -> assistant_message
- Trajectory format: use 'in' dict check instead of hasattr on dicts
2026-04-04 14:59:18 -05:00
dmahan93
f168a4f1bf add prompt_tokens/ generation logprobs to run_agent 2026-04-04 13:35:42 -05:00
dmahan93
6442255f83 clean up agent_loop.py: remove debug print and dead comments 2026-04-03 18:11:26 -05:00
dmahan93
44371a9bbb add nemo gym support 2026-04-03 18:02:08 -05:00
3 changed files with 252 additions and 48 deletions

View File

@@ -193,6 +193,10 @@ class HermesAgentLoop:
import time as _time import time as _time
prompt_token_ids = None
generation_token_ids = None
generation_log_probs = None
for turn in range(self.max_turns): for turn in range(self.max_turns):
turn_start = _time.monotonic() turn_start = _time.monotonic()
@@ -246,6 +250,12 @@ class HermesAgentLoop:
) )
assistant_msg = response.choices[0].message assistant_msg = response.choices[0].message
if hasattr(assistant_msg, "prompt_token_ids"):
prompt_token_ids = assistant_msg.prompt_token_ids
if hasattr(assistant_msg, "generation_token_ids"):
generation_token_ids = assistant_msg.generation_token_ids
if hasattr(assistant_msg, "generation_log_probs"):
generation_log_probs = assistant_msg.generation_log_probs
# Extract reasoning content from the response (all provider formats) # Extract reasoning content from the response (all provider formats)
reasoning = _extract_reasoning_from_message(assistant_msg) reasoning = _extract_reasoning_from_message(assistant_msg)
@@ -308,7 +318,10 @@ class HermesAgentLoop:
"content": assistant_msg.content or "", "content": assistant_msg.content or "",
"tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls], "tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls],
} }
if prompt_token_ids is not None:
msg_dict["prompt_token_ids"] = prompt_token_ids
msg_dict["generation_token_ids"] = generation_token_ids
msg_dict["generation_log_probs"] = generation_log_probs
# Preserve reasoning_content for multi-turn chat template handling # Preserve reasoning_content for multi-turn chat template handling
# (e.g., Kimi-K2's template renders <think> blocks differently # (e.g., Kimi-K2's template renders <think> blocks differently
# for history vs. the latest turn based on this field) # for history vs. the latest turn based on this field)
@@ -471,6 +484,10 @@ class HermesAgentLoop:
} }
if reasoning: if reasoning:
msg_dict["reasoning_content"] = reasoning msg_dict["reasoning_content"] = reasoning
if prompt_token_ids is not None:
msg_dict["prompt_token_ids"] = prompt_token_ids
msg_dict["generation_token_ids"] = generation_token_ids
msg_dict["generation_log_probs"] = generation_log_probs
messages.append(msg_dict) messages.append(msg_dict)
turn_elapsed = _time.monotonic() - turn_start turn_elapsed = _time.monotonic() - turn_start

View File

@@ -0,0 +1,144 @@
#!/usr/bin/env python3
"""
Quick compatibility check: connect to a local OpenAI-compatible endpoint
and run a single agent turn via HermesAgentLoop with all standard tools.
Usage:
python environments/check_gym_compat.py # auto-detect model
python environments/check_gym_compat.py --model my-model # explicit model
python environments/check_gym_compat.py --base-url http://... --model ...
"""
import asyncio
import argparse
import json
import logging
import sys
from pathlib import Path
# Ensure repo root is on sys.path when run as a standalone script
_repo_root = str(Path(__file__).resolve().parent.parent)
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
import requests
from openai import AsyncOpenAI
from environments.agent_loop import HermesAgentLoop, AgentResult
from model_tools import get_tool_definitions
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Thin server wrapper — gives HermesAgentLoop the chat_completion() it wants
# ---------------------------------------------------------------------------
class OpenAIServer:
"""Minimal async server wrapping an OpenAI-compatible endpoint."""
def __init__(self, base_url: str, model: str, api_key: str = "dummy"):
self.model = model
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key)
async def chat_completion(self, **kwargs):
kwargs.setdefault("model", self.model)
return await self.client.chat.completions.create(**kwargs)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def detect_model(base_url: str) -> str:
try:
resp = requests.get(f"{base_url}/models", timeout=10)
resp.raise_for_status()
models = resp.json().get("data", [])
if not models:
print("WARNING: /v1/models returned no models")
return "default"
model_id = models[0]["id"]
print(f"Auto-detected model: {model_id}")
return model_id
except Exception as e:
print(f"Could not auto-detect model ({e}), falling back to 'default'")
return "default"
async def run_check(base_url: str, model: str, message: str) -> AgentResult:
server = OpenAIServer(base_url=base_url, model=model)
# Get all default hermes tools
tool_schemas = get_tool_definitions(quiet_mode=False)
valid_names = {t["function"]["name"] for t in tool_schemas}
agent = HermesAgentLoop(
server=server,
tool_schemas=tool_schemas,
valid_tool_names=valid_names,
max_turns=5,
)
messages = [
{"role": "system", "content": "You are a helpful assistant with access to tools."},
{"role": "user", "content": message},
]
return await agent.run(messages)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Check gym endpoint compatibility")
parser.add_argument("--base-url", default="http://127.0.0.1:11746/v1")
parser.add_argument("--model", default=None)
parser.add_argument("--message", default="Hello! What's the current directory you're in?")
args = parser.parse_args()
model = args.model or detect_model(args.base_url)
print(f"\n{'='*60}")
print(f"Endpoint: {args.base_url}")
print(f"Model: {model}")
print(f"Message: {args.message}")
print(f"{'='*60}\n")
try:
result = asyncio.run(run_check(args.base_url, model, args.message))
print(f"\n{'='*60}")
print(f"Turns used: {result.turns_used}")
print(f"Finished naturally: {result.finished_naturally}")
print(f"Tool errors: {len(result.tool_errors)}")
print(f"{'='*60}")
# Print the final assistant response
for msg in reversed(result.messages):
# if msg.get("role") == "assistant" and msg.get("content"):
# print("\nRESPONSE:")
# print(msg["content"])
# break
print(msg)
if result.tool_errors:
print("\nTOOL ERRORS:")
for err in result.tool_errors:
print(f" turn {err.turn}: {err.tool_name}{err.error}")
status = "✅ passed" if result.finished_naturally else "⚠️ hit max turns"
print(f"\nGym compatibility check {status}")
except Exception as e:
print(f"\n❌ Gym compatibility check failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -516,6 +516,9 @@ class AIAgent:
checkpoint_max_snapshots: int = 50, checkpoint_max_snapshots: int = 50,
pass_session_id: bool = False, pass_session_id: bool = False,
persist_session: bool = True, persist_session: bool = True,
use_streaming: bool = True,
temperature: float = None,
insert_reasoning: bool = True,
): ):
""" """
Initialize the AI Agent. Initialize the AI Agent.
@@ -559,11 +562,17 @@ class AIAgent:
When provided and Honcho is enabled in config, enables persistent cross-session user modeling. When provided and Honcho is enabled in config, enables persistent cross-session user modeling.
honcho_manager: Optional shared HonchoSessionManager owned by the caller. honcho_manager: Optional shared HonchoSessionManager owned by the caller.
honcho_config: Optional HonchoClientConfig corresponding to honcho_manager. honcho_config: Optional HonchoClientConfig corresponding to honcho_manager.
use_streaming (bool): Whether to use streaming for API calls (default: True)
temperature (float): Temperature for model responses (optional, uses model default if not set)
insert_reasoning (bool): Whether to insert reasoning into the API response (default: True)
""" """
_install_safe_stdio() _install_safe_stdio()
self.model = model self.model = model
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.use_streaming = use_streaming
self.temperature = temperature
self.insert_reasoning = insert_reasoning
# Shared iteration budget — parent creates, children inherit. # Shared iteration budget — parent creates, children inherit.
# Consumed by every LLM turn across parent + all subagents. # Consumed by every LLM turn across parent + all subagents.
self.iteration_budget = iteration_budget or IterationBudget(max_iterations) self.iteration_budget = iteration_budget or IterationBudget(max_iterations)
@@ -1917,6 +1926,10 @@ class AIAgent:
"value": content.rstrip() "value": content.rstrip()
}) })
if "prompt_token_ids" in msg:
trajectory[-1]["prompt_token_ids"] = msg["prompt_token_ids"]
trajectory[-1]["generation_token_ids"] = msg["generation_token_ids"]
trajectory[-1]["generation_log_probs"] = msg["generation_log_probs"]
# Collect all subsequent tool responses # Collect all subsequent tool responses
tool_responses = [] tool_responses = []
j = i + 1 j = i + 1
@@ -1978,6 +1991,10 @@ class AIAgent:
"from": "gpt", "from": "gpt",
"value": content.strip() "value": content.strip()
}) })
if "prompt_token_ids" in msg:
trajectory[-1]["prompt_token_ids"] = msg["prompt_token_ids"]
trajectory[-1]["generation_token_ids"] = msg["generation_token_ids"]
trajectory[-1]["generation_log_probs"] = msg["generation_log_probs"]
elif msg["role"] == "user": elif msg["role"] == "user":
trajectory.append({ trajectory.append({
@@ -5055,6 +5072,8 @@ class AIAgent:
"messages": sanitized_messages, "messages": sanitized_messages,
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)), "timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)),
} }
if self.temperature is not None:
api_kwargs["temperature"] = self.temperature
if self.tools: if self.tools:
api_kwargs["tools"] = self.tools api_kwargs["tools"] = self.tools
@@ -5230,6 +5249,11 @@ class AIAgent:
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
if hasattr(assistant_message, "prompt_token_ids") and assistant_message.prompt_token_ids is not None:
msg["prompt_token_ids"] = assistant_message.prompt_token_ids
msg["generation_token_ids"] = assistant_message.generation_token_ids
msg["generation_log_probs"] = assistant_message.generation_log_probs
if hasattr(assistant_message, 'reasoning_details') and assistant_message.reasoning_details: if hasattr(assistant_message, 'reasoning_details') and assistant_message.reasoning_details:
# Pass reasoning_details back unmodified so providers (OpenRouter, # Pass reasoning_details back unmodified so providers (OpenRouter,
# Anthropic, OpenAI) can maintain reasoning continuity across turns. # Anthropic, OpenAI) can maintain reasoning continuity across turns.
@@ -5377,7 +5401,7 @@ class AIAgent:
api_msg = msg.copy() api_msg = msg.copy()
if msg.get("role") == "assistant": if msg.get("role") == "assistant":
reasoning = msg.get("reasoning") reasoning = msg.get("reasoning")
if reasoning: if reasoning and self.insert_reasoning:
api_msg["reasoning_content"] = reasoning api_msg["reasoning_content"] = reasoning
api_msg.pop("reasoning", None) api_msg.pop("reasoning", None)
api_msg.pop("finish_reason", None) api_msg.pop("finish_reason", None)
@@ -6374,6 +6398,7 @@ class AIAgent:
stream_callback: Optional[callable] = None, stream_callback: Optional[callable] = None,
persist_user_message: Optional[str] = None, persist_user_message: Optional[str] = None,
sync_honcho: bool = True, sync_honcho: bool = True,
dont_review: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run a complete conversation with tool calling until completion. Run a complete conversation with tool calling until completion.
@@ -6391,7 +6416,7 @@ class AIAgent:
synthetic prefixes. synthetic prefixes.
sync_honcho: When False, skip writing the final synthetic turn back sync_honcho: When False, skip writing the final synthetic turn back
to Honcho or queuing follow-up prefetch work. to Honcho or queuing follow-up prefetch work.
dont_review: When True, skip reviewing memory and skills.
Returns: Returns:
Dict: Complete conversation result with final response and message history Dict: Complete conversation result with final response and message history
""" """
@@ -6728,7 +6753,7 @@ class AIAgent:
# This ensures multi-turn reasoning context is preserved # This ensures multi-turn reasoning context is preserved
if msg.get("role") == "assistant": if msg.get("role") == "assistant":
reasoning_text = msg.get("reasoning") reasoning_text = msg.get("reasoning")
if reasoning_text: if reasoning_text and self.insert_reasoning:
# Add reasoning_content for API compatibility (Moonshot AI, Novita, OpenRouter) # Add reasoning_content for API compatibility (Moonshot AI, Novita, OpenRouter)
api_msg["reasoning_content"] = reasoning_text api_msg["reasoning_content"] = reasoning_text
@@ -6856,7 +6881,7 @@ class AIAgent:
if self.thinking_callback: if self.thinking_callback:
self.thinking_callback("") self.thinking_callback("")
_use_streaming = True _use_streaming = self.use_streaming
if not self._has_stream_consumers(): if not self._has_stream_consumers():
# No display/TTS consumer. Still prefer streaming for # No display/TTS consumer. Still prefer streaming for
# health checking, but skip for Mock clients in tests # health checking, but skip for Mock clients in tests
@@ -7034,6 +7059,15 @@ class AIAgent:
finish_reason = response.choices[0].finish_reason finish_reason = response.choices[0].finish_reason
if finish_reason == "length": if finish_reason == "length":
if not self.compression_enabled:
return {
"final_response": None,
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"partial": True,
"error": "Response truncated due to output length limit",
}
self._vprint(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens", force=True) self._vprint(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens", force=True)
# ── Detect thinking-budget exhaustion ────────────── # ── Detect thinking-budget exhaustion ──────────────
@@ -7433,7 +7467,7 @@ class AIAgent:
or 'error code: 413' in error_msg or 'error code: 413' in error_msg
) )
if is_payload_too_large: if is_payload_too_large and self.compression_enabled:
compression_attempts += 1 compression_attempts += 1
if compression_attempts > max_compression_attempts: if compression_attempts > max_compression_attempts:
self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.", force=True) self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.", force=True)
@@ -7448,30 +7482,14 @@ class AIAgent:
"partial": True "partial": True
} }
self._emit_status(f"⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...") self._emit_status(f"⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
elif is_payload_too_large and not self.compression_enabled:
original_len = len(messages) return {
messages, active_system_prompt = self._compress_context( "messages": messages,
messages, system_message, approx_tokens=approx_tokens, "completed": False,
task_id=effective_task_id, "api_calls": api_call_count,
) "error": "Request payload too large (413). Cannot compress further.",
"partial": True
if len(messages) < original_len: }
self._emit_status(f"🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
time.sleep(2) # Brief pause between compression retries
restart_with_compressed_messages = True
break
else:
self._vprint(f"{self.log_prefix}❌ Payload too large and cannot compress further.", force=True)
self._vprint(f"{self.log_prefix} 💡 Try /new to start a fresh conversation, or /compress to retry compression.", force=True)
logging.error(f"{self.log_prefix}413 payload too large. Cannot compress further.")
self._persist_session(messages, conversation_history)
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": "Request payload too large (413). Cannot compress further.",
"partial": True
}
# Check for context-length errors BEFORE generic 4xx handler. # Check for context-length errors BEFORE generic 4xx handler.
# Local backends (LM Studio, Ollama, llama.cpp) often return # Local backends (LM Studio, Ollama, llama.cpp) often return
@@ -7507,7 +7525,7 @@ class AIAgent:
force=True, force=True,
) )
if is_context_length_error: if is_context_length_error and self.compression_enabled:
compressor = self.context_compressor compressor = self.context_compressor
old_ctx = compressor.context_length old_ctx = compressor.context_length
@@ -7576,6 +7594,14 @@ class AIAgent:
"error": f"Context length exceeded ({approx_tokens:,} tokens). Cannot compress further.", "error": f"Context length exceeded ({approx_tokens:,} tokens). Cannot compress further.",
"partial": True "partial": True
} }
elif is_context_length_error and not self.compression_enabled:
return {
"messages": messages,
"completed": False,
"api_calls": api_call_count,
"error": f"Context length exceeded ({approx_tokens:,} tokens). Cannot compress further.",
"partial": True
}
# Check for non-retryable client errors (4xx HTTP status codes). # Check for non-retryable client errors (4xx HTTP status codes).
# These indicate a problem with the request itself (bad model ID, # These indicate a problem with the request itself (bad model ID,
@@ -7789,6 +7815,9 @@ class AIAgent:
break break
try: try:
prompt_token_ids = None
generation_token_ids = None
generation_log_probs = None
if self.api_mode == "codex_responses": if self.api_mode == "codex_responses":
assistant_message, finish_reason = self._normalize_codex_response(response) assistant_message, finish_reason = self._normalize_codex_response(response)
elif self.api_mode == "anthropic_messages": elif self.api_mode == "anthropic_messages":
@@ -7798,6 +7827,12 @@ class AIAgent:
) )
else: else:
assistant_message = response.choices[0].message assistant_message = response.choices[0].message
if hasattr(assistant_message, "prompt_token_ids") and assistant_message.prompt_token_ids is not None:
prompt_token_ids = assistant_message.prompt_token_ids
if hasattr(assistant_message, "generation_token_ids") and assistant_message.generation_token_ids is not None:
generation_token_ids = assistant_message.generation_token_ids
if hasattr(assistant_message, "generation_log_probs") and assistant_message.generation_log_probs is not None:
generation_log_probs = assistant_message.generation_log_probs
# Normalize content to string — some OpenAI-compatible servers # Normalize content to string — some OpenAI-compatible servers
# (llama-server, etc.) return content as a dict or list instead # (llama-server, etc.) return content as a dict or list instead
@@ -8240,28 +8275,34 @@ class AIAgent:
self._response_was_previewed = True self._response_was_previewed = True
break break
# No fallback -- if reasoning_text exists, the model put its # No fallback -- the model kept emitting <think>...</think>
# entire response inside <think> tags; use that as the content. # with empty content for 3 retries. Preserve token IDs from
# the last API attempt (reasoning-only generation) so RL can
# train on this trajectory instead of dropping it entirely.
# Using _build_assistant_message ensures prompt_token_ids,
# generation_token_ids, and generation_log_probs are attached
# when present on the assistant_message object.
if reasoning_text: if reasoning_text:
self._vprint(f"{self.log_prefix}Using reasoning as response content (model wrapped entire response in think tags).", force=True) self._vprint(f"{self.log_prefix}Using reasoning as response content (model wrapped entire response in think tags).", force=True)
final_response = reasoning_text final_response = reasoning_text
empty_msg = {
# Preserve token IDs from the last API attempt by building the
# assistant message from the live API response object. This
# avoids the all-empty-output-items ValueError in NeMo RL's
# nemo_gym postprocessor when every turn was reasoning-only.
try:
_last_msg = self._build_assistant_message(assistant_message, finish_reason)
messages.append(_last_msg)
except Exception:
# If assistant_message is out of scope or _build fails,
# fall back to a message without token IDs (matches
# original behavior).
messages.append({
"role": "assistant", "role": "assistant",
"content": final_response, "content": final_response,
"reasoning": reasoning_text, "reasoning": reasoning_text,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} })
messages.append(empty_msg)
break
# Truly empty -- no reasoning and no content
empty_msg = {
"role": "assistant",
"content": final_response,
"reasoning": reasoning_text,
"finish_reason": finish_reason,
}
messages.append(empty_msg)
self._cleanup_task_resources(effective_task_id) self._cleanup_task_resources(effective_task_id)
self._persist_session(messages, conversation_history) self._persist_session(messages, conversation_history)
@@ -8471,7 +8512,9 @@ class AIAgent:
and "skill_manage" in self.valid_tool_names): and "skill_manage" in self.valid_tool_names):
_should_review_skills = True _should_review_skills = True
self._iters_since_skill = 0 self._iters_since_skill = 0
if dont_review:
_should_review_memory = False
_should_review_skills = False
# Background memory/skill review — runs AFTER the response is delivered # Background memory/skill review — runs AFTER the response is delivered
# so it never competes with the user's task for model attention. # so it never competes with the user's task for model attention.
if final_response and not interrupted and (_should_review_memory or _should_review_skills): if final_response and not interrupted and (_should_review_memory or _should_review_skills):