diff --git a/agent/workspace.py b/agent/workspace.py index 56a4d84a734..394da57e8eb 100644 --- a/agent/workspace.py +++ b/agent/workspace.py @@ -361,49 +361,162 @@ def workspace_search( class WorkspaceEmbedder: """Best-effort embedder for workspace retrieval. - Local mode uses a deterministic hashing fallback so retrieval works without - extra dependencies. Hosted providers can use real embedding APIs when - credentials are present; failures fall back to the local hash backend. + Local mode prefers SentenceTransformers with EmbeddingGemma when the + optional runtime is installed. Hosted providers can use real embedding APIs + when credentials are present. Any failure falls back to a deterministic hash + backend so retrieval continues to work. """ + _MODEL_CACHE: dict[tuple[str, str], Any] = {} + _MODEL_CACHE_LOCK = None + def __init__(self, config: dict[str, Any]): kb_cfg = config.get("knowledgebase", {}) or {} emb_cfg = kb_cfg.get("embeddings", {}) or {} self.provider = str(emb_cfg.get("provider", "local") or "local").strip().lower() - self.model = str(emb_cfg.get("model", "embeddinggemma-300m") or "embeddinggemma-300m") + self.model = str(emb_cfg.get("model", "google/embeddinggemma-300m") or "google/embeddinggemma-300m") self.dimensions = int(emb_cfg.get("dimensions", 768) or 768) self.backend = "hash-local-v1" + if WorkspaceEmbedder._MODEL_CACHE_LOCK is None: + import threading + WorkspaceEmbedder._MODEL_CACHE_LOCK = threading.Lock() @property def signature(self) -> str: return f"{self.provider}:{self.model}:{self.dimensions}:{self.backend}" def embed_texts(self, texts: list[str]) -> list[list[float]]: - if self.provider == "openai": - embedded = self._try_openai(texts) - if embedded is not None: + return self.embed_documents(texts) + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + vectors = None + if self.provider == "local": + vectors = self._try_local_documents(texts) + if vectors is not None: + self.backend = "sentence-transformers" + return vectors + elif self.provider == "openai": + vectors = self._try_openai(texts) + if vectors is not None: self.backend = "openai" - return embedded + return vectors elif self.provider == "google": - embedded = self._try_google(texts) - if embedded is not None: + vectors = self._try_google(texts, task_type="RETRIEVAL_DOCUMENT") + if vectors is not None: self.backend = "google" - return embedded + return vectors self.backend = "hash-local-v1" return [self._hash_embed(text) for text in texts] + def embed_query(self, text: str) -> list[float]: + vector = None + if self.provider == "local": + vector = self._try_local_query(text) + if vector is not None: + self.backend = "sentence-transformers" + return vector + elif self.provider == "openai": + vectors = self._try_openai([text]) + if vectors is not None: + self.backend = "openai" + return vectors[0] + elif self.provider == "google": + vectors = self._try_google([text], task_type="RETRIEVAL_QUERY") + if vectors is not None: + self.backend = "google" + return vectors[0] + self.backend = "hash-local-v1" + return self._hash_embed(text) + + def _sentence_transformer_model(self): + try: + import torch + from sentence_transformers import SentenceTransformer + except Exception: + return None + + if torch.cuda.is_available(): + device = "cuda" + elif getattr(getattr(torch, 'backends', None), 'mps', None) and torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + cache_key = (self.model, device) + lock = WorkspaceEmbedder._MODEL_CACHE_LOCK + with lock: + cached = WorkspaceEmbedder._MODEL_CACHE.get(cache_key) + if cached is not None: + return cached + try: + model = SentenceTransformer(self.model, device=device) + except TypeError: + model = SentenceTransformer(self.model) + if hasattr(model, 'to'): + model = model.to(device) + except Exception: + return None + WorkspaceEmbedder._MODEL_CACHE[cache_key] = model + return model + + def _st_encode_kwargs(self) -> dict[str, Any]: + kwargs: dict[str, Any] = {"normalize_embeddings": True} + if 0 < self.dimensions < 768: + kwargs["truncate_dim"] = self.dimensions + return kwargs + + @staticmethod + def _vector_to_list(vector: Any) -> list[float]: + if hasattr(vector, 'tolist'): + vector = vector.tolist() + return [float(v) for v in vector] + + def _vectors_to_lists(self, vectors: Any) -> list[list[float]]: + if hasattr(vectors, 'tolist'): + vectors = vectors.tolist() + if not vectors: + return [] + first = vectors[0] + if isinstance(first, (int, float)): + return [self._vector_to_list(vectors)] + return [self._vector_to_list(vector) for vector in vectors] + + def _try_local_documents(self, texts: list[str]) -> list[list[float]] | None: + model = self._sentence_transformer_model() + if model is None: + return None + kwargs = self._st_encode_kwargs() + try: + if hasattr(model, 'encode_document'): + return self._vectors_to_lists(model.encode_document(texts, **kwargs)) + return self._vectors_to_lists(model.encode(texts, prompt_name='Retrieval-document', **kwargs)) + except Exception: + return None + + def _try_local_query(self, text: str) -> list[float] | None: + model = self._sentence_transformer_model() + if model is None: + return None + kwargs = self._st_encode_kwargs() + try: + if hasattr(model, 'encode_query'): + return self._vector_to_list(model.encode_query(text, **kwargs)) + return self._vector_to_list(model.encode(text, prompt_name='Retrieval-query', **kwargs)) + except Exception: + return None + def _try_openai(self, texts: list[str]) -> list[list[float]] | None: try: from openai import OpenAI except Exception: return None - api_key = os.getenv("OPENAI_API_KEY", "").strip() + api_key = os.getenv('OPENAI_API_KEY', '').strip() if not api_key: return None - kwargs: dict[str, Any] = {"api_key": api_key} - base_url = os.getenv("OPENAI_BASE_URL", "").strip() + kwargs: dict[str, Any] = {'api_key': api_key} + base_url = os.getenv('OPENAI_BASE_URL', '').strip() if base_url: - kwargs["base_url"] = base_url + kwargs['base_url'] = base_url try: client = OpenAI(**kwargs) resp = client.embeddings.create(model=self.model, input=texts) @@ -411,8 +524,8 @@ class WorkspaceEmbedder: except Exception: return None - def _try_google(self, texts: list[str]) -> list[list[float]] | None: - api_key = os.getenv("GEMINI_API_KEY", "").strip() or os.getenv("GOOGLE_API_KEY", "").strip() + def _try_google(self, texts: list[str], task_type: str) -> list[list[float]] | None: + api_key = os.getenv('GEMINI_API_KEY', '').strip() or os.getenv('GOOGLE_API_KEY', '').strip() if not api_key: return None try: @@ -423,17 +536,18 @@ class WorkspaceEmbedder: for text in texts: try: response = requests.post( - f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:embedContent", - params={"key": api_key}, + f'https://generativelanguage.googleapis.com/v1beta/models/{self.model}:embedContent', + params={'key': api_key}, json={ - "content": {"parts": [{"text": text}]}, - "outputDimensionality": self.dimensions, + 'content': {'parts': [{'text': text}]}, + 'taskType': task_type, + 'outputDimensionality': self.dimensions, }, timeout=30, ) response.raise_for_status() payload = response.json() - values = payload.get("embedding", {}).get("values") + values = payload.get('embedding', {}).get('values') if not values: return None results.append([float(v) for v in values]) @@ -448,14 +562,13 @@ class WorkspaceEmbedder: if not tokens: return vec for token in tokens: - digest = hashlib.sha256(token.encode("utf-8")).digest() - idx = int.from_bytes(digest[:4], "big") % dims + digest = hashlib.sha256(token.encode('utf-8')).digest() + idx = int.from_bytes(digest[:4], 'big') % dims sign = 1.0 if digest[4] % 2 == 0 else -1.0 vec[idx] += sign norm = math.sqrt(sum(value * value for value in vec)) or 1.0 return [value / norm for value in vec] - def _index_db_path(paths: WorkspacePaths) -> Path: return paths.indexes_dir / "workspace.sqlite" diff --git a/docs/workspace-knowledgebase-rag-spec.md b/docs/workspace-knowledgebase-rag-spec.md index 296c9d93f1a..690b9136fb1 100644 --- a/docs/workspace-knowledgebase-rag-spec.md +++ b/docs/workspace-knowledgebase-rag-spec.md @@ -228,7 +228,7 @@ knowledgebase: markdown_strategy: headings embeddings: provider: local # local | google | openai | voyage | custom - model: embeddinggemma-300m + model: google/embeddinggemma-300m dimensions: 768 reranker: enabled: false diff --git a/hermes_cli/config.py b/hermes_cli/config.py index af7f823af59..566d00daf85 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -280,7 +280,7 @@ DEFAULT_CONFIG = { }, "embeddings": { "provider": "local", - "model": "embeddinggemma-300m", + "model": "google/embeddinggemma-300m", "dimensions": 768, }, "reranker": { diff --git a/pyproject.toml b/pyproject.toml index fa248cd0e59..36631911c1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,10 @@ honcho = ["honcho-ai>=2.0.1"] mcp = ["mcp>=1.2.0"] homeassistant = ["aiohttp>=3.9.0"] acp = ["agent-client-protocol>=0.8.1,<1.0"] +workspace-rag = [ + "sentence-transformers>=5.0.0", + "torch>=2.4.0", +] rl = [ "atroposlib @ git+https://github.com/NousResearch/atropos.git", "tinker @ git+https://github.com/thinking-machines-lab/tinker.git", diff --git a/tests/agent/test_workspace.py b/tests/agent/test_workspace.py index cd8beb6d172..4838287b327 100644 --- a/tests/agent/test_workspace.py +++ b/tests/agent/test_workspace.py @@ -1,7 +1,9 @@ from __future__ import annotations import json +import sys from pathlib import Path +from types import SimpleNamespace def _config(tmp_path: Path) -> dict: @@ -35,7 +37,7 @@ def _config(tmp_path: Path) -> dict: }, "embeddings": { "provider": "local", - "model": "embeddinggemma-300m", + "model": "google/embeddinggemma-300m", "dimensions": 768, }, "reranker": { @@ -128,6 +130,48 @@ class TestWorkspaceSearch: assert result["matches"][0]["relative_path"] == "docs/a.md" +class TestWorkspaceEmbedder: + def test_local_embeddinggemma_uses_sentence_transformers_when_available(self, tmp_path, monkeypatch): + from agent.workspace import WorkspaceEmbedder + + calls = {} + + class FakeVector(list): + def tolist(self): + return list(self) + + class FakeModel: + def __init__(self, model_id, **kwargs): + calls["model_id"] = model_id + calls["kwargs"] = kwargs + + def encode_query(self, text, **kwargs): + calls["query"] = (text, kwargs) + return FakeVector([0.1, 0.2, 0.3]) + + def encode_document(self, texts, **kwargs): + calls["documents"] = (list(texts), kwargs) + return [FakeVector([0.4, 0.5, 0.6]) for _ in texts] + + fake_torch = SimpleNamespace( + cuda=SimpleNamespace(is_available=lambda: False), + backends=SimpleNamespace(mps=SimpleNamespace(is_available=lambda: False)), + ) + monkeypatch.setitem(sys.modules, "torch", fake_torch) + monkeypatch.setitem(sys.modules, "sentence_transformers", SimpleNamespace(SentenceTransformer=FakeModel)) + + embedder = WorkspaceEmbedder(_config(tmp_path)) + docs = embedder.embed_documents(["alpha doc"]) + query = embedder.embed_query("alpha query") + + assert embedder.backend == "sentence-transformers" + assert calls["model_id"] == "google/embeddinggemma-300m" + assert calls["documents"][0] == ["alpha doc"] + assert calls["query"][0] == "alpha query" + assert docs == [[0.4, 0.5, 0.6]] + assert query == [0.1, 0.2, 0.3] + + class TestWorkspaceRetrieval: def test_index_workspace_builds_chunk_db_and_retrieves_ranked_chunks(self, tmp_path): from agent.workspace import index_workspace_knowledgebase, workspace_retrieve diff --git a/tests/tools/test_workspace_tool.py b/tests/tools/test_workspace_tool.py index fcd8dec9d04..bf6dd8f21cd 100644 --- a/tests/tools/test_workspace_tool.py +++ b/tests/tools/test_workspace_tool.py @@ -33,7 +33,7 @@ def _config(tmp_path: Path) -> dict: "code_strategy": "structural", "markdown_strategy": "headings", }, - "embeddings": {"provider": "local", "model": "embeddinggemma-300m", "dimensions": 768}, + "embeddings": {"provider": "local", "model": "google/embeddinggemma-300m", "dimensions": 768}, "reranker": {"enabled": False, "provider": "local", "model": "bge-reranker-v2-m3"}, "indexing": { "respect_gitignore": True,