diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index 6f6260ffe24..a5806046583 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -114,14 +114,30 @@ class TestHermesToolsGeneration(unittest.TestCase): self.assertIn("def json_parse(", src) self.assertIn("def shell_quote(", src) self.assertIn("def retry(", src) - self.assertIn("import json, os, socket, shlex, time", src) + self.assertIn("import json, os, socket, shlex, threading, time", src) def test_file_transport_uses_tempfile_fallback_for_rpc_dir(self): src = generate_hermes_tools_module(["terminal"], transport="file") - self.assertIn("import json, os, shlex, tempfile, time", src) + self.assertIn("import json, os, shlex, tempfile, threading, time", src) self.assertIn("os.path.join(tempfile.gettempdir(), \"hermes_rpc\")", src) self.assertNotIn('os.environ.get("HERMES_RPC_DIR", "/tmp/hermes_rpc")', src) + def test_uds_transport_serializes_concurrent_calls(self): + """Regression: UDS _call() must hold a lock across send+recv so that + concurrent tool calls from multiple threads don't interleave on the + shared socket and receive each other's responses.""" + src = generate_hermes_tools_module(["terminal"], transport="uds") + self.assertIn("_call_lock = threading.Lock()", src) + self.assertIn("with _call_lock:", src) + + def test_file_transport_serializes_seq_allocation(self): + """Regression: file transport _call() must allocate `_seq` under a + lock, otherwise concurrent threads can pick the same seq and clobber + each other's request files.""" + src = generate_hermes_tools_module(["terminal"], transport="file") + self.assertIn("_seq_lock = threading.Lock()", src) + self.assertIn("with _seq_lock:", src) + class TestExecuteCodeRemoteTempDir(unittest.TestCase): def test_execute_remote_uses_backend_temp_dir_for_sandbox(self): @@ -226,6 +242,64 @@ print(f"file lines: {r2['total_lines']}") result = self._run("raise ValueError('test error')") self.assertEqual(result["status"], "error") + def test_concurrent_tool_calls_match_responses(self): + """Regression for the UDS RPC race: multiple threads inside the + sandbox calling terminal() concurrently must each receive their own + response, not another thread's. + + Before the fix, `_sock` and the recv-loop were shared without a + lock, so responses (written FIFO by the single-threaded server) + got delivered to whichever client thread happened to win the + recv() race. That surfaced as each thread seeing another thread's + output. + + The mock dispatcher sleeps briefly to guarantee the requests + overlap on the socket. + """ + code = ''' +import threading +from concurrent.futures import ThreadPoolExecutor +from hermes_tools import terminal + +N = 10 + +def call(i): + r = terminal(f"echo TAG-{i}") + return i, r.get("output", "") + +with ThreadPoolExecutor(max_workers=N) as ex: + results = list(ex.map(call, range(N))) + +mismatches = [(i, out) for i, out in results if f"TAG-{i}" not in out] +if mismatches: + print(f"MISMATCH {len(mismatches)}/{N}: {mismatches[:3]}") +else: + print(f"OK {N}/{N}") +''' + + def slow_mock(function_name, function_args, task_id=None, user_task=None): + import time as _t + if function_name == "terminal": + _t.sleep(0.05) # ensure requests overlap on the socket + cmd = function_args.get("command", "") + # Echo semantics: strip leading "echo " and return the rest + out = cmd[5:] if cmd.startswith("echo ") else f"mock: {cmd}" + return json.dumps({"output": out, "exit_code": 0}) + return _mock_handle_function_call( + function_name, function_args, task_id=task_id, user_task=user_task + ) + + with patch("model_tools.handle_function_call", side_effect=slow_mock): + raw = execute_code( + code=code, + task_id="test-concurrent", + enabled_tools=list(SANDBOX_ALLOWED_TOOLS), + ) + result = json.loads(raw) + self.assertEqual(result["status"], "success", msg=result) + self.assertIn("OK 10/10", result["output"], + msg=f"Concurrent tool calls mismatched: {result['output']!r}") + def test_excluded_tool_returns_error(self): """Script calling a tool not in the allow-list gets an error from RPC.""" code = """ diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index c91907c4d12..ffcf726fcd5 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -224,9 +224,14 @@ def retry(fn, max_attempts=3, delay=2): _UDS_TRANSPORT_HEADER = '''\ """Auto-generated Hermes tools RPC stubs.""" -import json, os, socket, shlex, time +import json, os, socket, shlex, threading, time _sock = None +# The RPC server handles a single client connection serially and has no +# request-id in the protocol, so concurrent _call() invocations from multiple +# threads (e.g. ThreadPoolExecutor) would race on the shared socket and get +# each other's responses. Serialize the entire send+recv round-trip. +_call_lock = threading.Lock() ''' + _COMMON_HELPERS + '''\ def _connect(): @@ -239,17 +244,18 @@ def _connect(): def _call(tool_name, args): """Send a tool call to the parent process and return the parsed result.""" - conn = _connect() request = json.dumps({"tool": tool_name, "args": args}) + "\\n" - conn.sendall(request.encode()) - buf = b"" - while True: - chunk = conn.recv(65536) - if not chunk: - raise RuntimeError("Agent process disconnected") - buf += chunk - if buf.endswith(b"\\n"): - break + with _call_lock: + conn = _connect() + conn.sendall(request.encode()) + buf = b"" + while True: + chunk = conn.recv(65536) + if not chunk: + raise RuntimeError("Agent process disconnected") + buf += chunk + if buf.endswith(b"\\n"): + break raw = buf.decode().strip() result = json.loads(raw) if isinstance(result, str): @@ -265,24 +271,30 @@ def _call(tool_name, args): _FILE_TRANSPORT_HEADER = '''\ """Auto-generated Hermes tools RPC stubs (file-based transport).""" -import json, os, shlex, tempfile, time +import json, os, shlex, tempfile, threading, time _RPC_DIR = os.environ.get("HERMES_RPC_DIR") or os.path.join(tempfile.gettempdir(), "hermes_rpc") _seq = 0 +# `_seq += 1` is not atomic (read-modify-write), so concurrent _call() +# invocations from multiple threads could allocate the same sequence number +# and clobber each other's request files. Guard seq allocation with a lock. +_seq_lock = threading.Lock() ''' + _COMMON_HELPERS + '''\ def _call(tool_name, args): """Send a tool call request via file-based RPC and wait for response.""" global _seq - _seq += 1 - seq_str = f"{_seq:06d}" + with _seq_lock: + _seq += 1 + seq = _seq + seq_str = f"{seq:06d}" req_file = os.path.join(_RPC_DIR, f"req_{seq_str}") res_file = os.path.join(_RPC_DIR, f"res_{seq_str}") # Write request atomically (write to .tmp, then rename) tmp = req_file + ".tmp" with open(tmp, "w") as f: - json.dump({"tool": tool_name, "args": args, "seq": _seq}, f) + json.dump({"tool": tool_name, "args": args, "seq": seq}, f) os.rename(tmp, req_file) # Wait for response with adaptive polling