fix(tools): serialize concurrent hermes_tools RPC calls from execute_code

The sandbox-side `_call()` in both the UDS and file-based transports was
not thread-safe, so scripts that call tools from multiple threads (e.g.
`ThreadPoolExecutor` over `terminal()`) inside a single `execute_code`
run could silently receive each other's responses.

Root cause:

* UDS transport — a single module-level `_sock` was shared across all
  threads; the newline-framed protocol has no request-id; and the
  server-side RPC loop handles one connection serially. With concurrent
  callers, each thread would `sendall()` then race to `recv()` the next
  newline-terminated response from the shared buffer, so responses got
  delivered to the wrong caller.

* File transport — `_seq += 1` is a non-atomic read-modify-write, so
  two threads could allocate the same sequence number and clobber each
  other's request/response files.

Fix: guard `_call()` with a `threading.Lock` in the UDS case (covering
send+recv), and guard `_seq` allocation with a lock in the file case.
No protocol change.

Regression tests cover both the generated-source level (lock is present
and used) and an end-to-end concurrency test: running a sandboxed
ThreadPoolExecutor of 10 `terminal()` calls against a slow mock
dispatcher, asserting every caller sees its own tagged response. The
test fails without the fix (10/10 mismatched, matching real-world
repro) and passes with it.
This commit is contained in:
Heltman
2026-04-30 13:20:05 +08:00
committed by Teknium
parent 3858f9419e
commit 19f9be1dff
2 changed files with 103 additions and 17 deletions

View File

@@ -114,14 +114,30 @@ class TestHermesToolsGeneration(unittest.TestCase):
self.assertIn("def json_parse(", src)
self.assertIn("def shell_quote(", src)
self.assertIn("def retry(", src)
self.assertIn("import json, os, socket, shlex, time", src)
self.assertIn("import json, os, socket, shlex, threading, time", src)
def test_file_transport_uses_tempfile_fallback_for_rpc_dir(self):
src = generate_hermes_tools_module(["terminal"], transport="file")
self.assertIn("import json, os, shlex, tempfile, time", src)
self.assertIn("import json, os, shlex, tempfile, threading, time", src)
self.assertIn("os.path.join(tempfile.gettempdir(), \"hermes_rpc\")", src)
self.assertNotIn('os.environ.get("HERMES_RPC_DIR", "/tmp/hermes_rpc")', src)
def test_uds_transport_serializes_concurrent_calls(self):
"""Regression: UDS _call() must hold a lock across send+recv so that
concurrent tool calls from multiple threads don't interleave on the
shared socket and receive each other's responses."""
src = generate_hermes_tools_module(["terminal"], transport="uds")
self.assertIn("_call_lock = threading.Lock()", src)
self.assertIn("with _call_lock:", src)
def test_file_transport_serializes_seq_allocation(self):
"""Regression: file transport _call() must allocate `_seq` under a
lock, otherwise concurrent threads can pick the same seq and clobber
each other's request files."""
src = generate_hermes_tools_module(["terminal"], transport="file")
self.assertIn("_seq_lock = threading.Lock()", src)
self.assertIn("with _seq_lock:", src)
class TestExecuteCodeRemoteTempDir(unittest.TestCase):
def test_execute_remote_uses_backend_temp_dir_for_sandbox(self):
@@ -226,6 +242,64 @@ print(f"file lines: {r2['total_lines']}")
result = self._run("raise ValueError('test error')")
self.assertEqual(result["status"], "error")
def test_concurrent_tool_calls_match_responses(self):
"""Regression for the UDS RPC race: multiple threads inside the
sandbox calling terminal() concurrently must each receive their own
response, not another thread's.
Before the fix, `_sock` and the recv-loop were shared without a
lock, so responses (written FIFO by the single-threaded server)
got delivered to whichever client thread happened to win the
recv() race. That surfaced as each thread seeing another thread's
output.
The mock dispatcher sleeps briefly to guarantee the requests
overlap on the socket.
"""
code = '''
import threading
from concurrent.futures import ThreadPoolExecutor
from hermes_tools import terminal
N = 10
def call(i):
r = terminal(f"echo TAG-{i}")
return i, r.get("output", "")
with ThreadPoolExecutor(max_workers=N) as ex:
results = list(ex.map(call, range(N)))
mismatches = [(i, out) for i, out in results if f"TAG-{i}" not in out]
if mismatches:
print(f"MISMATCH {len(mismatches)}/{N}: {mismatches[:3]}")
else:
print(f"OK {N}/{N}")
'''
def slow_mock(function_name, function_args, task_id=None, user_task=None):
import time as _t
if function_name == "terminal":
_t.sleep(0.05) # ensure requests overlap on the socket
cmd = function_args.get("command", "")
# Echo semantics: strip leading "echo " and return the rest
out = cmd[5:] if cmd.startswith("echo ") else f"mock: {cmd}"
return json.dumps({"output": out, "exit_code": 0})
return _mock_handle_function_call(
function_name, function_args, task_id=task_id, user_task=user_task
)
with patch("model_tools.handle_function_call", side_effect=slow_mock):
raw = execute_code(
code=code,
task_id="test-concurrent",
enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
)
result = json.loads(raw)
self.assertEqual(result["status"], "success", msg=result)
self.assertIn("OK 10/10", result["output"],
msg=f"Concurrent tool calls mismatched: {result['output']!r}")
def test_excluded_tool_returns_error(self):
"""Script calling a tool not in the allow-list gets an error from RPC."""
code = """