mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-02 16:57:36 +08:00
fix(tools): serialize concurrent hermes_tools RPC calls from execute_code
The sandbox-side `_call()` in both the UDS and file-based transports was not thread-safe, so scripts that call tools from multiple threads (e.g. `ThreadPoolExecutor` over `terminal()`) inside a single `execute_code` run could silently receive each other's responses. Root cause: * UDS transport — a single module-level `_sock` was shared across all threads; the newline-framed protocol has no request-id; and the server-side RPC loop handles one connection serially. With concurrent callers, each thread would `sendall()` then race to `recv()` the next newline-terminated response from the shared buffer, so responses got delivered to the wrong caller. * File transport — `_seq += 1` is a non-atomic read-modify-write, so two threads could allocate the same sequence number and clobber each other's request/response files. Fix: guard `_call()` with a `threading.Lock` in the UDS case (covering send+recv), and guard `_seq` allocation with a lock in the file case. No protocol change. Regression tests cover both the generated-source level (lock is present and used) and an end-to-end concurrency test: running a sandboxed ThreadPoolExecutor of 10 `terminal()` calls against a slow mock dispatcher, asserting every caller sees its own tagged response. The test fails without the fix (10/10 mismatched, matching real-world repro) and passes with it.
This commit is contained in:
@@ -114,14 +114,30 @@ class TestHermesToolsGeneration(unittest.TestCase):
|
||||
self.assertIn("def json_parse(", src)
|
||||
self.assertIn("def shell_quote(", src)
|
||||
self.assertIn("def retry(", src)
|
||||
self.assertIn("import json, os, socket, shlex, time", src)
|
||||
self.assertIn("import json, os, socket, shlex, threading, time", src)
|
||||
|
||||
def test_file_transport_uses_tempfile_fallback_for_rpc_dir(self):
|
||||
src = generate_hermes_tools_module(["terminal"], transport="file")
|
||||
self.assertIn("import json, os, shlex, tempfile, time", src)
|
||||
self.assertIn("import json, os, shlex, tempfile, threading, time", src)
|
||||
self.assertIn("os.path.join(tempfile.gettempdir(), \"hermes_rpc\")", src)
|
||||
self.assertNotIn('os.environ.get("HERMES_RPC_DIR", "/tmp/hermes_rpc")', src)
|
||||
|
||||
def test_uds_transport_serializes_concurrent_calls(self):
|
||||
"""Regression: UDS _call() must hold a lock across send+recv so that
|
||||
concurrent tool calls from multiple threads don't interleave on the
|
||||
shared socket and receive each other's responses."""
|
||||
src = generate_hermes_tools_module(["terminal"], transport="uds")
|
||||
self.assertIn("_call_lock = threading.Lock()", src)
|
||||
self.assertIn("with _call_lock:", src)
|
||||
|
||||
def test_file_transport_serializes_seq_allocation(self):
|
||||
"""Regression: file transport _call() must allocate `_seq` under a
|
||||
lock, otherwise concurrent threads can pick the same seq and clobber
|
||||
each other's request files."""
|
||||
src = generate_hermes_tools_module(["terminal"], transport="file")
|
||||
self.assertIn("_seq_lock = threading.Lock()", src)
|
||||
self.assertIn("with _seq_lock:", src)
|
||||
|
||||
|
||||
class TestExecuteCodeRemoteTempDir(unittest.TestCase):
|
||||
def test_execute_remote_uses_backend_temp_dir_for_sandbox(self):
|
||||
@@ -226,6 +242,64 @@ print(f"file lines: {r2['total_lines']}")
|
||||
result = self._run("raise ValueError('test error')")
|
||||
self.assertEqual(result["status"], "error")
|
||||
|
||||
def test_concurrent_tool_calls_match_responses(self):
|
||||
"""Regression for the UDS RPC race: multiple threads inside the
|
||||
sandbox calling terminal() concurrently must each receive their own
|
||||
response, not another thread's.
|
||||
|
||||
Before the fix, `_sock` and the recv-loop were shared without a
|
||||
lock, so responses (written FIFO by the single-threaded server)
|
||||
got delivered to whichever client thread happened to win the
|
||||
recv() race. That surfaced as each thread seeing another thread's
|
||||
output.
|
||||
|
||||
The mock dispatcher sleeps briefly to guarantee the requests
|
||||
overlap on the socket.
|
||||
"""
|
||||
code = '''
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from hermes_tools import terminal
|
||||
|
||||
N = 10
|
||||
|
||||
def call(i):
|
||||
r = terminal(f"echo TAG-{i}")
|
||||
return i, r.get("output", "")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=N) as ex:
|
||||
results = list(ex.map(call, range(N)))
|
||||
|
||||
mismatches = [(i, out) for i, out in results if f"TAG-{i}" not in out]
|
||||
if mismatches:
|
||||
print(f"MISMATCH {len(mismatches)}/{N}: {mismatches[:3]}")
|
||||
else:
|
||||
print(f"OK {N}/{N}")
|
||||
'''
|
||||
|
||||
def slow_mock(function_name, function_args, task_id=None, user_task=None):
|
||||
import time as _t
|
||||
if function_name == "terminal":
|
||||
_t.sleep(0.05) # ensure requests overlap on the socket
|
||||
cmd = function_args.get("command", "")
|
||||
# Echo semantics: strip leading "echo " and return the rest
|
||||
out = cmd[5:] if cmd.startswith("echo ") else f"mock: {cmd}"
|
||||
return json.dumps({"output": out, "exit_code": 0})
|
||||
return _mock_handle_function_call(
|
||||
function_name, function_args, task_id=task_id, user_task=user_task
|
||||
)
|
||||
|
||||
with patch("model_tools.handle_function_call", side_effect=slow_mock):
|
||||
raw = execute_code(
|
||||
code=code,
|
||||
task_id="test-concurrent",
|
||||
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
|
||||
)
|
||||
result = json.loads(raw)
|
||||
self.assertEqual(result["status"], "success", msg=result)
|
||||
self.assertIn("OK 10/10", result["output"],
|
||||
msg=f"Concurrent tool calls mismatched: {result['output']!r}")
|
||||
|
||||
def test_excluded_tool_returns_error(self):
|
||||
"""Script calling a tool not in the allow-list gets an error from RPC."""
|
||||
code = """
|
||||
|
||||
Reference in New Issue
Block a user