refactor: extract atomic_json_write helper, add 24 checkpoint tests

Extract the duplicated temp-file + fsync + os.replace pattern from
batch_runner.py (1 instance) and process_registry.py (2 instances) into
a shared utils.atomic_json_write() function.

Add 12 tests for atomic_json_write covering: valid JSON, parent dir
creation, overwrite, crash safety (original preserved on error), no temp
file leaks, string paths, unicode, custom indent, concurrent writes.

Add 12 tests for batch_runner checkpoint behavior covering:
_save_checkpoint (valid JSON, last_updated, overwrite, lock/no-lock,
parent dirs, no temp leaks), _load_checkpoint (missing file, existing
data, corrupt JSON), and resume logic (preserves prior progress,
different run_name starts fresh).
This commit is contained in:
teknium1
2026-03-06 05:50:12 -08:00
parent c05c60665e
commit d63b363cde
5 changed files with 340 additions and 64 deletions

View File

@@ -29,8 +29,6 @@ from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime from datetime import datetime
from multiprocessing import Pool, Lock from multiprocessing import Pool, Lock
import traceback import traceback
import tempfile
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console from rich.console import Console
import fire import fire
@@ -703,32 +701,12 @@ class BatchRunner:
""" """
checkpoint_data["last_updated"] = datetime.now().isoformat() checkpoint_data["last_updated"] = datetime.now().isoformat()
def _atomic_write(): from utils import atomic_json_write
"""Write checkpoint atomically (temp file + replace) to avoid corruption on crash."""
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(
dir=str(self.checkpoint_file.parent),
prefix='.checkpoint_',
suffix='.tmp',
)
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, self.checkpoint_file)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
if lock: if lock:
with lock: with lock:
_atomic_write() atomic_json_write(self.checkpoint_file, checkpoint_data)
else: else:
_atomic_write() atomic_json_write(self.checkpoint_file, checkpoint_data)
def _scan_completed_prompts_by_content(self) -> set: def _scan_completed_prompts_by_content(self) -> set:
""" """

View File

@@ -0,0 +1,132 @@
"""Tests for utils.atomic_json_write — crash-safe JSON file writes."""
import json
import os
from pathlib import Path
from unittest.mock import patch
import pytest
from utils import atomic_json_write
class TestAtomicJsonWrite:
"""Core atomic write behavior."""
def test_writes_valid_json(self, tmp_path):
target = tmp_path / "data.json"
data = {"key": "value", "nested": {"a": 1}}
atomic_json_write(target, data)
result = json.loads(target.read_text(encoding="utf-8"))
assert result == data
def test_creates_parent_directories(self, tmp_path):
target = tmp_path / "deep" / "nested" / "dir" / "data.json"
atomic_json_write(target, {"ok": True})
assert target.exists()
assert json.loads(target.read_text())["ok"] is True
def test_overwrites_existing_file(self, tmp_path):
target = tmp_path / "data.json"
target.write_text('{"old": true}')
atomic_json_write(target, {"new": True})
result = json.loads(target.read_text())
assert result == {"new": True}
def test_preserves_original_on_serialization_error(self, tmp_path):
target = tmp_path / "data.json"
original = {"preserved": True}
target.write_text(json.dumps(original))
# Try to write non-serializable data — should fail
with pytest.raises(TypeError):
atomic_json_write(target, {"bad": object()})
# Original file should be untouched
result = json.loads(target.read_text())
assert result == original
def test_no_leftover_temp_files_on_success(self, tmp_path):
target = tmp_path / "data.json"
atomic_json_write(target, [1, 2, 3])
# No .tmp files should be left behind
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
assert len(tmp_files) == 0
assert target.exists()
def test_no_leftover_temp_files_on_failure(self, tmp_path):
target = tmp_path / "data.json"
with pytest.raises(TypeError):
atomic_json_write(target, {"bad": object()})
# No temp files should be left behind
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
assert len(tmp_files) == 0
def test_accepts_string_path(self, tmp_path):
target = str(tmp_path / "string_path.json")
atomic_json_write(target, {"string": True})
result = json.loads(Path(target).read_text())
assert result == {"string": True}
def test_writes_list_data(self, tmp_path):
target = tmp_path / "list.json"
data = [1, "two", {"three": 3}]
atomic_json_write(target, data)
result = json.loads(target.read_text())
assert result == data
def test_empty_list(self, tmp_path):
target = tmp_path / "empty.json"
atomic_json_write(target, [])
result = json.loads(target.read_text())
assert result == []
def test_custom_indent(self, tmp_path):
target = tmp_path / "custom.json"
atomic_json_write(target, {"a": 1}, indent=4)
text = target.read_text()
assert ' "a"' in text # 4-space indent
def test_unicode_content(self, tmp_path):
target = tmp_path / "unicode.json"
data = {"emoji": "🎉", "japanese": "日本語"}
atomic_json_write(target, data)
result = json.loads(target.read_text(encoding="utf-8"))
assert result["emoji"] == "🎉"
assert result["japanese"] == "日本語"
def test_concurrent_writes_dont_corrupt(self, tmp_path):
"""Multiple rapid writes should each produce valid JSON."""
import threading
target = tmp_path / "concurrent.json"
errors = []
def writer(n):
try:
atomic_json_write(target, {"writer": n, "data": list(range(100))})
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors
# File should contain valid JSON from one of the writers
result = json.loads(target.read_text())
assert "writer" in result
assert len(result["data"]) == 100

View File

@@ -0,0 +1,159 @@
"""Tests for batch_runner checkpoint behavior — incremental writes, resume, atomicity."""
import json
import os
from pathlib import Path
from multiprocessing import Lock
from unittest.mock import patch, MagicMock
import pytest
# batch_runner uses relative imports, ensure project root is on path
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from batch_runner import BatchRunner
@pytest.fixture
def runner(tmp_path):
"""Create a BatchRunner with all paths pointing at tmp_path."""
prompts_file = tmp_path / "prompts.jsonl"
prompts_file.write_text("")
output_file = tmp_path / "output.jsonl"
checkpoint_file = tmp_path / "checkpoint.json"
r = BatchRunner.__new__(BatchRunner)
r.run_name = "test_run"
r.checkpoint_file = checkpoint_file
r.output_file = output_file
r.prompts_file = prompts_file
return r
class TestSaveCheckpoint:
"""Verify _save_checkpoint writes valid, atomic JSON."""
def test_writes_valid_json(self, runner):
data = {"run_name": "test", "completed_prompts": [1, 2, 3], "batch_stats": {}}
runner._save_checkpoint(data)
result = json.loads(runner.checkpoint_file.read_text())
assert result["run_name"] == "test"
assert result["completed_prompts"] == [1, 2, 3]
def test_adds_last_updated(self, runner):
data = {"run_name": "test", "completed_prompts": []}
runner._save_checkpoint(data)
result = json.loads(runner.checkpoint_file.read_text())
assert "last_updated" in result
assert result["last_updated"] is not None
def test_overwrites_previous_checkpoint(self, runner):
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1]})
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1, 2, 3]})
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [1, 2, 3]
def test_with_lock(self, runner):
lock = Lock()
data = {"run_name": "test", "completed_prompts": [42]}
runner._save_checkpoint(data, lock=lock)
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [42]
def test_without_lock(self, runner):
data = {"run_name": "test", "completed_prompts": [99]}
runner._save_checkpoint(data, lock=None)
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [99]
def test_creates_parent_dirs(self, tmp_path):
runner_deep = BatchRunner.__new__(BatchRunner)
runner_deep.checkpoint_file = tmp_path / "deep" / "nested" / "checkpoint.json"
data = {"run_name": "test", "completed_prompts": []}
runner_deep._save_checkpoint(data)
assert runner_deep.checkpoint_file.exists()
def test_no_temp_files_left(self, runner):
runner._save_checkpoint({"run_name": "test", "completed_prompts": []})
tmp_files = [f for f in runner.checkpoint_file.parent.iterdir()
if ".tmp" in f.name]
assert len(tmp_files) == 0
class TestLoadCheckpoint:
"""Verify _load_checkpoint reads existing data or returns defaults."""
def test_returns_empty_when_no_file(self, runner):
result = runner._load_checkpoint()
assert result.get("completed_prompts", []) == []
def test_loads_existing_checkpoint(self, runner):
data = {"run_name": "test_run", "completed_prompts": [5, 10, 15],
"batch_stats": {"0": {"processed": 3}}}
runner.checkpoint_file.write_text(json.dumps(data))
result = runner._load_checkpoint()
assert result["completed_prompts"] == [5, 10, 15]
assert result["batch_stats"]["0"]["processed"] == 3
def test_handles_corrupt_json(self, runner):
runner.checkpoint_file.write_text("{broken json!!")
result = runner._load_checkpoint()
# Should return empty/default, not crash
assert isinstance(result, dict)
class TestResumePreservesProgress:
"""Verify that initializing a run with resume=True loads prior checkpoint."""
def test_completed_prompts_loaded_from_checkpoint(self, runner):
# Simulate a prior run that completed prompts 0-4
prior = {
"run_name": "test_run",
"completed_prompts": [0, 1, 2, 3, 4],
"batch_stats": {"0": {"processed": 5}},
"last_updated": "2026-01-01T00:00:00",
}
runner.checkpoint_file.write_text(json.dumps(prior))
# Load checkpoint like run() does
checkpoint_data = runner._load_checkpoint()
if checkpoint_data.get("run_name") != runner.run_name:
checkpoint_data = {
"run_name": runner.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None,
}
completed_set = set(checkpoint_data.get("completed_prompts", []))
assert completed_set == {0, 1, 2, 3, 4}
def test_different_run_name_starts_fresh(self, runner):
prior = {
"run_name": "different_run",
"completed_prompts": [0, 1, 2],
"batch_stats": {},
}
runner.checkpoint_file.write_text(json.dumps(prior))
checkpoint_data = runner._load_checkpoint()
if checkpoint_data.get("run_name") != runner.run_name:
checkpoint_data = {
"run_name": runner.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None,
}
assert checkpoint_data["completed_prompts"] == []
assert checkpoint_data["run_name"] == "test_run"

View File

@@ -37,7 +37,6 @@ import shlex
import shutil import shutil
import signal import signal
import subprocess import subprocess
import tempfile
import threading import threading
import time import time
import uuid import uuid
@@ -707,25 +706,9 @@ class ProcessRegistry:
"session_key": s.session_key, "session_key": s.session_key,
}) })
# Atomic write: temp file + os.replace to avoid corruption on crash # Atomic write to avoid corruption on crash
CHECKPOINT_PATH.parent.mkdir(parents=True, exist_ok=True) from utils import atomic_json_write
fd, tmp_path = tempfile.mkstemp( atomic_json_write(CHECKPOINT_PATH, entries)
dir=str(CHECKPOINT_PATH.parent),
prefix='.checkpoint_',
suffix='.tmp',
)
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
json.dump(entries, f, indent=2, ensure_ascii=False)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, CHECKPOINT_PATH)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
except Exception as e: except Exception as e:
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True) logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
@@ -774,26 +757,9 @@ class ProcessRegistry:
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid) logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
# Clear the checkpoint (will be rewritten as processes finish) # Clear the checkpoint (will be rewritten as processes finish)
# Use atomic write to avoid corruption
try: try:
CHECKPOINT_PATH.parent.mkdir(parents=True, exist_ok=True) from utils import atomic_json_write
fd, tmp_path = tempfile.mkstemp( atomic_json_write(CHECKPOINT_PATH, [])
dir=str(CHECKPOINT_PATH.parent),
prefix='.checkpoint_',
suffix='.tmp',
)
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write("[]")
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, CHECKPOINT_PATH)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
except Exception as e: except Exception as e:
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True) logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)

41
utils.py Normal file
View File

@@ -0,0 +1,41 @@
"""Shared utility functions for hermes-agent."""
import json
import os
import tempfile
from pathlib import Path
from typing import Any, Union
def atomic_json_write(path: Union[str, Path], data: Any, *, indent: int = 2) -> None:
"""Write JSON data to a file atomically.
Uses temp file + fsync + os.replace to ensure the target file is never
left in a partially-written state. If the process crashes mid-write,
the previous version of the file remains intact.
Args:
path: Target file path (will be created or overwritten).
data: JSON-serializable data to write.
indent: JSON indentation (default 2).
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(
dir=str(path.parent),
prefix=f".{path.stem}_",
suffix=".tmp",
)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(data, f, indent=indent, ensure_ascii=False)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, path)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise