diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index d6012a40f41..163ad007f80 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -40,6 +40,7 @@ import json import logging import os import threading +import time from pathlib import Path from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple @@ -325,9 +326,10 @@ class AsyncCodexAuxiliaryClient: class _AnthropicCompletionsAdapter: """OpenAI-client-compatible adapter for Anthropic Messages API.""" - def __init__(self, real_client: Any, model: str): + def __init__(self, real_client: Any, model: str, is_oauth: bool = False): self._client = real_client self._model = model + self._is_oauth = is_oauth def create(self, **kwargs) -> Any: from agent.anthropic_adapter import build_anthropic_kwargs, normalize_anthropic_response @@ -356,6 +358,7 @@ class _AnthropicCompletionsAdapter: max_tokens=max_tokens, reasoning_config=None, tool_choice=normalized_tool_choice, + is_oauth=self._is_oauth, ) if temperature is not None: anthropic_kwargs["temperature"] = temperature @@ -394,9 +397,9 @@ class _AnthropicChatShim: class AnthropicAuxiliaryClient: """OpenAI-client-compatible wrapper over a native Anthropic client.""" - def __init__(self, real_client: Any, model: str, api_key: str, base_url: str): + def __init__(self, real_client: Any, model: str, api_key: str, base_url: str, is_oauth: bool = False): self._real_client = real_client - adapter = _AnthropicCompletionsAdapter(real_client, model) + adapter = _AnthropicCompletionsAdapter(real_client, model, is_oauth=is_oauth) self.chat = _AnthropicChatShim(adapter) self.api_key = api_key self.base_url = base_url @@ -463,15 +466,30 @@ def _nous_base_url() -> str: def _read_codex_access_token() -> Optional[str]: - """Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json).""" + """Read a valid, non-expired Codex OAuth access token from Hermes auth store.""" try: from hermes_cli.auth import _read_codex_tokens data = _read_codex_tokens() tokens = data.get("tokens", {}) access_token = tokens.get("access_token") - if isinstance(access_token, str) and access_token.strip(): - return access_token.strip() - return None + if not isinstance(access_token, str) or not access_token.strip(): + return None + + # Check JWT expiry — expired tokens block the auto chain and + # prevent fallback to working providers (e.g. Anthropic). + try: + import base64 + payload = access_token.split(".")[1] + payload += "=" * (-len(payload) % 4) + claims = json.loads(base64.urlsafe_b64decode(payload)) + exp = claims.get("exp", 0) + if exp and time.time() > exp: + logger.debug("Codex access token expired (exp=%s), skipping", exp) + return None + except Exception: + pass # Non-JWT token or decode error — use as-is + + return access_token.strip() except Exception as exc: logger.debug("Could not read Codex auth for auxiliary client: %s", exc) return None @@ -671,10 +689,12 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]: except Exception: pass + from agent.anthropic_adapter import _is_oauth_token + is_oauth = _is_oauth_token(token) model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001") - logger.debug("Auxiliary client: Anthropic native (%s) at %s", model, base_url) + logger.debug("Auxiliary client: Anthropic native (%s) at %s (oauth=%s)", model, base_url, is_oauth) real_client = build_anthropic_client(token, base_url) - return AnthropicAuxiliaryClient(real_client, model, token, base_url), model + return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]: diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index 0a396944ad3..e4c770f8ea6 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -112,6 +112,339 @@ class TestReadCodexAccessToken: assert result is None + def test_expired_jwt_returns_none(self, tmp_path, monkeypatch): + """Expired JWT tokens should be skipped so auto chain continues.""" + import base64 + import time as _time + + # Build a JWT with exp in the past + header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode() + payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode() + payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode() + expired_jwt = f"{header}.{payload}.fakesig" + + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "auth.json").write_text(json.dumps({ + "version": 1, + "providers": { + "openai-codex": { + "tokens": {"access_token": expired_jwt, "refresh_token": "r"}, + }, + }, + })) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + result = _read_codex_access_token() + assert result is None, "Expired JWT should return None" + + def test_valid_jwt_returns_token(self, tmp_path, monkeypatch): + """Non-expired JWT tokens should be returned.""" + import base64 + import time as _time + + header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode() + payload_data = json.dumps({"exp": int(_time.time()) + 3600}).encode() + payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode() + valid_jwt = f"{header}.{payload}.fakesig" + + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "auth.json").write_text(json.dumps({ + "version": 1, + "providers": { + "openai-codex": { + "tokens": {"access_token": valid_jwt, "refresh_token": "r"}, + }, + }, + })) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + result = _read_codex_access_token() + assert result == valid_jwt + + def test_non_jwt_token_passes_through(self, tmp_path, monkeypatch): + """Non-JWT tokens (no dots) should be returned as-is.""" + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "auth.json").write_text(json.dumps({ + "version": 1, + "providers": { + "openai-codex": { + "tokens": {"access_token": "plain-token-no-jwt", "refresh_token": "r"}, + }, + }, + })) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + result = _read_codex_access_token() + assert result == "plain-token-no-jwt" + + +class TestAnthropicOAuthFlag: + """Test that OAuth tokens get is_oauth=True in auxiliary Anthropic client.""" + + def test_oauth_token_sets_flag(self, monkeypatch): + """OAuth tokens (sk-ant-oat01-*) should create client with is_oauth=True.""" + monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-test-token") + with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build: + mock_build.return_value = MagicMock() + from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient + client, model = _try_anthropic() + assert client is not None + assert isinstance(client, AnthropicAuxiliaryClient) + # The adapter inside should have is_oauth=True + adapter = client.chat.completions + assert adapter._is_oauth is True + + def test_api_key_no_oauth_flag(self, monkeypatch): + """Regular API keys (sk-ant-api-*) should create client with is_oauth=False.""" + with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-testkey1234"), \ + patch("agent.anthropic_adapter.build_anthropic_client") as mock_build: + mock_build.return_value = MagicMock() + from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient + client, model = _try_anthropic() + assert client is not None + assert isinstance(client, AnthropicAuxiliaryClient) + adapter = client.chat.completions + assert adapter._is_oauth is False + + +class TestExpiredCodexFallback: + """Test that expired Codex tokens don't block the auto chain.""" + + def test_expired_codex_falls_through_to_next(self, tmp_path, monkeypatch): + """When Codex token is expired, auto chain should skip it and try next provider.""" + import base64 + import time as _time + + # Expired Codex JWT + header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode() + payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode() + payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode() + expired_jwt = f"{header}.{payload}.fakesig" + + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "auth.json").write_text(json.dumps({ + "version": 1, + "providers": { + "openai-codex": { + "tokens": {"access_token": expired_jwt, "refresh_token": "r"}, + }, + }, + })) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + # Set up Anthropic as fallback + monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-test-fallback") + with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build: + mock_build.return_value = MagicMock() + from agent.auxiliary_client import _resolve_auto, AnthropicAuxiliaryClient + client, model = _resolve_auto() + # Should NOT be Codex, should be Anthropic (or another available provider) + assert not isinstance(client, type(None)), "Should find a provider after expired Codex" + + + def test_expired_codex_openrouter_wins(self, tmp_path, monkeypatch): + """With expired Codex + OpenRouter key, OpenRouter should win (1st in chain).""" + import base64 + import time as _time + + header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode() + payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode() + payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode() + expired_jwt = f"{header}.{payload}.fakesig" + + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "auth.json").write_text(json.dumps({ + "version": 1, + "providers": { + "openai-codex": { + "tokens": {"access_token": expired_jwt, "refresh_token": "r"}, + }, + }, + })) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key") + + with patch("agent.auxiliary_client.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + from agent.auxiliary_client import _resolve_auto + client, model = _resolve_auto() + assert client is not None + # OpenRouter is 1st in chain, should win + mock_openai.assert_called() + + def test_expired_codex_custom_endpoint_wins(self, tmp_path, monkeypatch): + """With expired Codex + custom endpoint (Ollama), custom should win (3rd in chain).""" + import base64 + import time as _time + + header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode() + payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode() + payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode() + expired_jwt = f"{header}.{payload}.fakesig" + + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "auth.json").write_text(json.dumps({ + "version": 1, + "providers": { + "openai-codex": { + "tokens": {"access_token": expired_jwt, "refresh_token": "r"}, + }, + }, + })) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + # Simulate Ollama or custom endpoint + with patch("agent.auxiliary_client._resolve_custom_runtime", + return_value=("http://localhost:11434/v1", "sk-dummy")): + with patch("agent.auxiliary_client.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + from agent.auxiliary_client import _resolve_auto + client, model = _resolve_auto() + assert client is not None + + + def test_hermes_oauth_file_sets_oauth_flag(self, monkeypatch): + """Hermes OAuth credentials should get is_oauth=True (token is not sk-ant-api-*).""" + # Mock resolve_anthropic_token to return an OAuth-style token + # (simulates what read_hermes_oauth_credentials would return) + with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="hermes-oauth-jwt-token"), \ + patch("agent.anthropic_adapter.build_anthropic_client") as mock_build: + mock_build.return_value = MagicMock() + from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient + client, model = _try_anthropic() + assert client is not None, "Should resolve token" + adapter = client.chat.completions + assert adapter._is_oauth is True, "Non-sk-ant-api token should set is_oauth=True" + + def test_jwt_missing_exp_passes_through(self, tmp_path, monkeypatch): + """JWT with valid JSON but no exp claim should pass through.""" + import base64 + header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode() + payload_data = json.dumps({"sub": "user123"}).encode() # no exp + payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode() + no_exp_jwt = f"{header}.{payload}.fakesig" + + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "auth.json").write_text(json.dumps({ + "version": 1, + "providers": { + "openai-codex": { + "tokens": {"access_token": no_exp_jwt, "refresh_token": "r"}, + }, + }, + })) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + result = _read_codex_access_token() + assert result == no_exp_jwt, "JWT without exp should pass through" + + def test_jwt_invalid_json_payload_passes_through(self, tmp_path, monkeypatch): + """JWT with valid base64 but invalid JSON payload should pass through.""" + import base64 + header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode() + payload = base64.urlsafe_b64encode(b"not-json-content").rstrip(b"=").decode() + bad_jwt = f"{header}.{payload}.fakesig" + + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + (hermes_home / "auth.json").write_text(json.dumps({ + "version": 1, + "providers": { + "openai-codex": { + "tokens": {"access_token": bad_jwt, "refresh_token": "r"}, + }, + }, + })) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + result = _read_codex_access_token() + assert result == bad_jwt, "JWT with invalid JSON payload should pass through" + + def test_claude_code_oauth_env_sets_flag(self, monkeypatch): + """CLAUDE_CODE_OAUTH_TOKEN env var should get is_oauth=True.""" + monkeypatch.setenv("CLAUDE_CODE_OAUTH_TOKEN", "cc-oauth-token-test") + monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False) + with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build: + mock_build.return_value = MagicMock() + from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient + client, model = _try_anthropic() + assert client is not None + adapter = client.chat.completions + assert adapter._is_oauth is True + + +class TestExplicitProviderRouting: + """Test explicit provider selection bypasses auto chain correctly.""" + + def test_explicit_anthropic_oauth(self, monkeypatch): + """provider='anthropic' + OAuth token should work with is_oauth=True.""" + monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-explicit-test") + with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build: + mock_build.return_value = MagicMock() + client, model = resolve_provider_client("anthropic") + assert client is not None + # Verify OAuth flag propagated + adapter = client.chat.completions + assert adapter._is_oauth is True + + def test_explicit_anthropic_api_key(self, monkeypatch): + """provider='anthropic' + regular API key should work with is_oauth=False.""" + with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api-regular-key"), \ + patch("agent.anthropic_adapter.build_anthropic_client") as mock_build: + mock_build.return_value = MagicMock() + client, model = resolve_provider_client("anthropic") + assert client is not None + adapter = client.chat.completions + assert adapter._is_oauth is False + + def test_explicit_openrouter(self, monkeypatch): + """provider='openrouter' should use OPENROUTER_API_KEY.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "or-explicit") + with patch("agent.auxiliary_client.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + client, model = resolve_provider_client("openrouter") + assert client is not None + + def test_explicit_kimi(self, monkeypatch): + """provider='kimi-coding' should use KIMI_API_KEY.""" + monkeypatch.setenv("KIMI_API_KEY", "kimi-test-key") + with patch("agent.auxiliary_client.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + client, model = resolve_provider_client("kimi-coding") + assert client is not None + + def test_explicit_minimax(self, monkeypatch): + """provider='minimax' should use MINIMAX_API_KEY.""" + monkeypatch.setenv("MINIMAX_API_KEY", "mm-test-key") + with patch("agent.auxiliary_client.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + client, model = resolve_provider_client("minimax") + assert client is not None + + def test_explicit_deepseek(self, monkeypatch): + """provider='deepseek' should use DEEPSEEK_API_KEY.""" + monkeypatch.setenv("DEEPSEEK_API_KEY", "ds-test-key") + with patch("agent.auxiliary_client.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + client, model = resolve_provider_client("deepseek") + assert client is not None + + def test_explicit_zai(self, monkeypatch): + """provider='zai' should use GLM_API_KEY.""" + monkeypatch.setenv("GLM_API_KEY", "zai-test-key") + with patch("agent.auxiliary_client.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + client, model = resolve_provider_client("zai") + assert client is not None + + def test_explicit_unknown_returns_none(self, monkeypatch): + """Unknown provider should return None.""" + client, model = resolve_provider_client("nonexistent-provider") + assert client is None + + class TestGetTextAuxiliaryClient: """Test the full resolution chain for get_text_auxiliary_client."""