Compare commits

...

1 Commits

Author SHA1 Message Date
Teknium
2555cb76b4 fix: scope tool interrupt signal per-thread to prevent cross-session leaks
The interrupt mechanism in tools/interrupt.py used a process-global
threading.Event. In the gateway, multiple agents run concurrently in
the same process via run_in_executor. When any agent was interrupted
(user sends a follow-up message), the global flag killed ALL agents'
running tools — terminal commands, browser ops, web requests — across
all sessions.

Changes:
- tools/interrupt.py: Replace single threading.Event with a set of
  interrupted thread IDs. set_interrupt() targets a specific thread;
  is_interrupted() checks the current thread. Includes a backward-
  compatible _ThreadAwareEventProxy for legacy _interrupt_event usage.
- run_agent.py: Store execution thread ID at start of run_conversation().
  interrupt() and clear_interrupt() pass it to set_interrupt() so only
  this agent's thread is affected.
- tools/code_execution_tool.py: Use is_interrupted() instead of
  directly checking _interrupt_event.is_set().
- tools/process_registry.py: Same — use is_interrupted().
- tests: Update interrupt tests for per-thread semantics. Add new
  TestPerThreadInterruptIsolation with two tests verifying cross-thread
  isolation.
2026-04-11 13:58:02 -07:00
6 changed files with 183 additions and 78 deletions

View File

@@ -739,6 +739,7 @@ class AIAgent:
# Interrupt mechanism for breaking out of tool loops
self._interrupt_requested = False
self._interrupt_message = None # Optional message that triggered interrupt
self._execution_thread_id: int | None = None # Set at run_conversation() start
self._client_lock = threading.RLock()
# Subagent delegation state
@@ -2832,8 +2833,10 @@ class AIAgent:
"""
self._interrupt_requested = True
self._interrupt_message = message
# Signal all tools to abort any in-flight operations immediately
_set_interrupt(True)
# Signal all tools to abort any in-flight operations immediately.
# Scope the interrupt to this agent's execution thread so other
# agents running in the same process (gateway) are not affected.
_set_interrupt(True, self._execution_thread_id)
# Propagate interrupt to any running child agents (subagent delegation)
with self._active_children_lock:
children_copy = list(self._active_children)
@@ -2846,10 +2849,10 @@ class AIAgent:
print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
def clear_interrupt(self) -> None:
"""Clear any pending interrupt request and the global tool interrupt signal."""
"""Clear any pending interrupt request and the per-thread tool interrupt signal."""
self._interrupt_requested = False
self._interrupt_message = None
_set_interrupt(False)
_set_interrupt(False, self._execution_thread_id)
def _touch_activity(self, desc: str) -> None:
"""Update the last-activity timestamp and description (thread-safe)."""
@@ -7799,6 +7802,11 @@ class AIAgent:
compression_attempts = 0
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
# Record the execution thread so interrupt()/clear_interrupt() can
# scope the tool-level interrupt signal to THIS agent's thread only.
# Must be set before clear_interrupt() which uses it.
self._execution_thread_id = threading.current_thread().ident
# Clear any stale interrupt state at start
self.clear_interrupt()

View File

@@ -22,23 +22,22 @@ class TestInterruptPropagationToChild(unittest.TestCase):
def tearDown(self):
set_interrupt(False)
def _make_bare_agent(self):
"""Create a bare AIAgent via __new__ with all interrupt-related attrs."""
from run_agent import AIAgent
agent = AIAgent.__new__(AIAgent)
agent._interrupt_requested = False
agent._interrupt_message = None
agent._execution_thread_id = None # defaults to current thread in set_interrupt
agent._active_children = []
agent._active_children_lock = threading.Lock()
agent.quiet_mode = True
return agent
def test_parent_interrupt_sets_child_flag(self):
"""When parent.interrupt() is called, child._interrupt_requested should be set."""
from run_agent import AIAgent
parent = AIAgent.__new__(AIAgent)
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
parent = self._make_bare_agent()
child = self._make_bare_agent()
parent._active_children.append(child)
@@ -49,40 +48,26 @@ class TestInterruptPropagationToChild(unittest.TestCase):
assert child._interrupt_message == "new user message"
assert is_interrupted() is True
def test_child_clear_interrupt_at_start_clears_global(self):
"""child.clear_interrupt() at start of run_conversation clears the GLOBAL event.
This is the intended behavior at startup, but verify it doesn't
accidentally clear an interrupt intended for a running child.
def test_child_clear_interrupt_at_start_clears_thread(self):
"""child.clear_interrupt() at start of run_conversation clears the
per-thread interrupt flag for the current thread.
"""
from run_agent import AIAgent
child = AIAgent.__new__(AIAgent)
child = self._make_bare_agent()
child._interrupt_requested = True
child._interrupt_message = "msg"
child.quiet_mode = True
child._active_children = []
child._active_children_lock = threading.Lock()
# Global is set
# Interrupt for current thread is set
set_interrupt(True)
assert is_interrupted() is True
# child.clear_interrupt() clears both
# child.clear_interrupt() clears both instance flag and thread flag
child.clear_interrupt()
assert child._interrupt_requested is False
assert is_interrupted() is False
def test_interrupt_during_child_api_call_detected(self):
"""Interrupt set during _interruptible_api_call is detected within 0.5s."""
from run_agent import AIAgent
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
child = self._make_bare_agent()
child.api_mode = "chat_completions"
child.log_prefix = ""
child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"}
@@ -117,21 +102,8 @@ class TestInterruptPropagationToChild(unittest.TestCase):
def test_concurrent_interrupt_propagation(self):
"""Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts."""
from run_agent import AIAgent
parent = AIAgent.__new__(AIAgent)
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
parent = self._make_bare_agent()
child = self._make_bare_agent()
# Register child (simulating what _run_single_child does)
parent._active_children.append(child)
@@ -157,5 +129,79 @@ class TestInterruptPropagationToChild(unittest.TestCase):
set_interrupt(False)
class TestPerThreadInterruptIsolation(unittest.TestCase):
"""Verify that interrupting one agent does NOT affect another agent's thread.
This is the core fix for the gateway cross-session interrupt leak:
multiple agents run in separate threads within the same process, and
interrupting agent A must not kill agent B's running tools.
"""
def setUp(self):
set_interrupt(False)
def tearDown(self):
set_interrupt(False)
def test_interrupt_only_affects_target_thread(self):
"""set_interrupt(True, tid) only makes is_interrupted() True on that thread."""
results = {}
barrier = threading.Barrier(2)
def thread_a():
"""Agent A's execution thread — will be interrupted."""
tid = threading.current_thread().ident
results["a_tid"] = tid
barrier.wait(timeout=5) # sync with thread B
time.sleep(0.2) # let the interrupt arrive
results["a_interrupted"] = is_interrupted()
def thread_b():
"""Agent B's execution thread — should NOT be affected."""
tid = threading.current_thread().ident
results["b_tid"] = tid
barrier.wait(timeout=5) # sync with thread A
time.sleep(0.2)
results["b_interrupted"] = is_interrupted()
ta = threading.Thread(target=thread_a)
tb = threading.Thread(target=thread_b)
ta.start()
tb.start()
# Wait for both threads to register their TIDs
time.sleep(0.05)
while "a_tid" not in results or "b_tid" not in results:
time.sleep(0.01)
# Interrupt ONLY thread A (simulates gateway interrupting agent A)
set_interrupt(True, results["a_tid"])
ta.join(timeout=3)
tb.join(timeout=3)
assert results["a_interrupted"] is True, "Thread A should see the interrupt"
assert results["b_interrupted"] is False, "Thread B must NOT see thread A's interrupt"
def test_clear_interrupt_only_clears_target_thread(self):
"""Clearing one thread's interrupt doesn't clear another's."""
tid_a = 99990001
tid_b = 99990002
set_interrupt(True, tid_a)
set_interrupt(True, tid_b)
# Clear only A
set_interrupt(False, tid_a)
# Simulate checking from thread B's perspective
from tools.interrupt import _interrupted_threads, _lock
with _lock:
assert tid_a not in _interrupted_threads
assert tid_b in _interrupted_threads
# Cleanup
set_interrupt(False, tid_b)
if __name__ == "__main__":
unittest.main()

View File

@@ -780,14 +780,18 @@ class TestLoadConfig(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
class TestInterruptHandling(unittest.TestCase):
def test_interrupt_event_stops_execution(self):
"""When _interrupt_event is set, execute_code should stop the script."""
"""When interrupt is set for the execution thread, execute_code should stop."""
code = "import time; time.sleep(60); print('should not reach')"
from tools.interrupt import set_interrupt
# Capture the main thread ID so we can target the interrupt correctly.
# execute_code runs in the current thread; set_interrupt needs its ID.
main_tid = threading.current_thread().ident
def set_interrupt_after_delay():
import time as _t
_t.sleep(1)
from tools.terminal_tool import _interrupt_event
_interrupt_event.set()
set_interrupt(True, main_tid)
t = threading.Thread(target=set_interrupt_after_delay, daemon=True)
t.start()
@@ -804,8 +808,7 @@ class TestInterruptHandling(unittest.TestCase):
self.assertEqual(result["status"], "interrupted")
self.assertIn("interrupted", result["output"])
finally:
from tools.terminal_tool import _interrupt_event
_interrupt_event.clear()
set_interrupt(False, main_tid)
t.join(timeout=3)

View File

@@ -924,8 +924,8 @@ def execute_code(
# --- Local execution path (UDS) --- below this line is unchanged ---
# Import interrupt event from terminal_tool (cooperative cancellation)
from tools.terminal_tool import _interrupt_event
# Import per-thread interrupt check (cooperative cancellation)
from tools.interrupt import is_interrupted as _is_interrupted
# Resolve config
_cfg = _load_config()
@@ -1114,7 +1114,7 @@ def execute_code(
status = "success"
while proc.poll() is None:
if _interrupt_event.is_set():
if _is_interrupted():
_kill_process_group(proc)
status = "interrupted"
break

View File

@@ -1,8 +1,12 @@
"""Shared interrupt signaling for all tools.
"""Per-thread interrupt signaling for all tools.
Provides a global threading.Event that any tool can check to determine
if the user has requested an interrupt. The agent's interrupt() method
sets this event, and tools poll it during long-running operations.
Provides thread-scoped interrupt tracking so that interrupting one agent
session does not kill tools running in other sessions. This is critical
in the gateway where multiple agents run concurrently in the same process.
The agent stores its execution thread ID at the start of run_conversation()
and passes it to set_interrupt()/clear_interrupt(). Tools call
is_interrupted() which checks the CURRENT thread — no argument needed.
Usage in tools:
from tools.interrupt import is_interrupted
@@ -12,17 +16,61 @@ Usage in tools:
import threading
_interrupt_event = threading.Event()
# Set of thread idents that have been interrupted.
_interrupted_threads: set[int] = set()
_lock = threading.Lock()
def set_interrupt(active: bool) -> None:
"""Called by the agent to signal or clear the interrupt."""
if active:
_interrupt_event.set()
else:
_interrupt_event.clear()
def set_interrupt(active: bool, thread_id: int | None = None) -> None:
"""Set or clear interrupt for a specific thread.
Args:
active: True to signal interrupt, False to clear it.
thread_id: Target thread ident. When None, targets the
current thread (backward compat for CLI/tests).
"""
tid = thread_id if thread_id is not None else threading.current_thread().ident
with _lock:
if active:
_interrupted_threads.add(tid)
else:
_interrupted_threads.discard(tid)
def is_interrupted() -> bool:
"""Check if an interrupt has been requested. Safe to call from any thread."""
return _interrupt_event.is_set()
"""Check if an interrupt has been requested for the current thread.
Safe to call from any thread — each thread only sees its own
interrupt state.
"""
tid = threading.current_thread().ident
with _lock:
return tid in _interrupted_threads
# ---------------------------------------------------------------------------
# Backward-compatible _interrupt_event proxy
# ---------------------------------------------------------------------------
# Some legacy call sites (code_execution_tool, process_registry, tests)
# import _interrupt_event directly and call .is_set() / .set() / .clear().
# This shim maps those calls to the per-thread functions above so existing
# code keeps working while the underlying mechanism is thread-scoped.
class _ThreadAwareEventProxy:
"""Drop-in proxy that maps threading.Event methods to per-thread state."""
def is_set(self) -> bool:
return is_interrupted()
def set(self) -> None: # noqa: A003
set_interrupt(True)
def clear(self) -> None:
set_interrupt(False)
def wait(self, timeout: float | None = None) -> bool:
"""Not truly supported — returns current state immediately."""
return self.is_set()
_interrupt_event = _ThreadAwareEventProxy()

View File

@@ -686,7 +686,7 @@ class ProcessRegistry:
and output snapshot.
"""
from tools.ansi_strip import strip_ansi
from tools.terminal_tool import _interrupt_event
from tools.interrupt import is_interrupted as _is_interrupted
try:
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
@@ -723,7 +723,7 @@ class ProcessRegistry:
result["timeout_note"] = timeout_note
return result
if _interrupt_event.is_set():
if _is_interrupted():
result = {
"status": "interrupted",
"output": strip_ansi(session.output_buffer[-1000:]),