fix(process): correct detached crash recovery state

Previously crash recovery recreated detached sessions as if they were
fully managed, so polls and kills could lie about liveness and the
checkpoint could forget recovered jobs after the next restart.
This commit refreshes recovered host-backed sessions from real PID
state, keeps checkpoint data durable, and preserves notify watcher
metadata while treating sandbox-only PIDs as non-recoverable.

- Persist `pid_scope` in `tools/process_registry.py` and skip
  recovering sandbox-backed entries without a host-visible PID handle
- Refresh detached sessions on access so `get`/`poll`/`wait` and active
  session queries observe exited processes instead of hanging forever
- Allow recovered host PIDs to be terminated honestly and requeue
  `notify_on_complete` watchers during checkpoint recovery
- Add regression tests for durable checkpoints, detached exit/kill
  behavior, sandbox skip logic, and recovered notify watchers
This commit is contained in:
mrshu
2026-04-08 08:59:52 +02:00
committed by Teknium
parent 383db35925
commit 19b0ddce40
3 changed files with 241 additions and 13 deletions

View File

@@ -197,6 +197,26 @@ class TestCheckpointNotify:
s = registry.get("proc_live") s = registry.get("proc_live")
assert s.notify_on_complete is True assert s.notify_on_complete is True
def test_recover_requeues_notify_watchers(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "sleep 999",
"pid": os.getpid(),
"task_id": "t1",
"session_key": "sk1",
"watcher_platform": "telegram",
"watcher_chat_id": "123",
"watcher_thread_id": "42",
"watcher_interval": 5,
"notify_on_complete": True,
}]))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
assert len(registry.pending_watchers) == 1
assert registry.pending_watchers[0]["notify_on_complete"] is True
def test_recover_defaults_false(self, registry, tmp_path): def test_recover_defaults_false(self, registry, tmp_path):
"""Old checkpoint entries without the field default to False.""" """Old checkpoint entries without the field default to False."""
checkpoint = tmp_path / "procs.json" checkpoint = tmp_path / "procs.json"

View File

@@ -2,6 +2,9 @@
import json import json
import os import os
import signal
import subprocess
import sys
import time import time
import pytest import pytest
from pathlib import Path from pathlib import Path
@@ -45,6 +48,23 @@ def _make_session(
return s return s
def _spawn_python_sleep(seconds: float) -> subprocess.Popen:
"""Spawn a portable short-lived Python sleep process."""
return subprocess.Popen(
[sys.executable, "-c", f"import time; time.sleep({seconds})"],
)
def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool:
"""Poll a predicate until it returns truthy or the timeout elapses."""
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if predicate():
return True
time.sleep(interval)
return False
# ========================================================================= # =========================================================================
# Get / Poll # Get / Poll
# ========================================================================= # =========================================================================
@@ -349,6 +369,88 @@ class TestCheckpoint:
assert recovered == 1 assert recovered == 1
assert len(registry.pending_watchers) == 0 assert len(registry.pending_watchers) == 0
def test_recovery_keeps_live_checkpoint_entries(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "sleep 999",
"pid": os.getpid(),
"task_id": "t1",
"session_key": "sk1",
}]))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
assert registry.get("proc_live") is not None
data = json.loads(checkpoint.read_text())
assert len(data) == 1
assert data[0]["session_id"] == "proc_live"
assert data[0]["pid"] == os.getpid()
assert data != []
def test_recovery_skips_explicit_sandbox_backed_entries(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
original = [{
"session_id": "proc_remote",
"command": "sleep 999",
"pid": os.getpid(),
"task_id": "t1",
"pid_scope": "sandbox",
}]
checkpoint.write_text(json.dumps(original))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 0
assert registry.get("proc_remote") is None
data = json.loads(checkpoint.read_text())
assert data == []
def test_detached_recovered_process_eventually_exits(self, registry, tmp_path):
proc = _spawn_python_sleep(0.4)
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "python -c 'import time; time.sleep(0.4)'",
"pid": proc.pid,
"task_id": "t1",
"session_key": "sk1",
}]))
try:
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
session = registry.get("proc_live")
assert session is not None
assert session.detached is True
proc.wait(timeout=5)
assert _wait_until(
lambda: registry.get("proc_live") is not None
and registry.get("proc_live").exited,
timeout=5,
)
poll_result = registry.poll("proc_live")
assert poll_result["status"] == "exited"
wait_result = registry.wait("proc_live", timeout=1)
assert wait_result["status"] == "exited"
finally:
if proc.poll() is None:
proc.terminate()
try:
proc.wait(timeout=5)
except Exception:
proc.kill()
proc.wait(timeout=5)
# ========================================================================= # =========================================================================
# Kill process # Kill process
@@ -365,6 +467,27 @@ class TestKillProcess:
result = registry.kill_process(s.id) result = registry.kill_process(s.id)
assert result["status"] == "already_exited" assert result["status"] == "already_exited"
def test_kill_detached_session_uses_host_pid(self, registry):
s = _make_session(sid="proc_detached", command="sleep 999")
s.pid = 424242
s.detached = True
registry._running[s.id] = s
calls = []
def fake_kill(pid, sig):
calls.append((pid, sig))
try:
with patch("tools.process_registry.os.kill", side_effect=fake_kill):
result = registry.kill_process(s.id)
assert result["status"] == "killed"
assert (424242, 0) in calls
assert (424242, signal.SIGTERM) in calls
finally:
registry._running.pop(s.id, None)
# ========================================================================= # =========================================================================
# Tool handler # Tool handler

View File

@@ -76,6 +76,7 @@ class ProcessSession:
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS) output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
max_output_chars: int = MAX_OUTPUT_CHARS max_output_chars: int = MAX_OUTPUT_CHARS
detached: bool = False # True if recovered from crash (no pipe) detached: bool = False # True if recovered from crash (no pipe)
pid_scope: str = "host" # "host" for local/PTY PIDs, "sandbox" for env-local PIDs
# Watcher/notification metadata (persisted for crash recovery) # Watcher/notification metadata (persisted for crash recovery)
watcher_platform: str = "" watcher_platform: str = ""
watcher_chat_id: str = "" watcher_chat_id: str = ""
@@ -127,6 +128,48 @@ class ProcessRegistry:
lines.pop(0) lines.pop(0)
return "\n".join(lines) return "\n".join(lines)
@staticmethod
def _is_host_pid_alive(pid: Optional[int]) -> bool:
"""Best-effort liveness check for host-visible PIDs."""
if not pid:
return False
try:
os.kill(pid, 0)
return True
except (ProcessLookupError, PermissionError):
return False
def _refresh_detached_session(self, session: Optional[ProcessSession]) -> Optional[ProcessSession]:
"""Update recovered host-PID sessions when the underlying process has exited."""
if session is None or session.exited or not session.detached or session.pid_scope != "host":
return session
if self._is_host_pid_alive(session.pid):
return session
with session._lock:
if session.exited:
return session
session.exited = True
# Recovered sessions no longer have a waitable handle, so the real
# exit code is unavailable once the original process object is gone.
session.exit_code = None
self._move_to_finished(session)
return session
@staticmethod
def _terminate_host_pid(pid: int) -> None:
"""Terminate a host-visible PID without requiring the original process handle."""
if _IS_WINDOWS:
os.kill(pid, signal.SIGTERM)
return
try:
os.killpg(os.getpgid(pid), signal.SIGTERM)
except (OSError, ProcessLookupError, PermissionError):
os.kill(pid, signal.SIGTERM)
# ----- Spawn ----- # ----- Spawn -----
def spawn_local( def spawn_local(
@@ -269,6 +312,7 @@ class ProcessRegistry:
cwd=cwd, cwd=cwd,
started_at=time.time(), started_at=time.time(),
env_ref=env, env_ref=env,
pid_scope="sandbox",
) )
# Run the command in the sandbox with output capture # Run the command in the sandbox with output capture
@@ -439,7 +483,8 @@ class ProcessRegistry:
def get(self, session_id: str) -> Optional[ProcessSession]: def get(self, session_id: str) -> Optional[ProcessSession]:
"""Get a session by ID (running or finished).""" """Get a session by ID (running or finished)."""
with self._lock: with self._lock:
return self._running.get(session_id) or self._finished.get(session_id) session = self._running.get(session_id) or self._finished.get(session_id)
return self._refresh_detached_session(session)
def poll(self, session_id: str) -> dict: def poll(self, session_id: str) -> dict:
"""Check status and get new output for a background process.""" """Check status and get new output for a background process."""
@@ -531,6 +576,7 @@ class ProcessRegistry:
deadline = time.monotonic() + effective_timeout deadline = time.monotonic() + effective_timeout
while time.monotonic() < deadline: while time.monotonic() < deadline:
session = self._refresh_detached_session(session)
if session.exited: if session.exited:
result = { result = {
"status": "exited", "status": "exited",
@@ -596,6 +642,25 @@ class ProcessRegistry:
elif session.env_ref and session.pid: elif session.env_ref and session.pid:
# Non-local -- kill inside sandbox # Non-local -- kill inside sandbox
session.env_ref.execute(f"kill {session.pid} 2>/dev/null", timeout=5) session.env_ref.execute(f"kill {session.pid} 2>/dev/null", timeout=5)
elif session.detached and session.pid_scope == "host" and session.pid:
if not self._is_host_pid_alive(session.pid):
with session._lock:
session.exited = True
session.exit_code = None
self._move_to_finished(session)
return {
"status": "already_exited",
"exit_code": session.exit_code,
}
self._terminate_host_pid(session.pid)
else:
return {
"status": "error",
"error": (
"Recovered process cannot be killed after restart because "
"its original runtime handle is no longer available"
),
}
session.exited = True session.exited = True
session.exit_code = -15 # SIGTERM session.exit_code = -15 # SIGTERM
self._move_to_finished(session) self._move_to_finished(session)
@@ -640,6 +705,8 @@ class ProcessRegistry:
with self._lock: with self._lock:
all_sessions = list(self._running.values()) + list(self._finished.values()) all_sessions = list(self._running.values()) + list(self._finished.values())
all_sessions = [self._refresh_detached_session(s) for s in all_sessions]
if task_id: if task_id:
all_sessions = [s for s in all_sessions if s.task_id == task_id] all_sessions = [s for s in all_sessions if s.task_id == task_id]
@@ -666,6 +733,12 @@ class ProcessRegistry:
def has_active_processes(self, task_id: str) -> bool: def has_active_processes(self, task_id: str) -> bool:
"""Check if there are active (running) processes for a task_id.""" """Check if there are active (running) processes for a task_id."""
with self._lock:
sessions = list(self._running.values())
for session in sessions:
self._refresh_detached_session(session)
with self._lock: with self._lock:
return any( return any(
s.task_id == task_id and not s.exited s.task_id == task_id and not s.exited
@@ -674,6 +747,12 @@ class ProcessRegistry:
def has_active_for_session(self, session_key: str) -> bool: def has_active_for_session(self, session_key: str) -> bool:
"""Check if there are active processes for a gateway session key.""" """Check if there are active processes for a gateway session key."""
with self._lock:
sessions = list(self._running.values())
for session in sessions:
self._refresh_detached_session(session)
with self._lock: with self._lock:
return any( return any(
s.session_key == session_key and not s.exited s.session_key == session_key and not s.exited
@@ -727,6 +806,7 @@ class ProcessRegistry:
"session_id": s.id, "session_id": s.id,
"command": s.command, "command": s.command,
"pid": s.pid, "pid": s.pid,
"pid_scope": s.pid_scope,
"cwd": s.cwd, "cwd": s.cwd,
"started_at": s.started_at, "started_at": s.started_at,
"task_id": s.task_id, "task_id": s.task_id,
@@ -764,13 +844,21 @@ class ProcessRegistry:
if not pid: if not pid:
continue continue
pid_scope = entry.get("pid_scope", "host")
if pid_scope != "host":
# Sandbox-backed processes keep only in-sandbox PIDs in the
# checkpoint, which are not meaningful to the restarted host
# process once the original environment handle is gone.
logger.info(
"Skipping recovery for non-host process: %s (pid=%s, scope=%s)",
entry.get("command", "unknown")[:60],
pid,
pid_scope,
)
continue
# Check if PID is still alive # Check if PID is still alive
alive = False alive = self._is_host_pid_alive(pid)
try:
os.kill(pid, 0)
alive = True
except (ProcessLookupError, PermissionError):
pass
if alive: if alive:
session = ProcessSession( session = ProcessSession(
@@ -779,6 +867,7 @@ class ProcessRegistry:
task_id=entry.get("task_id", ""), task_id=entry.get("task_id", ""),
session_key=entry.get("session_key", ""), session_key=entry.get("session_key", ""),
pid=pid, pid=pid,
pid_scope=pid_scope,
cwd=entry.get("cwd"), cwd=entry.get("cwd"),
started_at=entry.get("started_at", time.time()), started_at=entry.get("started_at", time.time()),
detached=True, # Can't read output, but can report status + kill detached=True, # Can't read output, but can report status + kill
@@ -802,14 +891,10 @@ class ProcessRegistry:
"platform": session.watcher_platform, "platform": session.watcher_platform,
"chat_id": session.watcher_chat_id, "chat_id": session.watcher_chat_id,
"thread_id": session.watcher_thread_id, "thread_id": session.watcher_thread_id,
"notify_on_complete": session.notify_on_complete,
}) })
# Clear the checkpoint (will be rewritten as processes finish) self._write_checkpoint()
try:
from utils import atomic_json_write
atomic_json_write(CHECKPOINT_PATH, [])
except Exception as e:
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
return recovered return recovered