mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
refactor: extract atomic_json_write helper, add 24 checkpoint tests
Extract the duplicated temp-file + fsync + os.replace pattern from batch_runner.py (1 instance) and process_registry.py (2 instances) into a shared utils.atomic_json_write() function. Add 12 tests for atomic_json_write covering: valid JSON, parent dir creation, overwrite, crash safety (original preserved on error), no temp file leaks, string paths, unicode, custom indent, concurrent writes. Add 12 tests for batch_runner checkpoint behavior covering: _save_checkpoint (valid JSON, last_updated, overwrite, lock/no-lock, parent dirs, no temp leaks), _load_checkpoint (missing file, existing data, corrupt JSON), and resume logic (preserves prior progress, different run_name starts fresh).
This commit is contained in:
@@ -29,8 +29,6 @@ from typing import List, Dict, Any, Optional, Tuple
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from multiprocessing import Pool, Lock
|
from multiprocessing import Pool, Lock
|
||||||
import traceback
|
import traceback
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
import fire
|
import fire
|
||||||
@@ -703,32 +701,12 @@ class BatchRunner:
|
|||||||
"""
|
"""
|
||||||
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
||||||
|
|
||||||
def _atomic_write():
|
from utils import atomic_json_write
|
||||||
"""Write checkpoint atomically (temp file + replace) to avoid corruption on crash."""
|
|
||||||
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
fd, tmp_path = tempfile.mkstemp(
|
|
||||||
dir=str(self.checkpoint_file.parent),
|
|
||||||
prefix='.checkpoint_',
|
|
||||||
suffix='.tmp',
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
|
||||||
f.flush()
|
|
||||||
os.fsync(f.fileno())
|
|
||||||
os.replace(tmp_path, self.checkpoint_file)
|
|
||||||
except BaseException:
|
|
||||||
try:
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
|
|
||||||
if lock:
|
if lock:
|
||||||
with lock:
|
with lock:
|
||||||
_atomic_write()
|
atomic_json_write(self.checkpoint_file, checkpoint_data)
|
||||||
else:
|
else:
|
||||||
_atomic_write()
|
atomic_json_write(self.checkpoint_file, checkpoint_data)
|
||||||
|
|
||||||
def _scan_completed_prompts_by_content(self) -> set:
|
def _scan_completed_prompts_by_content(self) -> set:
|
||||||
"""
|
"""
|
||||||
|
|||||||
132
tests/test_atomic_json_write.py
Normal file
132
tests/test_atomic_json_write.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Tests for utils.atomic_json_write — crash-safe JSON file writes."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from utils import atomic_json_write
|
||||||
|
|
||||||
|
|
||||||
|
class TestAtomicJsonWrite:
|
||||||
|
"""Core atomic write behavior."""
|
||||||
|
|
||||||
|
def test_writes_valid_json(self, tmp_path):
|
||||||
|
target = tmp_path / "data.json"
|
||||||
|
data = {"key": "value", "nested": {"a": 1}}
|
||||||
|
atomic_json_write(target, data)
|
||||||
|
|
||||||
|
result = json.loads(target.read_text(encoding="utf-8"))
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
def test_creates_parent_directories(self, tmp_path):
|
||||||
|
target = tmp_path / "deep" / "nested" / "dir" / "data.json"
|
||||||
|
atomic_json_write(target, {"ok": True})
|
||||||
|
|
||||||
|
assert target.exists()
|
||||||
|
assert json.loads(target.read_text())["ok"] is True
|
||||||
|
|
||||||
|
def test_overwrites_existing_file(self, tmp_path):
|
||||||
|
target = tmp_path / "data.json"
|
||||||
|
target.write_text('{"old": true}')
|
||||||
|
|
||||||
|
atomic_json_write(target, {"new": True})
|
||||||
|
result = json.loads(target.read_text())
|
||||||
|
assert result == {"new": True}
|
||||||
|
|
||||||
|
def test_preserves_original_on_serialization_error(self, tmp_path):
|
||||||
|
target = tmp_path / "data.json"
|
||||||
|
original = {"preserved": True}
|
||||||
|
target.write_text(json.dumps(original))
|
||||||
|
|
||||||
|
# Try to write non-serializable data — should fail
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
atomic_json_write(target, {"bad": object()})
|
||||||
|
|
||||||
|
# Original file should be untouched
|
||||||
|
result = json.loads(target.read_text())
|
||||||
|
assert result == original
|
||||||
|
|
||||||
|
def test_no_leftover_temp_files_on_success(self, tmp_path):
|
||||||
|
target = tmp_path / "data.json"
|
||||||
|
atomic_json_write(target, [1, 2, 3])
|
||||||
|
|
||||||
|
# No .tmp files should be left behind
|
||||||
|
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
||||||
|
assert len(tmp_files) == 0
|
||||||
|
assert target.exists()
|
||||||
|
|
||||||
|
def test_no_leftover_temp_files_on_failure(self, tmp_path):
|
||||||
|
target = tmp_path / "data.json"
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
atomic_json_write(target, {"bad": object()})
|
||||||
|
|
||||||
|
# No temp files should be left behind
|
||||||
|
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
||||||
|
assert len(tmp_files) == 0
|
||||||
|
|
||||||
|
def test_accepts_string_path(self, tmp_path):
|
||||||
|
target = str(tmp_path / "string_path.json")
|
||||||
|
atomic_json_write(target, {"string": True})
|
||||||
|
|
||||||
|
result = json.loads(Path(target).read_text())
|
||||||
|
assert result == {"string": True}
|
||||||
|
|
||||||
|
def test_writes_list_data(self, tmp_path):
|
||||||
|
target = tmp_path / "list.json"
|
||||||
|
data = [1, "two", {"three": 3}]
|
||||||
|
atomic_json_write(target, data)
|
||||||
|
|
||||||
|
result = json.loads(target.read_text())
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
def test_empty_list(self, tmp_path):
|
||||||
|
target = tmp_path / "empty.json"
|
||||||
|
atomic_json_write(target, [])
|
||||||
|
|
||||||
|
result = json.loads(target.read_text())
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_custom_indent(self, tmp_path):
|
||||||
|
target = tmp_path / "custom.json"
|
||||||
|
atomic_json_write(target, {"a": 1}, indent=4)
|
||||||
|
|
||||||
|
text = target.read_text()
|
||||||
|
assert ' "a"' in text # 4-space indent
|
||||||
|
|
||||||
|
def test_unicode_content(self, tmp_path):
|
||||||
|
target = tmp_path / "unicode.json"
|
||||||
|
data = {"emoji": "🎉", "japanese": "日本語"}
|
||||||
|
atomic_json_write(target, data)
|
||||||
|
|
||||||
|
result = json.loads(target.read_text(encoding="utf-8"))
|
||||||
|
assert result["emoji"] == "🎉"
|
||||||
|
assert result["japanese"] == "日本語"
|
||||||
|
|
||||||
|
def test_concurrent_writes_dont_corrupt(self, tmp_path):
|
||||||
|
"""Multiple rapid writes should each produce valid JSON."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
target = tmp_path / "concurrent.json"
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def writer(n):
|
||||||
|
try:
|
||||||
|
atomic_json_write(target, {"writer": n, "data": list(range(100))})
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert not errors
|
||||||
|
# File should contain valid JSON from one of the writers
|
||||||
|
result = json.loads(target.read_text())
|
||||||
|
assert "writer" in result
|
||||||
|
assert len(result["data"]) == 100
|
||||||
159
tests/test_batch_runner_checkpoint.py
Normal file
159
tests/test_batch_runner_checkpoint.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""Tests for batch_runner checkpoint behavior — incremental writes, resume, atomicity."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from multiprocessing import Lock
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# batch_runner uses relative imports, ensure project root is on path
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from batch_runner import BatchRunner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runner(tmp_path):
|
||||||
|
"""Create a BatchRunner with all paths pointing at tmp_path."""
|
||||||
|
prompts_file = tmp_path / "prompts.jsonl"
|
||||||
|
prompts_file.write_text("")
|
||||||
|
output_file = tmp_path / "output.jsonl"
|
||||||
|
checkpoint_file = tmp_path / "checkpoint.json"
|
||||||
|
r = BatchRunner.__new__(BatchRunner)
|
||||||
|
r.run_name = "test_run"
|
||||||
|
r.checkpoint_file = checkpoint_file
|
||||||
|
r.output_file = output_file
|
||||||
|
r.prompts_file = prompts_file
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
class TestSaveCheckpoint:
|
||||||
|
"""Verify _save_checkpoint writes valid, atomic JSON."""
|
||||||
|
|
||||||
|
def test_writes_valid_json(self, runner):
|
||||||
|
data = {"run_name": "test", "completed_prompts": [1, 2, 3], "batch_stats": {}}
|
||||||
|
runner._save_checkpoint(data)
|
||||||
|
|
||||||
|
result = json.loads(runner.checkpoint_file.read_text())
|
||||||
|
assert result["run_name"] == "test"
|
||||||
|
assert result["completed_prompts"] == [1, 2, 3]
|
||||||
|
|
||||||
|
def test_adds_last_updated(self, runner):
|
||||||
|
data = {"run_name": "test", "completed_prompts": []}
|
||||||
|
runner._save_checkpoint(data)
|
||||||
|
|
||||||
|
result = json.loads(runner.checkpoint_file.read_text())
|
||||||
|
assert "last_updated" in result
|
||||||
|
assert result["last_updated"] is not None
|
||||||
|
|
||||||
|
def test_overwrites_previous_checkpoint(self, runner):
|
||||||
|
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1]})
|
||||||
|
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1, 2, 3]})
|
||||||
|
|
||||||
|
result = json.loads(runner.checkpoint_file.read_text())
|
||||||
|
assert result["completed_prompts"] == [1, 2, 3]
|
||||||
|
|
||||||
|
def test_with_lock(self, runner):
|
||||||
|
lock = Lock()
|
||||||
|
data = {"run_name": "test", "completed_prompts": [42]}
|
||||||
|
runner._save_checkpoint(data, lock=lock)
|
||||||
|
|
||||||
|
result = json.loads(runner.checkpoint_file.read_text())
|
||||||
|
assert result["completed_prompts"] == [42]
|
||||||
|
|
||||||
|
def test_without_lock(self, runner):
|
||||||
|
data = {"run_name": "test", "completed_prompts": [99]}
|
||||||
|
runner._save_checkpoint(data, lock=None)
|
||||||
|
|
||||||
|
result = json.loads(runner.checkpoint_file.read_text())
|
||||||
|
assert result["completed_prompts"] == [99]
|
||||||
|
|
||||||
|
def test_creates_parent_dirs(self, tmp_path):
|
||||||
|
runner_deep = BatchRunner.__new__(BatchRunner)
|
||||||
|
runner_deep.checkpoint_file = tmp_path / "deep" / "nested" / "checkpoint.json"
|
||||||
|
|
||||||
|
data = {"run_name": "test", "completed_prompts": []}
|
||||||
|
runner_deep._save_checkpoint(data)
|
||||||
|
|
||||||
|
assert runner_deep.checkpoint_file.exists()
|
||||||
|
|
||||||
|
def test_no_temp_files_left(self, runner):
|
||||||
|
runner._save_checkpoint({"run_name": "test", "completed_prompts": []})
|
||||||
|
|
||||||
|
tmp_files = [f for f in runner.checkpoint_file.parent.iterdir()
|
||||||
|
if ".tmp" in f.name]
|
||||||
|
assert len(tmp_files) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadCheckpoint:
|
||||||
|
"""Verify _load_checkpoint reads existing data or returns defaults."""
|
||||||
|
|
||||||
|
def test_returns_empty_when_no_file(self, runner):
|
||||||
|
result = runner._load_checkpoint()
|
||||||
|
assert result.get("completed_prompts", []) == []
|
||||||
|
|
||||||
|
def test_loads_existing_checkpoint(self, runner):
|
||||||
|
data = {"run_name": "test_run", "completed_prompts": [5, 10, 15],
|
||||||
|
"batch_stats": {"0": {"processed": 3}}}
|
||||||
|
runner.checkpoint_file.write_text(json.dumps(data))
|
||||||
|
|
||||||
|
result = runner._load_checkpoint()
|
||||||
|
assert result["completed_prompts"] == [5, 10, 15]
|
||||||
|
assert result["batch_stats"]["0"]["processed"] == 3
|
||||||
|
|
||||||
|
def test_handles_corrupt_json(self, runner):
|
||||||
|
runner.checkpoint_file.write_text("{broken json!!")
|
||||||
|
|
||||||
|
result = runner._load_checkpoint()
|
||||||
|
# Should return empty/default, not crash
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResumePreservesProgress:
|
||||||
|
"""Verify that initializing a run with resume=True loads prior checkpoint."""
|
||||||
|
|
||||||
|
def test_completed_prompts_loaded_from_checkpoint(self, runner):
|
||||||
|
# Simulate a prior run that completed prompts 0-4
|
||||||
|
prior = {
|
||||||
|
"run_name": "test_run",
|
||||||
|
"completed_prompts": [0, 1, 2, 3, 4],
|
||||||
|
"batch_stats": {"0": {"processed": 5}},
|
||||||
|
"last_updated": "2026-01-01T00:00:00",
|
||||||
|
}
|
||||||
|
runner.checkpoint_file.write_text(json.dumps(prior))
|
||||||
|
|
||||||
|
# Load checkpoint like run() does
|
||||||
|
checkpoint_data = runner._load_checkpoint()
|
||||||
|
if checkpoint_data.get("run_name") != runner.run_name:
|
||||||
|
checkpoint_data = {
|
||||||
|
"run_name": runner.run_name,
|
||||||
|
"completed_prompts": [],
|
||||||
|
"batch_stats": {},
|
||||||
|
"last_updated": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
completed_set = set(checkpoint_data.get("completed_prompts", []))
|
||||||
|
assert completed_set == {0, 1, 2, 3, 4}
|
||||||
|
|
||||||
|
def test_different_run_name_starts_fresh(self, runner):
|
||||||
|
prior = {
|
||||||
|
"run_name": "different_run",
|
||||||
|
"completed_prompts": [0, 1, 2],
|
||||||
|
"batch_stats": {},
|
||||||
|
}
|
||||||
|
runner.checkpoint_file.write_text(json.dumps(prior))
|
||||||
|
|
||||||
|
checkpoint_data = runner._load_checkpoint()
|
||||||
|
if checkpoint_data.get("run_name") != runner.run_name:
|
||||||
|
checkpoint_data = {
|
||||||
|
"run_name": runner.run_name,
|
||||||
|
"completed_prompts": [],
|
||||||
|
"batch_stats": {},
|
||||||
|
"last_updated": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert checkpoint_data["completed_prompts"] == []
|
||||||
|
assert checkpoint_data["run_name"] == "test_run"
|
||||||
@@ -37,7 +37,6 @@ import shlex
|
|||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@@ -707,25 +706,9 @@ class ProcessRegistry:
|
|||||||
"session_key": s.session_key,
|
"session_key": s.session_key,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Atomic write: temp file + os.replace to avoid corruption on crash
|
# Atomic write to avoid corruption on crash
|
||||||
CHECKPOINT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
from utils import atomic_json_write
|
||||||
fd, tmp_path = tempfile.mkstemp(
|
atomic_json_write(CHECKPOINT_PATH, entries)
|
||||||
dir=str(CHECKPOINT_PATH.parent),
|
|
||||||
prefix='.checkpoint_',
|
|
||||||
suffix='.tmp',
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(entries, f, indent=2, ensure_ascii=False)
|
|
||||||
f.flush()
|
|
||||||
os.fsync(f.fileno())
|
|
||||||
os.replace(tmp_path, CHECKPOINT_PATH)
|
|
||||||
except BaseException:
|
|
||||||
try:
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
|
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
|
||||||
|
|
||||||
@@ -774,26 +757,9 @@ class ProcessRegistry:
|
|||||||
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
|
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
|
||||||
|
|
||||||
# Clear the checkpoint (will be rewritten as processes finish)
|
# Clear the checkpoint (will be rewritten as processes finish)
|
||||||
# Use atomic write to avoid corruption
|
|
||||||
try:
|
try:
|
||||||
CHECKPOINT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
from utils import atomic_json_write
|
||||||
fd, tmp_path = tempfile.mkstemp(
|
atomic_json_write(CHECKPOINT_PATH, [])
|
||||||
dir=str(CHECKPOINT_PATH.parent),
|
|
||||||
prefix='.checkpoint_',
|
|
||||||
suffix='.tmp',
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
|
||||||
f.write("[]")
|
|
||||||
f.flush()
|
|
||||||
os.fsync(f.fileno())
|
|
||||||
os.replace(tmp_path, CHECKPOINT_PATH)
|
|
||||||
except BaseException:
|
|
||||||
try:
|
|
||||||
os.unlink(tmp_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
|
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
|||||||
41
utils.py
Normal file
41
utils.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Shared utility functions for hermes-agent."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
|
||||||
|
def atomic_json_write(path: Union[str, Path], data: Any, *, indent: int = 2) -> None:
|
||||||
|
"""Write JSON data to a file atomically.
|
||||||
|
|
||||||
|
Uses temp file + fsync + os.replace to ensure the target file is never
|
||||||
|
left in a partially-written state. If the process crashes mid-write,
|
||||||
|
the previous version of the file remains intact.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Target file path (will be created or overwritten).
|
||||||
|
data: JSON-serializable data to write.
|
||||||
|
indent: JSON indentation (default 2).
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
fd, tmp_path = tempfile.mkstemp(
|
||||||
|
dir=str(path.parent),
|
||||||
|
prefix=f".{path.stem}_",
|
||||||
|
suffix=".tmp",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=indent, ensure_ascii=False)
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
os.replace(tmp_path, path)
|
||||||
|
except BaseException:
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
Reference in New Issue
Block a user