mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-28 06:51:16 +08:00
feat: add inline diff previews for write actions
Show inline diffs in the CLI transcript when write_file, patch, or skill_manage modifies files. Captures a filesystem snapshot before the tool runs, computes a unified diff after, and renders it with ANSI coloring in the activity feed. Adds tool_start_callback and tool_complete_callback hooks to AIAgent for pre/post tool execution notifications. Also fixes _extract_parallel_scope_path to normalize relative paths to absolute, preventing the parallel overlap detection from missing conflicts when the same file is referenced with different path styles. Gated by display.inline_diffs config option (default: true). Based on PR #3774 by @kshitijk4poor.
This commit is contained in:
313
agent/display.py
313
agent/display.py
@@ -10,6 +10,9 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from difflib import unified_diff
|
||||
from pathlib import Path
|
||||
|
||||
# ANSI escape codes for coloring tool failure indicators
|
||||
_RED = "\033[31m"
|
||||
@@ -17,6 +20,22 @@ _RESET = "\033[0m"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ANSI_RESET = "\033[0m"
|
||||
_ANSI_DIM = "\033[38;2;150;150;150m"
|
||||
_ANSI_FILE = "\033[38;2;180;160;255m"
|
||||
_ANSI_HUNK = "\033[38;2;120;120;140m"
|
||||
_ANSI_MINUS = "\033[38;2;255;255;255;48;2;120;20;20m"
|
||||
_ANSI_PLUS = "\033[38;2;255;255;255;48;2;20;90;20m"
|
||||
_MAX_INLINE_DIFF_FILES = 6
|
||||
_MAX_INLINE_DIFF_LINES = 80
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalEditSnapshot:
|
||||
"""Pre-tool filesystem snapshot used to render diffs locally after writes."""
|
||||
paths: list[Path] = field(default_factory=list)
|
||||
before: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
# =========================================================================
|
||||
# Configurable tool preview length (0 = no limit)
|
||||
# Set once at startup by CLI or gateway from display.tool_preview_length config.
|
||||
@@ -218,6 +237,300 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int | None = None) -
|
||||
return preview
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Inline diff previews for write actions
|
||||
# =========================================================================
|
||||
|
||||
def _resolved_path(path: str) -> Path:
|
||||
"""Resolve a possibly-relative filesystem path against the current cwd."""
|
||||
candidate = Path(os.path.expanduser(path))
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
return Path.cwd() / candidate
|
||||
|
||||
|
||||
def _snapshot_text(path: Path) -> str | None:
|
||||
"""Return UTF-8 file content, or None for missing/unreadable files."""
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _display_diff_path(path: Path) -> str:
|
||||
"""Prefer cwd-relative paths in diffs when available."""
|
||||
try:
|
||||
return str(path.resolve().relative_to(Path.cwd().resolve()))
|
||||
except Exception:
|
||||
return str(path)
|
||||
|
||||
|
||||
def _resolve_skill_manage_paths(args: dict) -> list[Path]:
|
||||
"""Resolve skill_manage write targets to filesystem paths."""
|
||||
action = args.get("action")
|
||||
name = args.get("name")
|
||||
if not action or not name:
|
||||
return []
|
||||
|
||||
from tools.skill_manager_tool import _find_skill, _resolve_skill_dir
|
||||
|
||||
if action == "create":
|
||||
skill_dir = _resolve_skill_dir(name, args.get("category"))
|
||||
return [skill_dir / "SKILL.md"]
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return []
|
||||
|
||||
skill_dir = Path(existing["path"])
|
||||
if action in {"edit", "patch"}:
|
||||
file_path = args.get("file_path")
|
||||
return [skill_dir / file_path] if file_path else [skill_dir / "SKILL.md"]
|
||||
if action in {"write_file", "remove_file"}:
|
||||
file_path = args.get("file_path")
|
||||
return [skill_dir / file_path] if file_path else []
|
||||
if action == "delete":
|
||||
files = [path for path in sorted(skill_dir.rglob("*")) if path.is_file()]
|
||||
return files
|
||||
return []
|
||||
|
||||
|
||||
def _resolve_local_edit_paths(tool_name: str, function_args: dict | None) -> list[Path]:
|
||||
"""Resolve local filesystem targets for write-capable tools."""
|
||||
if not isinstance(function_args, dict):
|
||||
return []
|
||||
|
||||
if tool_name == "write_file":
|
||||
path = function_args.get("path")
|
||||
return [_resolved_path(path)] if path else []
|
||||
|
||||
if tool_name == "patch":
|
||||
path = function_args.get("path")
|
||||
return [_resolved_path(path)] if path else []
|
||||
|
||||
if tool_name == "skill_manage":
|
||||
return _resolve_skill_manage_paths(function_args)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def capture_local_edit_snapshot(tool_name: str, function_args: dict | None) -> LocalEditSnapshot | None:
|
||||
"""Capture before-state for local write previews."""
|
||||
paths = _resolve_local_edit_paths(tool_name, function_args)
|
||||
if not paths:
|
||||
return None
|
||||
|
||||
snapshot = LocalEditSnapshot(paths=paths)
|
||||
for path in paths:
|
||||
snapshot.before[str(path)] = _snapshot_text(path)
|
||||
return snapshot
|
||||
|
||||
|
||||
def _result_succeeded(result: str | None) -> bool:
|
||||
"""Conservatively detect whether a tool result represents success."""
|
||||
if not result:
|
||||
return False
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return False
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
if data.get("error"):
|
||||
return False
|
||||
if "success" in data:
|
||||
return bool(data.get("success"))
|
||||
return True
|
||||
|
||||
|
||||
def _diff_from_snapshot(snapshot: LocalEditSnapshot | None) -> str | None:
|
||||
"""Generate unified diff text from a stored before-state and current files."""
|
||||
if not snapshot:
|
||||
return None
|
||||
|
||||
chunks: list[str] = []
|
||||
for path in snapshot.paths:
|
||||
before = snapshot.before.get(str(path))
|
||||
after = _snapshot_text(path)
|
||||
if before == after:
|
||||
continue
|
||||
|
||||
display_path = _display_diff_path(path)
|
||||
diff = "".join(
|
||||
unified_diff(
|
||||
[] if before is None else before.splitlines(keepends=True),
|
||||
[] if after is None else after.splitlines(keepends=True),
|
||||
fromfile=f"a/{display_path}",
|
||||
tofile=f"b/{display_path}",
|
||||
)
|
||||
)
|
||||
if diff:
|
||||
chunks.append(diff)
|
||||
|
||||
if not chunks:
|
||||
return None
|
||||
return "".join(chunk if chunk.endswith("\n") else chunk + "\n" for chunk in chunks)
|
||||
|
||||
|
||||
def extract_edit_diff(
|
||||
tool_name: str,
|
||||
result: str | None,
|
||||
*,
|
||||
function_args: dict | None = None,
|
||||
snapshot: LocalEditSnapshot | None = None,
|
||||
) -> str | None:
|
||||
"""Extract a unified diff from a file-edit tool result."""
|
||||
if tool_name == "patch" and result:
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = None
|
||||
if isinstance(data, dict):
|
||||
diff = data.get("diff")
|
||||
if isinstance(diff, str) and diff.strip():
|
||||
return diff
|
||||
|
||||
if tool_name not in {"write_file", "patch", "skill_manage"}:
|
||||
return None
|
||||
if not _result_succeeded(result):
|
||||
return None
|
||||
return _diff_from_snapshot(snapshot)
|
||||
|
||||
|
||||
def _emit_inline_diff(diff_text: str, print_fn) -> bool:
|
||||
"""Emit rendered diff text through the CLI's prompt_toolkit-safe printer."""
|
||||
if print_fn is None or not diff_text:
|
||||
return False
|
||||
try:
|
||||
print_fn(" ┊ review diff")
|
||||
for line in diff_text.rstrip("\n").splitlines():
|
||||
print_fn(line)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _render_inline_unified_diff(diff: str) -> list[str]:
|
||||
"""Render unified diff lines in Hermes' inline transcript style."""
|
||||
rendered: list[str] = []
|
||||
from_file = None
|
||||
to_file = None
|
||||
|
||||
for raw_line in diff.splitlines():
|
||||
if raw_line.startswith("--- "):
|
||||
from_file = raw_line[4:].strip()
|
||||
continue
|
||||
if raw_line.startswith("+++ "):
|
||||
to_file = raw_line[4:].strip()
|
||||
if from_file or to_file:
|
||||
rendered.append(f"{_ANSI_FILE}{from_file or 'a/?'} → {to_file or 'b/?'}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line.startswith("@@"):
|
||||
rendered.append(f"{_ANSI_HUNK}{raw_line}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line.startswith("-"):
|
||||
rendered.append(f"{_ANSI_MINUS}{raw_line}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line.startswith("+"):
|
||||
rendered.append(f"{_ANSI_PLUS}{raw_line}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line.startswith(" "):
|
||||
rendered.append(f"{_ANSI_DIM}{raw_line}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line:
|
||||
rendered.append(raw_line)
|
||||
|
||||
return rendered
|
||||
|
||||
|
||||
def _split_unified_diff_sections(diff: str) -> list[str]:
|
||||
"""Split a unified diff into per-file sections."""
|
||||
sections: list[list[str]] = []
|
||||
current: list[str] = []
|
||||
|
||||
for line in diff.splitlines():
|
||||
if line.startswith("--- ") and current:
|
||||
sections.append(current)
|
||||
current = [line]
|
||||
continue
|
||||
current.append(line)
|
||||
|
||||
if current:
|
||||
sections.append(current)
|
||||
|
||||
return ["\n".join(section) for section in sections if section]
|
||||
|
||||
|
||||
def _summarize_rendered_diff_sections(
|
||||
diff: str,
|
||||
*,
|
||||
max_files: int = _MAX_INLINE_DIFF_FILES,
|
||||
max_lines: int = _MAX_INLINE_DIFF_LINES,
|
||||
) -> list[str]:
|
||||
"""Render diff sections while capping file count and total line count."""
|
||||
sections = _split_unified_diff_sections(diff)
|
||||
rendered: list[str] = []
|
||||
omitted_files = 0
|
||||
omitted_lines = 0
|
||||
|
||||
for idx, section in enumerate(sections):
|
||||
if idx >= max_files:
|
||||
omitted_files += 1
|
||||
omitted_lines += len(_render_inline_unified_diff(section))
|
||||
continue
|
||||
|
||||
section_lines = _render_inline_unified_diff(section)
|
||||
remaining_budget = max_lines - len(rendered)
|
||||
if remaining_budget <= 0:
|
||||
omitted_lines += len(section_lines)
|
||||
omitted_files += 1
|
||||
continue
|
||||
|
||||
if len(section_lines) <= remaining_budget:
|
||||
rendered.extend(section_lines)
|
||||
continue
|
||||
|
||||
rendered.extend(section_lines[:remaining_budget])
|
||||
omitted_lines += len(section_lines) - remaining_budget
|
||||
omitted_files += 1 + max(0, len(sections) - idx - 1)
|
||||
for leftover in sections[idx + 1:]:
|
||||
omitted_lines += len(_render_inline_unified_diff(leftover))
|
||||
break
|
||||
|
||||
if omitted_files or omitted_lines:
|
||||
summary = f"… omitted {omitted_lines} diff line(s)"
|
||||
if omitted_files:
|
||||
summary += f" across {omitted_files} additional file(s)/section(s)"
|
||||
rendered.append(f"{_ANSI_HUNK}{summary}{_ANSI_RESET}")
|
||||
|
||||
return rendered
|
||||
|
||||
|
||||
def render_edit_diff_with_delta(
|
||||
tool_name: str,
|
||||
result: str | None,
|
||||
*,
|
||||
function_args: dict | None = None,
|
||||
snapshot: LocalEditSnapshot | None = None,
|
||||
print_fn=None,
|
||||
) -> bool:
|
||||
"""Render an edit diff inline without taking over the terminal UI."""
|
||||
diff = extract_edit_diff(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
if not diff:
|
||||
return False
|
||||
try:
|
||||
rendered_lines = _summarize_rendered_diff_sections(diff)
|
||||
except Exception as exc:
|
||||
logger.debug("Could not render inline diff: %s", exc)
|
||||
return False
|
||||
return _emit_inline_diff("\n".join(rendered_lines), print_fn)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# KawaiiSpinner
|
||||
# =========================================================================
|
||||
|
||||
33
cli.py
33
cli.py
@@ -1077,12 +1077,16 @@ class HermesCLI:
|
||||
# streaming: stream tokens to the terminal as they arrive (display.streaming in config.yaml)
|
||||
self.streaming_enabled = CLI_CONFIG["display"].get("streaming", False)
|
||||
|
||||
# Inline diff previews for write actions (display.inline_diffs in config.yaml)
|
||||
self._inline_diffs_enabled = CLI_CONFIG["display"].get("inline_diffs", True)
|
||||
|
||||
# Streaming display state
|
||||
self._stream_buf = "" # Partial line buffer for line-buffered rendering
|
||||
self._stream_started = False # True once first delta arrives
|
||||
self._stream_box_opened = False # True once the response box header is printed
|
||||
self._reasoning_stream_started = False # True once live reasoning starts streaming
|
||||
self._reasoning_preview_buf = "" # Coalesce tiny reasoning chunks for [thinking] output
|
||||
self._pending_edit_snapshots = {}
|
||||
|
||||
# Configuration - priority: CLI args > env vars > config file
|
||||
# Model comes from: CLI arg or config.yaml (single source of truth).
|
||||
@@ -2132,6 +2136,8 @@ class HermesCLI:
|
||||
checkpoint_max_snapshots=self.checkpoint_max_snapshots,
|
||||
pass_session_id=self.pass_session_id,
|
||||
tool_progress_callback=self._on_tool_progress,
|
||||
tool_start_callback=self._on_tool_start if self._inline_diffs_enabled else None,
|
||||
tool_complete_callback=self._on_tool_complete if self._inline_diffs_enabled else None,
|
||||
stream_delta_callback=self._stream_delta if self.streaming_enabled else None,
|
||||
tool_gen_callback=self._on_tool_gen_start if self.streaming_enabled else None,
|
||||
)
|
||||
@@ -5034,6 +5040,33 @@ class HermesCLI:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _on_tool_start(self, tool_call_id: str, function_name: str, function_args: dict):
|
||||
"""Capture local before-state for write-capable tools."""
|
||||
try:
|
||||
from agent.display import capture_local_edit_snapshot
|
||||
|
||||
snapshot = capture_local_edit_snapshot(function_name, function_args)
|
||||
if snapshot is not None:
|
||||
self._pending_edit_snapshots[tool_call_id] = snapshot
|
||||
except Exception:
|
||||
logger.debug("Edit snapshot capture failed for %s", function_name, exc_info=True)
|
||||
|
||||
def _on_tool_complete(self, tool_call_id: str, function_name: str, function_args: dict, function_result: str):
|
||||
"""Render file edits with inline diff after write-capable tools complete."""
|
||||
snapshot = self._pending_edit_snapshots.pop(tool_call_id, None)
|
||||
try:
|
||||
from agent.display import render_edit_diff_with_delta
|
||||
|
||||
render_edit_diff_with_delta(
|
||||
function_name,
|
||||
function_result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
print_fn=_cprint,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Edit diff preview failed for %s", function_name, exc_info=True)
|
||||
|
||||
# ====================================================================
|
||||
# Voice mode methods
|
||||
# ====================================================================
|
||||
|
||||
@@ -352,6 +352,7 @@ DEFAULT_CONFIG = {
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
"inline_diffs": True, # Show inline diff previews for write actions (write_file, patch, skill_manage)
|
||||
"show_cost": False, # Show $ cost in the status bar (off by default)
|
||||
"skin": "default",
|
||||
"tool_progress_command": False, # Enable /verbose command in messaging gateway
|
||||
|
||||
37
run_agent.py
37
run_agent.py
@@ -320,8 +320,12 @@ def _extract_parallel_scope_path(tool_name: str, function_args: dict) -> Path |
|
||||
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||
return None
|
||||
|
||||
expanded = Path(raw_path).expanduser()
|
||||
if expanded.is_absolute():
|
||||
return Path(os.path.abspath(str(expanded)))
|
||||
|
||||
# Avoid resolve(); the file may not exist yet.
|
||||
return Path(raw_path).expanduser()
|
||||
return Path(os.path.abspath(str(Path.cwd() / expanded)))
|
||||
|
||||
|
||||
def _paths_overlap(left: Path, right: Path) -> bool:
|
||||
@@ -486,6 +490,8 @@ class AIAgent:
|
||||
provider_data_collection: str = None,
|
||||
session_id: str = None,
|
||||
tool_progress_callback: callable = None,
|
||||
tool_start_callback: callable = None,
|
||||
tool_complete_callback: callable = None,
|
||||
thinking_callback: callable = None,
|
||||
reasoning_callback: callable = None,
|
||||
clarify_callback: callable = None,
|
||||
@@ -620,6 +626,8 @@ class AIAgent:
|
||||
).start()
|
||||
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.tool_start_callback = tool_start_callback
|
||||
self.tool_complete_callback = tool_complete_callback
|
||||
self.thinking_callback = thinking_callback
|
||||
self.reasoning_callback = reasoning_callback
|
||||
self._reasoning_deltas_fired = False # Set by _fire_reasoning_delta, reset per API call
|
||||
@@ -5553,7 +5561,7 @@ class AIAgent:
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||
|
||||
for _, name, args in parsed_calls:
|
||||
for tc, name, args in parsed_calls:
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
preview = _build_tool_preview(name, args)
|
||||
@@ -5561,6 +5569,13 @@ class AIAgent:
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
for tc, name, args in parsed_calls:
|
||||
if self.tool_start_callback:
|
||||
try:
|
||||
self.tool_start_callback(tc.id, name, args)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool start callback error: {cb_err}")
|
||||
|
||||
# ── Concurrent execution ─────────────────────────────────────────
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
|
||||
results = [None] * num_tools
|
||||
@@ -5631,6 +5646,12 @@ class AIAgent:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
if self.tool_complete_callback:
|
||||
try:
|
||||
self.tool_complete_callback(tc.id, name, args, function_result)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool complete callback error: {cb_err}")
|
||||
|
||||
# Truncate oversized results
|
||||
MAX_TOOL_RESULT_CHARS = 100_000
|
||||
if len(function_result) > MAX_TOOL_RESULT_CHARS:
|
||||
@@ -5719,6 +5740,12 @@ class AIAgent:
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
if self.tool_start_callback:
|
||||
try:
|
||||
self.tool_start_callback(tool_call.id, function_name, function_args)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool start callback error: {cb_err}")
|
||||
|
||||
# Checkpoint: snapshot working dir before file-mutating tools
|
||||
if function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
|
||||
try:
|
||||
@@ -5883,6 +5910,12 @@ class AIAgent:
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
||||
|
||||
if self.tool_complete_callback:
|
||||
try:
|
||||
self.tool_complete_callback(tool_call.id, function_name, function_args, function_result)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool complete callback error: {cb_err}")
|
||||
|
||||
# Guard against tools returning absurdly large content that would
|
||||
# blow up the context window. 100K chars ≈ 25K tokens — generous
|
||||
# enough for any reasonable tool output but prevents catastrophic
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
"""Tests for agent/display.py — build_tool_preview()."""
|
||||
"""Tests for agent/display.py — build_tool_preview() and inline diff previews."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from agent.display import build_tool_preview
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.display import (
|
||||
build_tool_preview,
|
||||
capture_local_edit_snapshot,
|
||||
extract_edit_diff,
|
||||
_render_inline_unified_diff,
|
||||
_summarize_rendered_diff_sections,
|
||||
render_edit_diff_with_delta,
|
||||
)
|
||||
|
||||
|
||||
class TestBuildToolPreview:
|
||||
@@ -83,3 +93,110 @@ class TestBuildToolPreview:
|
||||
assert build_tool_preview("terminal", 0) is None
|
||||
assert build_tool_preview("terminal", "") is None
|
||||
assert build_tool_preview("terminal", []) is None
|
||||
|
||||
|
||||
class TestEditDiffPreview:
|
||||
def test_extract_edit_diff_for_patch(self):
|
||||
diff = extract_edit_diff("patch", '{"success": true, "diff": "--- a/x\\n+++ b/x\\n"}')
|
||||
assert diff is not None
|
||||
assert "+++ b/x" in diff
|
||||
|
||||
def test_render_inline_unified_diff_colors_added_and_removed_lines(self):
|
||||
rendered = _render_inline_unified_diff(
|
||||
"--- a/cli.py\n"
|
||||
"+++ b/cli.py\n"
|
||||
"@@ -1,2 +1,2 @@\n"
|
||||
"-old line\n"
|
||||
"+new line\n"
|
||||
" context\n"
|
||||
)
|
||||
|
||||
assert "a/cli.py" in rendered[0]
|
||||
assert "b/cli.py" in rendered[0]
|
||||
assert any("old line" in line for line in rendered)
|
||||
assert any("new line" in line for line in rendered)
|
||||
assert any("48;2;" in line for line in rendered)
|
||||
|
||||
def test_extract_edit_diff_ignores_non_edit_tools(self):
|
||||
assert extract_edit_diff("web_search", '{"diff": "--- a\\n+++ b\\n"}') is None
|
||||
|
||||
def test_extract_edit_diff_uses_local_snapshot_for_write_file(self, tmp_path):
|
||||
target = tmp_path / "note.txt"
|
||||
target.write_text("old\n", encoding="utf-8")
|
||||
|
||||
snapshot = capture_local_edit_snapshot("write_file", {"path": str(target)})
|
||||
|
||||
target.write_text("new\n", encoding="utf-8")
|
||||
|
||||
diff = extract_edit_diff(
|
||||
"write_file",
|
||||
'{"bytes_written": 4}',
|
||||
function_args={"path": str(target)},
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
assert diff is not None
|
||||
assert "--- a/" in diff
|
||||
assert "+++ b/" in diff
|
||||
assert "-old" in diff
|
||||
assert "+new" in diff
|
||||
|
||||
def test_render_edit_diff_with_delta_invokes_printer(self):
|
||||
printer = MagicMock()
|
||||
|
||||
rendered = render_edit_diff_with_delta(
|
||||
"patch",
|
||||
'{"diff": "--- a/x\\n+++ b/x\\n@@ -1 +1 @@\\n-old\\n+new\\n"}',
|
||||
print_fn=printer,
|
||||
)
|
||||
|
||||
assert rendered is True
|
||||
assert printer.call_count >= 2
|
||||
calls = [call.args[0] for call in printer.call_args_list]
|
||||
assert any("a/x" in line and "b/x" in line for line in calls)
|
||||
assert any("old" in line for line in calls)
|
||||
assert any("new" in line for line in calls)
|
||||
|
||||
def test_render_edit_diff_with_delta_skips_without_diff(self):
|
||||
rendered = render_edit_diff_with_delta(
|
||||
"patch",
|
||||
'{"success": true}',
|
||||
)
|
||||
|
||||
assert rendered is False
|
||||
|
||||
def test_render_edit_diff_with_delta_handles_renderer_errors(self, monkeypatch):
|
||||
printer = MagicMock()
|
||||
|
||||
monkeypatch.setattr("agent.display._summarize_rendered_diff_sections", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
rendered = render_edit_diff_with_delta(
|
||||
"patch",
|
||||
'{"diff": "--- a/x\\n+++ b/x\\n"}',
|
||||
print_fn=printer,
|
||||
)
|
||||
|
||||
assert rendered is False
|
||||
assert printer.call_count == 0
|
||||
|
||||
def test_summarize_rendered_diff_sections_truncates_large_diff(self):
|
||||
diff = "--- a/x.py\n+++ b/x.py\n" + "".join(f"+line{i}\n" for i in range(120))
|
||||
|
||||
rendered = _summarize_rendered_diff_sections(diff, max_lines=20)
|
||||
|
||||
assert len(rendered) == 21
|
||||
assert "omitted" in rendered[-1]
|
||||
|
||||
def test_summarize_rendered_diff_sections_limits_file_count(self):
|
||||
diff = "".join(
|
||||
f"--- a/file{i}.py\n+++ b/file{i}.py\n+line{i}\n"
|
||||
for i in range(8)
|
||||
)
|
||||
|
||||
rendered = _summarize_rendered_diff_sections(diff, max_files=3, max_lines=50)
|
||||
|
||||
assert any("a/file0.py" in line for line in rendered)
|
||||
assert any("a/file1.py" in line for line in rendered)
|
||||
assert any("a/file2.py" in line for line in rendered)
|
||||
assert not any("a/file7.py" in line for line in rendered)
|
||||
assert "additional file" in rendered[-1]
|
||||
|
||||
@@ -1239,6 +1239,42 @@ class TestConcurrentToolExecution:
|
||||
)
|
||||
assert result == "result"
|
||||
|
||||
def test_sequential_tool_callbacks_fire_in_order(self, agent):
|
||||
tool_call = _mock_tool_call(name="web_search", arguments='{"query":"hello"}', call_id="c1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
|
||||
messages = []
|
||||
starts = []
|
||||
completes = []
|
||||
agent.tool_start_callback = lambda tool_call_id, function_name, function_args: starts.append((tool_call_id, function_name, function_args))
|
||||
agent.tool_complete_callback = lambda tool_call_id, function_name, function_args, function_result: completes.append((tool_call_id, function_name, function_args, function_result))
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value='{"success": true}'):
|
||||
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
|
||||
|
||||
assert starts == [("c1", "web_search", {"query": "hello"})]
|
||||
assert completes == [("c1", "web_search", {"query": "hello"}, '{"success": true}')]
|
||||
|
||||
def test_concurrent_tool_callbacks_fire_for_each_tool(self, agent):
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{"query":"one"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"query":"two"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
starts = []
|
||||
completes = []
|
||||
agent.tool_start_callback = lambda tool_call_id, function_name, function_args: starts.append((tool_call_id, function_name, function_args))
|
||||
agent.tool_complete_callback = lambda tool_call_id, function_name, function_args, function_result: completes.append((tool_call_id, function_name, function_args, function_result))
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=['{"id":1}', '{"id":2}']):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert starts == [
|
||||
("c1", "web_search", {"query": "one"}),
|
||||
("c2", "web_search", {"query": "two"}),
|
||||
]
|
||||
assert len(completes) == 2
|
||||
assert {entry[0] for entry in completes} == {"c1", "c2"}
|
||||
assert {entry[3] for entry in completes} == {'{"id":1}', '{"id":2}'}
|
||||
|
||||
def test_invoke_tool_handles_agent_level_tools(self, agent):
|
||||
"""_invoke_tool should handle todo tool directly."""
|
||||
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
|
||||
@@ -1280,6 +1316,38 @@ class TestPathsOverlap:
|
||||
assert not _paths_overlap(Path("src/a.py"), Path(""))
|
||||
|
||||
|
||||
class TestParallelScopePathNormalization:
|
||||
def test_extract_parallel_scope_path_normalizes_relative_to_cwd(self, tmp_path, monkeypatch):
|
||||
from run_agent import _extract_parallel_scope_path
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
scoped = _extract_parallel_scope_path("write_file", {"path": "./notes.txt"})
|
||||
|
||||
assert scoped == tmp_path / "notes.txt"
|
||||
|
||||
def test_extract_parallel_scope_path_treats_relative_and_absolute_same_file_as_same_scope(self, tmp_path, monkeypatch):
|
||||
from run_agent import _extract_parallel_scope_path, _paths_overlap
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
abs_path = tmp_path / "notes.txt"
|
||||
|
||||
rel_scoped = _extract_parallel_scope_path("write_file", {"path": "notes.txt"})
|
||||
abs_scoped = _extract_parallel_scope_path("write_file", {"path": str(abs_path)})
|
||||
|
||||
assert rel_scoped == abs_scoped
|
||||
assert _paths_overlap(rel_scoped, abs_scoped)
|
||||
|
||||
def test_should_parallelize_tool_batch_rejects_same_file_with_mixed_path_spellings(self, tmp_path, monkeypatch):
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tc1 = _mock_tool_call(name="write_file", arguments='{"path":"notes.txt","content":"one"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="write_file", arguments=f'{{"path":"{tmp_path / "notes.txt"}","content":"two"}}', call_id="c2")
|
||||
|
||||
assert not _should_parallelize_tool_batch([tc1, tc2])
|
||||
|
||||
|
||||
class TestHandleMaxIterations:
|
||||
def test_returns_summary(self, agent):
|
||||
resp = _mock_response(content="Here is a summary of what I did.")
|
||||
|
||||
Reference in New Issue
Block a user