diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 84a3448448..2b65766dc9 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -549,10 +549,34 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]: return None +def _model_id_matches(candidate_id: str, lookup_model: str) -> bool: + """Return True if *candidate_id* (from server) matches *lookup_model* (configured). + + Supports two forms: + - Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1" + - Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1" + (the part after the last "/" equals lookup_model) + + This covers LM Studio's native API which stores models as "publisher/slug" + while users typically configure only the slug after the "local:" prefix. + """ + if candidate_id == lookup_model: + return True + # Slug match: basename of candidate equals the lookup name + if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model: + return True + return False + + def _query_local_context_length(model: str, base_url: str) -> Optional[int]: """Query a local server for the model's context length.""" import httpx + # Strip provider prefix (e.g., "local:model-name" → "model-name"). + # LM Studio and Ollama don't use provider prefixes in their model IDs. + if ":" in model and not model.startswith("http"): + model = model.split(":", 1)[1] + # Strip /v1 suffix to get the server root server_url = base_url.rstrip("/") if server_url.endswith("/v1"): @@ -587,6 +611,28 @@ def _query_local_context_length(model: str, base_url: str) -> Optional[int]: except ValueError: pass + # LM Studio native API: /api/v1/models returns max_context_length. + # This is more reliable than the OpenAI-compat /v1/models which + # doesn't include context window information for LM Studio servers. + # Use _model_id_matches for fuzzy matching: LM Studio stores models as + # "publisher/slug" but users configure only "slug" after "local:" prefix. + if server_type == "lm-studio": + resp = client.get(f"{server_url}/api/v1/models") + if resp.status_code == 200: + data = resp.json() + for m in data.get("models", []): + if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model): + # Prefer loaded instance context (actual runtime value) + for inst in m.get("loaded_instances", []): + cfg = inst.get("config", {}) + ctx = cfg.get("context_length") + if ctx and isinstance(ctx, (int, float)): + return int(ctx) + # Fall back to max_context_length (theoretical model max) + ctx = m.get("max_context_length") or m.get("context_length") + if ctx and isinstance(ctx, (int, float)): + return int(ctx) + # LM Studio / vLLM / llama.cpp: try /v1/models/{model} resp = client.get(f"{server_url}/v1/models/{model}") if resp.status_code == 200: @@ -596,13 +642,14 @@ def _query_local_context_length(model: str, base_url: str) -> Optional[int]: if ctx and isinstance(ctx, (int, float)): return int(ctx) - # Try /v1/models and find the model in the list + # Try /v1/models and find the model in the list. + # Use _model_id_matches to handle "publisher/slug" vs bare "slug". resp = client.get(f"{server_url}/v1/models") if resp.status_code == 200: data = resp.json() models_list = data.get("data", []) for m in models_list: - if m.get("id") == model: + if _model_id_matches(m.get("id", ""), model): ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens") if ctx and isinstance(ctx, (int, float)): return int(ctx) @@ -633,6 +680,12 @@ def get_model_context_length( if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0: return config_context_length + # Normalise provider-prefixed model names (e.g. "local:model-name" → + # "model-name") so cache lookups and server queries use the bare ID that + # local servers actually know about. + if ":" in model and not model.startswith("http"): + model = model.split(":", 1)[1] + # 1. Check persistent cache (model+provider) if base_url: cached = get_cached_context_length(model, base_url) diff --git a/tests/test_model_metadata_local_ctx.py b/tests/test_model_metadata_local_ctx.py index 513edaff75..e5ad0dc58c 100644 --- a/tests/test_model_metadata_local_ctx.py +++ b/tests/test_model_metadata_local_ctx.py @@ -206,6 +206,186 @@ class TestQueryLocalContextLengthModelsList: assert result is None +class TestQueryLocalContextLengthLmStudio: + """_query_local_context_length with LM Studio native /api/v1/models response.""" + + def _make_resp(self, status_code, body): + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = body + return resp + + def _make_client(self, native_resp, detail_resp, list_resp): + """Build a mock httpx.Client with sequenced GET responses.""" + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = self._make_resp(404, {}) + + responses = [native_resp, detail_resp, list_resp] + call_idx = [0] + + def get_side_effect(url, **kwargs): + idx = call_idx[0] + call_idx[0] += 1 + if idx < len(responses): + return responses[idx] + return self._make_resp(404, {}) + + client_mock.get.side_effect = get_side_effect + return client_mock + + def test_lmstudio_exact_key_match(self): + """Reads max_context_length when key matches exactly.""" + from agent.model_metadata import _query_local_context_length + + native_resp = self._make_resp(200, { + "models": [ + {"key": "nvidia/nvidia-nemotron-super-49b-v1", "id": "nvidia/nvidia-nemotron-super-49b-v1", + "max_context_length": 131072}, + ] + }) + client_mock = self._make_client( + native_resp, + self._make_resp(404, {}), + self._make_resp(404, {}), + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length( + "nvidia/nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" + ) + + assert result == 131072 + + def test_lmstudio_slug_only_matches_key_with_publisher_prefix(self): + """Fuzzy match: bare model slug matches key that includes publisher prefix. + + When the user configures the model as "local:nvidia-nemotron-super-49b-v1" + (slug only, no publisher), but LM Studio's native API stores it as + "nvidia/nvidia-nemotron-super-49b-v1", the lookup must still succeed. + """ + from agent.model_metadata import _query_local_context_length + + native_resp = self._make_resp(200, { + "models": [ + {"key": "nvidia/nvidia-nemotron-super-49b-v1", + "id": "nvidia/nvidia-nemotron-super-49b-v1", + "max_context_length": 131072}, + ] + }) + client_mock = self._make_client( + native_resp, + self._make_resp(404, {}), + self._make_resp(404, {}), + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + # Model passed in is just the slug after stripping "local:" prefix + result = _query_local_context_length( + "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" + ) + + assert result == 131072 + + def test_lmstudio_v1_models_list_slug_fuzzy_match(self): + """Fuzzy match also works for /v1/models list when exact match fails. + + LM Studio's OpenAI-compat /v1/models returns id like + "nvidia/nvidia-nemotron-super-49b-v1" — must match bare slug. + """ + from agent.model_metadata import _query_local_context_length + + # native /api/v1/models: no match + native_resp = self._make_resp(404, {}) + # /v1/models/{model}: no match + detail_resp = self._make_resp(404, {}) + # /v1/models list: model found with publisher prefix, includes context_length + list_resp = self._make_resp(200, { + "data": [ + {"id": "nvidia/nvidia-nemotron-super-49b-v1", "context_length": 131072}, + ] + }) + client_mock = self._make_client(native_resp, detail_resp, list_resp) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length( + "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" + ) + + assert result == 131072 + + def test_lmstudio_loaded_instances_context_length(self): + """Reads active context_length from loaded_instances when max_context_length absent.""" + from agent.model_metadata import _query_local_context_length + + native_resp = self._make_resp(200, { + "models": [ + { + "key": "nvidia/nvidia-nemotron-super-49b-v1", + "id": "nvidia/nvidia-nemotron-super-49b-v1", + "loaded_instances": [ + {"config": {"context_length": 65536}}, + ], + }, + ] + }) + client_mock = self._make_client( + native_resp, + self._make_resp(404, {}), + self._make_resp(404, {}), + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length( + "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" + ) + + assert result == 65536 + + def test_lmstudio_loaded_instance_beats_max_context_length(self): + """loaded_instances context_length takes priority over max_context_length. + + LM Studio may show max_context_length=1_048_576 (theoretical model max) + while the actual loaded context is 122_651 (runtime setting). The loaded + value is the real constraint and must be preferred. + """ + from agent.model_metadata import _query_local_context_length + + native_resp = self._make_resp(200, { + "models": [ + { + "key": "nvidia/nvidia-nemotron-3-nano-4b", + "id": "nvidia/nvidia-nemotron-3-nano-4b", + "max_context_length": 1_048_576, + "loaded_instances": [ + {"config": {"context_length": 122_651}}, + ], + }, + ] + }) + client_mock = self._make_client( + native_resp, + self._make_resp(404, {}), + self._make_resp(404, {}), + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length( + "nvidia-nemotron-3-nano-4b", "http://192.168.1.22:1234/v1" + ) + + assert result == 122_651, ( + f"Expected loaded instance context (122651) but got {result}. " + "max_context_length (1048576) must not win over loaded_instances." + ) + + class TestQueryLocalContextLengthNetworkError: """_query_local_context_length handles network failures gracefully."""