Compare commits

...

1 Commits

Author SHA1 Message Date
Shannon Sands
ae6435f787 Env robustness: context-safe prompting + tool arg normalization
- Preserve full trajectory while truncating prompt view per turn (avoids context overflow)
- Add max_context_tokens support and wire from env config
- Normalize tool call arguments robustly (dict / stringified JSON / plain string)
- Avoid double-encoding tool arguments in Hermes parser
- Add tool-call metrics to AgentResult for debugging/optional shaping

Scope: environments/* only
2026-02-14 13:13:00 +10:00
3 changed files with 219 additions and 39 deletions

View File

@@ -73,6 +73,12 @@ class AgentResult:
# Tool errors encountered during the loop
tool_errors: List[ToolError] = field(default_factory=list)
# Tool-call metrics (debugging / optional reward shaping)
tool_calls_attempted: int = 0
tool_calls_schema_valid: int = 0
tool_calls_executed_ok: int = 0
tool_calls_exec_error: int = 0
def _extract_reasoning_from_message(message) -> Optional[str]:
"""
@@ -136,6 +142,8 @@ class HermesAgentLoop:
temperature: float = 1.0,
max_tokens: Optional[int] = None,
extra_body: Optional[Dict[str, Any]] = None,
tool_handler=None,
max_context_tokens: Optional[int] = None,
):
"""
Initialize the agent loop.
@@ -152,6 +160,13 @@ class HermesAgentLoop:
extra_body: Extra parameters passed to the OpenAI client's create() call.
Used for OpenRouter provider preferences, transforms, etc.
e.g. {"provider": {"ignore": ["DeepInfra"]}}
tool_handler: Optional async callable(tool_name, args, task_id) -> str.
When provided, used INSTEAD of handle_function_call() for
tool dispatch. This allows sandbox backends (Modal, Nomad)
to route tool calls through their slot-based execution.
max_context_tokens: Maximum prompt tokens before truncation.
If None, no truncation is applied.
Recommended: set to max_model_len - max_tokens - 512 (safety margin).
"""
self.server = server
self.tool_schemas = tool_schemas
@@ -161,6 +176,123 @@ class HermesAgentLoop:
self.temperature = temperature
self.max_tokens = max_tokens
self.extra_body = extra_body
self.tool_handler = tool_handler
self.max_context_tokens = max_context_tokens
def _truncate_context(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Truncate conversation history to fit within max_context_tokens.
Strategy:
- Keep system message (index 0) and initial user message (index 1) always
- Keep last 6 messages (recent context) always
- For everything in between, progressively truncate tool result content
- If still too long, drop oldest middle messages entirely
Uses rough char/4 token estimate (fast, no tokenizer needed).
NOTE: This function mutates the provided list (it may pop/replace entries).
Call it on a copy when you want to preserve the full trajectory.
"""
if self.max_context_tokens is None:
return messages
def estimate_tokens(msgs):
total = 0
for m in msgs:
content = m.get("content", "") or ""
total += len(content) // 4 + 10 # ~4 chars per token + overhead
if "tool_calls" in m:
total += 50 * len(m["tool_calls"]) # tool call overhead
return total
if estimate_tokens(messages) <= self.max_context_tokens:
return messages
protect_head = 2
protect_tail = max(0, min(6, len(messages) - protect_head))
middle_start = protect_head
middle_end = len(messages) - protect_tail
# Phase 1: truncate tool outputs in the middle
if middle_start < middle_end:
for i in range(middle_start, middle_end):
if messages[i].get("role") == "tool":
content = messages[i].get("content", "") or ""
if len(content) > 200:
messages[i] = dict(messages[i])
messages[i]["content"] = content[:100] + "\n...[truncated]...\n" + content[-50:]
if estimate_tokens(messages) <= self.max_context_tokens:
return messages
# Phase 2: drop oldest middle messages (try to keep assistant+tool pairs)
while middle_start < middle_end and estimate_tokens(messages) > self.max_context_tokens:
msg = messages[middle_start]
messages.pop(middle_start)
middle_end -= 1
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_ids = {
tc.get("id") or tc.get("tool_call_id", "")
for tc in msg.get("tool_calls", [])
if isinstance(tc, dict)
}
i = middle_start
while i < middle_end:
if messages[i].get("role") == "tool" and messages[i].get("tool_call_id", "") in tool_ids:
messages.pop(i)
middle_end -= 1
else:
i += 1
return messages
def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool):
"""Normalize tool arguments into a dict.
Returns: (args_dict, schema_valid)
schema_valid is True only when arguments decode directly into a dict
(no double-decoding and no coercion/wrapping required).
Goal: keep environments robust (never crash on args format drift) while
still allowing reward functions to penalize malformed formats if desired.
"""
try:
decoded = json.loads(tool_args_raw)
except json.JSONDecodeError:
# Not JSON at all — treat as a plain string
if tool_name == "terminal":
return {"command": tool_args_raw}, False
return {"input": tool_args_raw}, False
if isinstance(decoded, dict):
if tool_name == "terminal":
cmd = decoded.get("command")
if isinstance(cmd, str) and cmd.strip():
return decoded, True
if isinstance(decoded.get("input"), str):
return {"command": decoded.get("input")}, False
return decoded, False
return decoded, True
if isinstance(decoded, str):
s = decoded.strip()
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
try:
decoded2 = json.loads(s)
except json.JSONDecodeError:
decoded2 = None
if isinstance(decoded2, dict):
return decoded2, False
if tool_name == "terminal":
return {"command": decoded}, False
return {"input": decoded}, False
if tool_name == "terminal":
return {"command": str(decoded)}, False
return {"input": decoded}, False
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
"""
@@ -176,14 +308,22 @@ class HermesAgentLoop:
reasoning_per_turn = []
tool_errors: List[ToolError] = []
tool_calls_attempted = 0
tool_calls_schema_valid = 0
tool_calls_executed_ok = 0
tool_calls_exec_error = 0
import time as _time
for turn in range(self.max_turns):
turn_start = _time.monotonic()
# Truncate prompt view on a copy (preserve full trajectory in `messages`)
prompt_messages = self._truncate_context(list(messages))
# Build the chat_completion kwargs
chat_kwargs = {
"messages": messages,
"messages": prompt_messages,
"n": 1,
"temperature": self.temperature,
}
@@ -215,6 +355,10 @@ class HermesAgentLoop:
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
tool_calls_attempted=tool_calls_attempted,
tool_calls_schema_valid=tool_calls_schema_valid,
tool_calls_executed_ok=tool_calls_executed_ok,
tool_calls_exec_error=tool_calls_exec_error,
)
api_elapsed = _time.monotonic() - api_start
@@ -228,6 +372,10 @@ class HermesAgentLoop:
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
tool_calls_attempted=tool_calls_attempted,
tool_calls_schema_valid=tool_calls_schema_valid,
tool_calls_executed_ok=tool_calls_executed_ok,
tool_calls_exec_error=tool_calls_exec_error,
)
assistant_msg = response.choices[0].message
@@ -270,6 +418,7 @@ class HermesAgentLoop:
# Validate tool name
if tool_name not in self.valid_tool_names:
tool_calls_exec_error += 1
tool_result = json.dumps(
{
"error": f"Unknown tool '{tool_name}'. "
@@ -287,35 +436,35 @@ class HermesAgentLoop:
tool_name, turn + 1,
)
else:
# Parse arguments and dispatch
try:
args = json.loads(tool_args_raw)
except json.JSONDecodeError:
args = {}
logger.warning(
"Invalid JSON in tool call arguments for '%s': %s",
tool_name, tool_args_raw[:200],
)
tool_calls_attempted += 1
args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw)
if schema_valid:
tool_calls_schema_valid += 1
try:
if tool_name == "terminal":
backend = os.getenv("TERMINAL_ENV", "local")
cmd_preview = args.get("command", "")[:80]
cmd_preview = str(args.get("command", ""))[:80]
logger.info(
"[%s] $ %s", self.task_id[:8], cmd_preview,
)
# Run tool calls in a thread pool so backends that use
# asyncio.run() internally (modal, docker) get a clean
# event loop instead of deadlocking inside Atropos's loop.
tool_submit_time = _time.monotonic()
loop = asyncio.get_event_loop()
tool_result = await loop.run_in_executor(
_tool_executor,
lambda: handle_function_call(
tool_name, args, task_id=self.task_id
),
)
if self.tool_handler:
tool_result = await self.tool_handler(tool_name, args, self.task_id)
else:
# Run tool calls in a thread pool so backends that use
# asyncio.run() internally (modal, docker) get a clean
# event loop instead of deadlocking inside Atropos's loop.
loop = asyncio.get_event_loop()
tool_result = await loop.run_in_executor(
_tool_executor,
lambda: handle_function_call(
tool_name, args, task_id=self.task_id
),
)
tool_elapsed = _time.monotonic() - tool_submit_time
# Log slow tools and thread pool stats for debugging
@@ -327,6 +476,7 @@ class HermesAgentLoop:
tool_elapsed, pool_active,
)
except Exception as e:
tool_calls_exec_error += 1
tool_result = json.dumps(
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
)
@@ -340,22 +490,31 @@ class HermesAgentLoop:
"Tool '%s' execution failed on turn %d: %s",
tool_name, turn + 1, e,
)
else:
tool_err = False
try:
result_data = json.loads(tool_result)
if isinstance(result_data, dict):
err = result_data.get("error")
if err:
tool_err = True
# Also check if the tool returned an error in its JSON result
try:
result_data = json.loads(tool_result)
if isinstance(result_data, dict):
err = result_data.get("error")
exit_code = result_data.get("exit_code")
if err and exit_code and exit_code < 0:
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=str(err),
tool_result=tool_result[:500],
))
except (json.JSONDecodeError, TypeError):
pass
exit_code = result_data.get("exit_code")
if exit_code is not None and isinstance(exit_code, int) and exit_code < 0:
tool_err = True
tool_errors.append(ToolError(
turn=turn + 1, tool_name=tool_name,
arguments=tool_args_raw[:200],
error=str(err) if err else "nonzero exit_code",
tool_result=tool_result[:500],
))
except (json.JSONDecodeError, TypeError):
pass
if tool_err:
tool_calls_exec_error += 1
else:
tool_calls_executed_ok += 1
# Add tool response to conversation
messages.append(
@@ -396,6 +555,10 @@ class HermesAgentLoop:
finished_naturally=True,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
tool_calls_attempted=tool_calls_attempted,
tool_calls_schema_valid=tool_calls_schema_valid,
tool_calls_executed_ok=tool_calls_executed_ok,
tool_calls_exec_error=tool_calls_exec_error,
)
# Hit max turns without the model stopping
@@ -407,6 +570,10 @@ class HermesAgentLoop:
finished_naturally=False,
reasoning_per_turn=reasoning_per_turn,
tool_errors=tool_errors,
tool_calls_attempted=tool_calls_attempted,
tool_calls_schema_valid=tool_calls_schema_valid,
tool_calls_executed_ok=tool_calls_executed_ok,
tool_calls_exec_error=tool_calls_exec_error,
)
def _get_managed_state(self) -> Optional[Dict[str, Any]]:

View File

@@ -478,6 +478,7 @@ class HermesAgentBaseEnv(BaseEnv):
tokenizer=self.tokenizer,
tool_call_parser=tc_parser,
) as managed:
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
agent = HermesAgentLoop(
server=managed,
tool_schemas=tools,
@@ -487,6 +488,7 @@ class HermesAgentBaseEnv(BaseEnv):
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
max_context_tokens=_max_ctx,
)
result = await agent.run(messages)
except NotImplementedError:
@@ -495,6 +497,7 @@ class HermesAgentBaseEnv(BaseEnv):
"ManagedServer not available (OpenAI server?). "
"Falling back to direct server mode."
)
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
@@ -504,10 +507,12 @@ class HermesAgentBaseEnv(BaseEnv):
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
max_context_tokens=_max_ctx,
)
result = await agent.run(messages)
else:
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
@@ -517,6 +522,7 @@ class HermesAgentBaseEnv(BaseEnv):
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
max_context_tokens=_max_ctx,
)
result = await agent.run(messages)

View File

@@ -49,15 +49,22 @@ class HermesToolCallParser(ToolCallParser):
continue
tc_data = json.loads(raw_json)
# Handle arguments: could be dict or already a JSON string
raw_args = tc_data.get("arguments", {})
if isinstance(raw_args, str):
# Already a string — pass through as-is.
# It may be a JSON string ("{...}") or a plain string ("ls").
args_str = raw_args
else:
# Dict — serialize to JSON
args_str = json.dumps(raw_args, ensure_ascii=False)
tool_calls.append(
ChatCompletionMessageToolCall(
id=f"call_{uuid.uuid4().hex[:8]}",
type="function",
function=Function(
name=tc_data["name"],
arguments=json.dumps(
tc_data.get("arguments", {}), ensure_ascii=False
),
arguments=args_str,
),
)
)