diff --git a/agent/workspace.py b/agent/workspace.py index 394da57e8eb..3d858879ef8 100644 --- a/agent/workspace.py +++ b/agent/workspace.py @@ -202,16 +202,21 @@ def workspace_status(config: dict[str, Any] | None = None) -> dict[str, Any]: index_path = _index_db_path(paths) chunk_count = 0 + index_info: dict[str, Any] = {} if index_path.exists(): try: conn = _open_index_db(paths) try: row = conn.execute("SELECT COUNT(*) AS count FROM chunks").fetchone() chunk_count = int(row["count"] if row else 0) + meta_row = conn.execute("SELECT value FROM meta WHERE key = 'index_info'").fetchone() + if meta_row and meta_row["value"]: + index_info = json.loads(meta_row["value"]) finally: conn.close() except Exception: chunk_count = 0 + index_info = {} return { "success": True, @@ -224,6 +229,8 @@ def workspace_status(config: dict[str, Any] | None = None) -> dict[str, Any]: "chunk_count": chunk_count, "file_count": len(entries), "category_counts": category_counts, + "embedding_backend": index_info.get("embedding_backend", ""), + "dense_backend": index_info.get("dense_backend", ""), "default_subdirs": list(DEFAULT_WORKSPACE_SUBDIRS), } @@ -569,6 +576,180 @@ class WorkspaceEmbedder: norm = math.sqrt(sum(value * value for value in vec)) or 1.0 return [value / norm for value in vec] +class WorkspaceReranker: + """Optional second-stage reranker for fused retrieval candidates.""" + + _MODEL_CACHE: dict[tuple[str, str], Any] = {} + _MODEL_CACHE_LOCK = None + + def __init__(self, config: dict[str, Any]): + kb_cfg = config.get("knowledgebase", {}) or {} + rerank_cfg = kb_cfg.get("reranker", {}) or {} + self.enabled = bool(rerank_cfg.get("enabled", False)) + self.provider = str(rerank_cfg.get("provider", "local") or "local").strip().lower() + self.model = str(rerank_cfg.get("model", "bge-reranker-v2-m3") or "bge-reranker-v2-m3") + self.backend = "disabled" + if WorkspaceReranker._MODEL_CACHE_LOCK is None: + import threading + WorkspaceReranker._MODEL_CACHE_LOCK = threading.Lock() + + def rerank(self, query: str, candidates: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not self.enabled or not candidates: + self.backend = "disabled" + return list(candidates) + if self.provider == "local": + ranked = self._try_local_cross_encoder(query, candidates) + if ranked is not None: + self.backend = "cross-encoder" + return ranked + elif self.provider == "cohere": + ranked = self._try_cohere(query, candidates) + if ranked is not None: + self.backend = "cohere" + return ranked + elif self.provider == "voyage": + ranked = self._try_voyage(query, candidates) + if ranked is not None: + self.backend = "voyage" + return ranked + self.backend = "heuristic" + return self._heuristic(query, candidates) + + def _local_model(self): + try: + import torch + from sentence_transformers import CrossEncoder + except Exception: + return None + if torch.cuda.is_available(): + device = "cuda" + elif getattr(getattr(torch, "backends", None), "mps", None) and torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + cache_key = (self.model, device) + lock = WorkspaceReranker._MODEL_CACHE_LOCK + with lock: + cached = WorkspaceReranker._MODEL_CACHE.get(cache_key) + if cached is not None: + return cached + try: + model = CrossEncoder(self.model, device=device) + except TypeError: + model = CrossEncoder(self.model) + except Exception: + return None + WorkspaceReranker._MODEL_CACHE[cache_key] = model + return model + + def _try_local_cross_encoder(self, query: str, candidates: list[dict[str, Any]]) -> list[dict[str, Any]] | None: + model = self._local_model() + if model is None: + return None + pairs = [(query, candidate.get("content", "")) for candidate in candidates] + try: + scores = model.predict(pairs) + except Exception: + return None + if hasattr(scores, "tolist"): + scores = scores.tolist() + enriched = [] + for candidate, score in zip(candidates, scores): + item = dict(candidate) + item["rerank_score"] = float(score) + enriched.append(item) + enriched.sort(key=lambda item: (item.get("rerank_score", 0.0), item.get("rrf_score", 0.0), item.get("dense_score", 0.0)), reverse=True) + return enriched + + def _try_cohere(self, query: str, candidates: list[dict[str, Any]]) -> list[dict[str, Any]] | None: + api_key = os.getenv("COHERE_API_KEY", "").strip() + if not api_key: + return None + try: + import requests + response = requests.post( + "https://api.cohere.com/v2/rerank", + headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, + json={ + "model": self.model, + "query": query, + "documents": [candidate.get("content", "") for candidate in candidates], + "top_n": len(candidates), + }, + timeout=30, + ) + response.raise_for_status() + payload = response.json() + except Exception: + return None + results = payload.get("results") or [] + if not results: + return None + return self._apply_remote_ranking(candidates, results) + + def _try_voyage(self, query: str, candidates: list[dict[str, Any]]) -> list[dict[str, Any]] | None: + api_key = os.getenv("VOYAGE_API_KEY", "").strip() + if not api_key: + return None + try: + import requests + response = requests.post( + "https://api.voyageai.com/v1/rerank", + headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, + json={ + "model": self.model, + "query": query, + "documents": [candidate.get("content", "") for candidate in candidates], + "top_k": len(candidates), + }, + timeout=30, + ) + response.raise_for_status() + payload = response.json() + except Exception: + return None + results = payload.get("data") or payload.get("results") or [] + if not results: + return None + return self._apply_remote_ranking(candidates, results) + + def _apply_remote_ranking(self, candidates: list[dict[str, Any]], results: list[dict[str, Any]]) -> list[dict[str, Any]]: + enriched: list[dict[str, Any]] = [] + seen: set[int] = set() + for entry in results: + idx = entry.get("index") + if idx is None: + idx = entry.get("document_index") + if idx is None or idx in seen or idx < 0 or idx >= len(candidates): + continue + seen.add(idx) + item = dict(candidates[idx]) + item["rerank_score"] = float(entry.get("relevance_score", entry.get("score", 0.0))) + enriched.append(item) + if len(enriched) != len(candidates): + for idx, candidate in enumerate(candidates): + if idx in seen: + continue + item = dict(candidate) + item.setdefault("rerank_score", item.get("rrf_score", 0.0)) + enriched.append(item) + enriched.sort(key=lambda item: (item.get("rerank_score", 0.0), item.get("rrf_score", 0.0), item.get("dense_score", 0.0)), reverse=True) + return enriched + + def _heuristic(self, query: str, candidates: list[dict[str, Any]]) -> list[dict[str, Any]]: + query_terms = set(re.findall(r"[A-Za-z0-9_./:-]+", query.lower())) + enriched: list[dict[str, Any]] = [] + for candidate in candidates: + content_terms = set(re.findall(r"[A-Za-z0-9_./:-]+", candidate.get("content", "").lower())) + overlap = len(query_terms & content_terms) + lexical = overlap / max(1, len(query_terms)) + item = dict(candidate) + item["rerank_score"] = lexical + float(item.get("dense_score", 0.0)) * 0.1 + enriched.append(item) + enriched.sort(key=lambda item: (item.get("rerank_score", 0.0), item.get("rrf_score", 0.0), item.get("dense_score", 0.0)), reverse=True) + return enriched + + def _index_db_path(paths: WorkspacePaths) -> Path: return paths.indexes_dir / "workspace.sqlite" @@ -611,20 +792,51 @@ def _open_index_db(paths: WorkspacePaths) -> sqlite3.Connection: return conn +def _maybe_enable_sqlite_vec(conn: sqlite3.Connection, dimensions: int | None = None): + try: + import sqlite_vec + except Exception: + return None + try: + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False) + if dimensions: + conn.execute( + f"CREATE VIRTUAL TABLE IF NOT EXISTS chunks_vec USING vec0(embedding float[{int(dimensions)}])" + ) + return sqlite_vec + except Exception: + return None + + +def _delete_chunk_rows(conn: sqlite3.Connection, rel_path: str, sqlite_vec_module=None) -> None: + rowids = [row["rowid"] for row in conn.execute("SELECT rowid FROM chunks WHERE rel_path = ?", (rel_path,)).fetchall()] + if sqlite_vec_module and rowids: + for rowid in rowids: + conn.execute("DELETE FROM chunks_vec WHERE rowid = ?", (rowid,)) + conn.execute("DELETE FROM chunks WHERE rel_path = ?", (rel_path,)) + conn.execute("DELETE FROM chunks_fts WHERE rel_path = ?", (rel_path,)) + conn.execute("DELETE FROM files WHERE rel_path = ?", (rel_path,)) + + def _text_hash(text: str) -> str: return hashlib.sha256(text.encode("utf-8")).hexdigest() -def _chunk_text(text: str, path: Path, config: dict[str, Any]) -> list[dict[str, Any]]: +def _chunk_cfg(config: dict[str, Any]) -> tuple[int, int]: kb_cfg = config.get("knowledgebase", {}) or {} chunk_cfg = kb_cfg.get("chunking", {}) or {} target_chars = max(256, int(chunk_cfg.get("default_tokens", 512) or 512) * 4) overlap_chars = max(0, int(chunk_cfg.get("overlap_tokens", 80) or 80) * 4) + return target_chars, overlap_chars + + +def _yield_chunk_windows(text: str, target_chars: int, overlap_chars: int) -> list[str]: normalized = text.replace("\r\n", "\n").strip() if not normalized: return [] - - chunks: list[dict[str, Any]] = [] + windows: list[str] = [] start = 0 text_len = len(normalized) while start < text_len: @@ -637,19 +849,115 @@ def _chunk_text(text: str, path: Path, config: dict[str, Any]) -> list[dict[str, end = boundary chunk = normalized[start:end].strip() if chunk: - chunks.append({ - "content": chunk, - "token_estimate": estimate_tokens_rough(chunk), - }) + windows.append(chunk) if end >= text_len: break next_start = max(start + 1, end - overlap_chars) if next_start <= start: next_start = end start = next_start + return windows + + +def _build_chunk(path: Path, content: str, kind: str, section: str = "") -> dict[str, Any]: + prefix_lines = [f"Path: {path.as_posix()}"] + if section: + prefix_lines.append(f"Section: {section}") + if kind: + prefix_lines.append(f"Kind: {kind}") + body = "\n".join(prefix_lines) + "\n\n" + content.strip() + return { + "content": body, + "token_estimate": estimate_tokens_rough(body), + "chunk_kind": kind, + "section_title": section, + } + + +def _chunk_markdown(text: str, path: Path, target_chars: int, overlap_chars: int) -> list[dict[str, Any]]: + lines = text.replace("\r\n", "\n").splitlines() + sections: list[tuple[str, str]] = [] + current_heading = "" + current_lines: list[str] = [] + for line in lines: + if re.match(r"^#{1,6}\s+", line.strip()): + if current_lines: + sections.append((current_heading, "\n".join(current_lines).strip())) + current_heading = line.strip().lstrip("#").strip() + current_lines = [line] + else: + current_lines.append(line) + if current_lines: + sections.append((current_heading, "\n".join(current_lines).strip())) + + chunks: list[dict[str, Any]] = [] + for heading, section_text in sections: + for window in _yield_chunk_windows(section_text, target_chars, overlap_chars): + chunks.append(_build_chunk(path, window, "markdown", heading)) return chunks +def _chunk_code(text: str, path: Path, target_chars: int, overlap_chars: int) -> list[dict[str, Any]]: + lines = text.replace("\r\n", "\n").splitlines() + marker_re = re.compile( + r"^\s*(?:async\s+def|def|class)\s+|^\s*(?:export\s+)?(?:async\s+)?function\s+|^\s*(?:export\s+)?class\s+|^\s*(?:const|let|var)\s+\w+\s*=\s*(?:async\s*)?\(" + ) + blocks: list[str] = [] + current: list[str] = [] + for line in lines: + if marker_re.match(line) and current: + blocks.append("\n".join(current).strip()) + current = [line] + else: + current.append(line) + if current: + blocks.append("\n".join(current).strip()) + + chunks: list[dict[str, Any]] = [] + for block in blocks: + first_line = next((ln.strip() for ln in block.splitlines() if ln.strip()), "") + section = first_line[:120] + for window in _yield_chunk_windows(block, target_chars, overlap_chars): + chunks.append(_build_chunk(path, window, "code", section)) + return chunks + + +def _chunk_generic(text: str, path: Path, target_chars: int, overlap_chars: int) -> list[dict[str, Any]]: + for_paragraphs = [part.strip() for part in re.split(r"\n\s*\n", text.replace("\r\n", "\n")) if part.strip()] + aggregated: list[str] = [] + current = "" + for paragraph in for_paragraphs: + candidate = f"{current}\n\n{paragraph}".strip() if current else paragraph + if current and len(candidate) > target_chars: + aggregated.append(current) + current = paragraph + else: + current = candidate + if current: + aggregated.append(current) + + chunks: list[dict[str, Any]] = [] + for block in aggregated or [text]: + for window in _yield_chunk_windows(block, target_chars, overlap_chars): + chunks.append(_build_chunk(path, window, "text")) + return chunks + + +def _chunk_text(text: str, path: Path, config: dict[str, Any]) -> list[dict[str, Any]]: + target_chars, overlap_chars = _chunk_cfg(config) + normalized = text.replace("\r\n", "\n").strip() + if not normalized: + return [] + ext = path.suffix.lower() + if ext in {".md", ".markdown", ".rst"}: + chunks = _chunk_markdown(normalized, path, target_chars, overlap_chars) + elif ext in {".py", ".js", ".ts", ".tsx", ".jsx", ".rs", ".go", ".java", ".c", ".cpp", ".h", ".hpp"}: + chunks = _chunk_code(normalized, path, target_chars, overlap_chars) + else: + chunks = _chunk_generic(normalized, path, target_chars, overlap_chars) + return chunks or [_build_chunk(path, normalized, "text")] + + def _read_indexable_text(path: Path) -> str | None: if _is_probably_binary(path): return None @@ -673,6 +981,7 @@ def index_workspace_knowledgebase(config: dict[str, Any] | None = None) -> dict[ pass config_signature = _config_signature(cfg, embedder) conn = _open_index_db(paths) + sqlite_vec_module = _maybe_enable_sqlite_vec(conn, embedder.dimensions) current_files: set[str] = set() chunk_count = 0 indexed_files = 0 @@ -699,17 +1008,26 @@ def index_workspace_knowledgebase(config: dict[str, Any] | None = None) -> dict[ chunks = _chunk_text(text, file_path, cfg) embeddings = embedder.embed_texts([chunk["content"] for chunk in chunks]) if chunks else [] - conn.execute("DELETE FROM chunks WHERE rel_path = ?", (rel_path,)) - conn.execute("DELETE FROM chunks_fts WHERE rel_path = ?", (rel_path,)) - conn.execute("DELETE FROM files WHERE rel_path = ?", (rel_path,)) + _delete_chunk_rows(conn, rel_path, sqlite_vec_module) for idx, chunk in enumerate(chunks): chunk_id = f"{rel_path}#chunk-{idx:04d}" - embedding_json = json.dumps(embeddings[idx] if idx < len(embeddings) else []) - conn.execute( + embedding_vector = embeddings[idx] if idx < len(embeddings) else [] + embedding_json = json.dumps(embedding_vector) + cursor = conn.execute( "INSERT INTO chunks(chunk_id, rel_path, chunk_index, content, token_estimate, embedding) VALUES (?, ?, ?, ?, ?, ?)", (chunk_id, rel_path, idx, chunk["content"], chunk["token_estimate"], embedding_json), ) + if sqlite_vec_module and embedding_vector: + serialized = ( + sqlite_vec_module.serialize_float32(embedding_vector) + if hasattr(sqlite_vec_module, "serialize_float32") + else json.dumps(embedding_vector) + ) + conn.execute( + "INSERT OR REPLACE INTO chunks_vec(rowid, embedding) VALUES (?, ?)", + (cursor.lastrowid, serialized), + ) conn.execute( "INSERT INTO chunks_fts(chunk_id, rel_path, content) VALUES (?, ?, ?)", (chunk_id, rel_path, chunk["content"]), @@ -735,13 +1053,16 @@ def index_workspace_knowledgebase(config: dict[str, Any] | None = None) -> dict[ rel_path = row["rel_path"] if rel_path in current_files: continue - conn.execute("DELETE FROM chunks WHERE rel_path = ?", (rel_path,)) - conn.execute("DELETE FROM chunks_fts WHERE rel_path = ?", (rel_path,)) - conn.execute("DELETE FROM files WHERE rel_path = ?", (rel_path,)) + _delete_chunk_rows(conn, rel_path, sqlite_vec_module) conn.execute( "INSERT OR REPLACE INTO meta(key, value) VALUES (?, ?)", - ("index_info", json.dumps({"updated_at": _utc_now_iso(), "config_signature": config_signature, "backend": embedder.backend})), + ("index_info", json.dumps({ + "updated_at": _utc_now_iso(), + "config_signature": config_signature, + "embedding_backend": embedder.backend, + "dense_backend": "sqlite-vec" if sqlite_vec_module else "python-cosine", + })), ) conn.commit() finally: @@ -752,6 +1073,7 @@ def index_workspace_knowledgebase(config: dict[str, Any] | None = None) -> dict[ manifest["indexed_files"] = indexed_files manifest["skipped_files"] = skipped_files manifest["embedding_backend"] = embedder.backend + manifest["dense_backend"] = "sqlite-vec" if sqlite_vec_module else "python-cosine" return manifest @@ -787,11 +1109,15 @@ def workspace_retrieve( dense_limit = int(dense_top_k or kb_cfg.get("dense_top_k", 40) or 40) sparse_limit = int(sparse_top_k or kb_cfg.get("sparse_top_k", 40) or 40) + fused_limit = int(kb_cfg.get("fused_top_k", 30) or 30) final_limit = int(limit or kb_cfg.get("final_top_k", 8) or 8) embedder = WorkspaceEmbedder(cfg) - query_embedding = embedder.embed_texts([query])[0] + query_embedding = embedder.embed_query(query) + dense_backend = "python-cosine" + reranker = WorkspaceReranker(cfg) conn = _open_index_db(paths) + sqlite_vec_module = _maybe_enable_sqlite_vec(conn, len(query_embedding)) try: sparse_rows: list[sqlite3.Row] = [] fts_query = _fts_terms(query) @@ -805,18 +1131,39 @@ def workspace_retrieve( sparse_rows = [] dense_rows: list[tuple[str, str, str, float]] = [] - chunk_rows = conn.execute( - "SELECT chunk_id, rel_path, content, embedding FROM chunks" - ).fetchall() - for row in chunk_rows: + if sqlite_vec_module: try: - embedding = json.loads(row["embedding"]) + serialized = ( + sqlite_vec_module.serialize_float32(query_embedding) + if hasattr(sqlite_vec_module, "serialize_float32") + else json.dumps(query_embedding) + ) + vec_rows = conn.execute( + "SELECT chunks.chunk_id, chunks.rel_path, chunks.content, chunks_vec.distance " + "FROM chunks_vec JOIN chunks ON chunks.rowid = chunks_vec.rowid " + "WHERE chunks_vec.embedding MATCH ? ORDER BY chunks_vec.distance LIMIT ?", + (serialized, dense_limit), + ).fetchall() + dense_rows = [ + (row["chunk_id"], row["rel_path"], row["content"], 1.0 / (1.0 + float(row["distance"]))) + for row in vec_rows + ] + dense_backend = "sqlite-vec" except Exception: - embedding = [] - score = _cosine_similarity(query_embedding, embedding) - dense_rows.append((row["chunk_id"], row["rel_path"], row["content"], score)) - dense_rows.sort(key=lambda item: item[3], reverse=True) - dense_rows = dense_rows[:dense_limit] + dense_rows = [] + if not dense_rows: + chunk_rows = conn.execute( + "SELECT chunk_id, rel_path, content, embedding FROM chunks" + ).fetchall() + for row in chunk_rows: + try: + embedding = json.loads(row["embedding"]) + except Exception: + embedding = [] + score = _cosine_similarity(query_embedding, embedding) + dense_rows.append((row["chunk_id"], row["rel_path"], row["content"], score)) + dense_rows.sort(key=lambda item: item[3], reverse=True) + dense_rows = dense_rows[:dense_limit] merged: dict[str, dict[str, Any]] = {} sparse_match_count = len(sparse_rows) @@ -845,14 +1192,19 @@ def workspace_retrieve( item["rrf_score"] += 1.0 / (_RRF_K + rank) results = sorted(merged.values(), key=lambda item: (item["rrf_score"], item["dense_score"]), reverse=True) - final = results[:final_limit] + fused_candidates = results[:fused_limit] + reranked = reranker.rerank(query, fused_candidates) + final = reranked[:final_limit] return { "success": True, "query": query, "count": len(final), "total_count": len(results), + "fused_candidate_count": len(fused_candidates), "sparse_match_count": sparse_match_count, "embedding_backend": embedder.backend, + "dense_backend": dense_backend, + "rerank_backend": reranker.backend, "index_path": str(db_path), "results": final, } diff --git a/hermes_cli/workspace.py b/hermes_cli/workspace.py index 565dbe83290..c015c5589e2 100644 --- a/hermes_cli/workspace.py +++ b/hermes_cli/workspace.py @@ -23,6 +23,10 @@ def _print_status(console: Console) -> None: console.print(f"Index DB: {data.get('index_path', '(not built)')}") console.print(f"Files: {data['file_count']}") console.print(f"Chunks: {data.get('chunk_count', 0)}") + if data.get('embedding_backend'): + console.print(f"Embedding backend: {data['embedding_backend']}") + if data.get('dense_backend'): + console.print(f"Dense backend: {data['dense_backend']}") counts = data.get("category_counts") or {} if counts: for key in sorted(counts): @@ -37,6 +41,10 @@ def _print_index(console: Console) -> None: console.print(f"Indexed {data['file_count']} files into {data.get('chunk_count', 0)} chunks") console.print(f"Manifest: {data['manifest_path']}") console.print(f"Index DB: {data['index_path']}") + if data.get('embedding_backend'): + console.print(f"Embedding backend: {data['embedding_backend']}") + if data.get('dense_backend'): + console.print(f"Dense backend: {data['dense_backend']}") def _print_list(console: Console, path: str = "", recursive: bool = True, limit: int = 20, offset: int = 0) -> None: @@ -78,8 +86,12 @@ def _print_retrieve(console: Console, query: str, limit: int = 8) -> None: if not results: console.print("No retrieval results found.") return + if data.get('dense_backend') or data.get('rerank_backend'): + console.print(f"Dense backend: {data.get('dense_backend', '')} Rerank backend: {data.get('rerank_backend', '')}") for result in results: - console.print(f"{result['relative_path']} [score={result['rrf_score']:.4f} dense={result['dense_score']:.3f}]") + rerank_score = result.get('rerank_score') + rerank_text = f" rerank={rerank_score:.3f}" if isinstance(rerank_score, (int, float)) else "" + console.print(f"{result['relative_path']} [rrf={result['rrf_score']:.4f} dense={result['dense_score']:.3f}{rerank_text}]") console.print(result["content"]) console.print() diff --git a/pyproject.toml b/pyproject.toml index 36631911c1f..c359c4b89be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ acp = ["agent-client-protocol>=0.8.1,<1.0"] workspace-rag = [ "sentence-transformers>=5.0.0", "torch>=2.4.0", + "sqlite-vec>=0.1.6", ] rl = [ "atroposlib @ git+https://github.com/NousResearch/atropos.git", diff --git a/tests/agent/test_workspace.py b/tests/agent/test_workspace.py index 4838287b327..74846604432 100644 --- a/tests/agent/test_workspace.py +++ b/tests/agent/test_workspace.py @@ -172,6 +172,72 @@ class TestWorkspaceEmbedder: assert query == [0.1, 0.2, 0.3] +class TestWorkspaceChunking: + def test_markdown_chunking_prefers_headings(self, tmp_path): + from agent.workspace import _chunk_text + + cfg = _config(tmp_path) + text = "# Intro\n\nAlpha overview.\n\n## Deploy\n\nBlue green rollout plan.\n\n## Rollback\n\nRollback steps.\n" + chunks = _chunk_text(text, Path("docs/plan.md"), cfg) + + assert len(chunks) >= 3 + assert any("deploy" in chunk["content"].lower() for chunk in chunks) + assert any("rollback" in chunk["content"].lower() for chunk in chunks) + + def test_code_chunking_prefers_symbol_boundaries(self, tmp_path): + from agent.workspace import _chunk_text + + cfg = _config(tmp_path) + text = "def alpha():\n return 'a'\n\n\ndef beta():\n return 'b'\n" + chunks = _chunk_text(text, Path("code/example.py"), cfg) + + assert len(chunks) >= 2 + assert any("def alpha" in chunk["content"] for chunk in chunks) + assert any("def beta" in chunk["content"] for chunk in chunks) + + +class TestWorkspaceReranker: + def test_local_cross_encoder_reranker_reorders_candidates(self, tmp_path, monkeypatch): + from agent.workspace import WorkspaceReranker + + calls = {} + + class FakeCrossEncoder: + def __init__(self, model_name, **kwargs): + calls["model_name"] = model_name + calls["kwargs"] = kwargs + + def predict(self, pairs, **kwargs): + calls["pairs"] = pairs + calls["predict_kwargs"] = kwargs + return [0.1, 0.9] + + fake_torch = SimpleNamespace( + cuda=SimpleNamespace(is_available=lambda: False), + backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)), + ) + monkeypatch.setitem(sys.modules, "torch", fake_torch) + monkeypatch.setitem(sys.modules, "sentence_transformers", SimpleNamespace(CrossEncoder=FakeCrossEncoder)) + + cfg = _config(tmp_path) + cfg["knowledgebase"]["reranker"]["enabled"] = True + cfg["knowledgebase"]["reranker"]["provider"] = "local" + cfg["knowledgebase"]["reranker"]["model"] = "cross-encoder/ms-marco-MiniLM-L6-v2" + + reranker = WorkspaceReranker(cfg) + ranked = reranker.rerank( + "rollback plan", + [ + {"content": "deployment overview", "rrf_score": 0.9, "dense_score": 0.9}, + {"content": "rollback plan details", "rrf_score": 0.3, "dense_score": 0.2}, + ], + ) + + assert reranker.backend == "cross-encoder" + assert calls["model_name"] == "cross-encoder/ms-marco-MiniLM-L6-v2" + assert ranked[0]["content"] == "rollback plan details" + + class TestWorkspaceRetrieval: def test_index_workspace_builds_chunk_db_and_retrieves_ranked_chunks(self, tmp_path): from agent.workspace import index_workspace_knowledgebase, workspace_retrieve @@ -197,6 +263,20 @@ class TestWorkspaceRetrieval: assert retrieved["results"][0]["relative_path"] == "docs/arch.md" assert "blue green" in retrieved["results"][0]["content"].lower() + def test_workspace_retrieve_reports_backend_metadata(self, tmp_path): + from agent.workspace import index_workspace_knowledgebase, workspace_retrieve + + cfg = _config(tmp_path) + workspace = Path(cfg["workspace"]["path"]) + (workspace / "docs").mkdir(parents=True) + (workspace / "docs" / "plan.md").write_text("blue green rollout plan\n", encoding="utf-8") + + index_workspace_knowledgebase(cfg) + retrieved = workspace_retrieve("blue green rollout", config=cfg, limit=2) + + assert "dense_backend" in retrieved + assert "rerank_backend" in retrieved + def test_workspace_context_for_turn_formats_sources_and_respects_gating(self, tmp_path): from agent.workspace import index_workspace_knowledgebase, workspace_context_for_turn