diff --git a/run_agent.py b/run_agent.py index 6c0262ccf2..5ed5d72457 100644 --- a/run_agent.py +++ b/run_agent.py @@ -3657,6 +3657,7 @@ class AIAgent: "id": tc_delta.id or "", "type": "function", "function": {"name": "", "arguments": ""}, + "extra_content": None, } entry = tool_calls_acc[idx] if tc_delta.id: @@ -3666,6 +3667,13 @@ class AIAgent: entry["function"]["name"] += tc_delta.function.name if tc_delta.function.arguments: entry["function"]["arguments"] += tc_delta.function.arguments + extra = getattr(tc_delta, "extra_content", None) + if extra is None and hasattr(tc_delta, "model_extra"): + extra = (tc_delta.model_extra or {}).get("extra_content") + if extra is not None: + if hasattr(extra, "model_dump"): + extra = extra.model_dump() + entry["extra_content"] = extra # Fire once per tool when the full name is available name = entry["function"]["name"] if name and idx not in tool_gen_notified: @@ -3690,6 +3698,7 @@ class AIAgent: mock_tool_calls.append(SimpleNamespace( id=tc["id"], type=tc["type"], + extra_content=tc.get("extra_content"), function=SimpleNamespace( name=tc["function"]["name"], arguments=tc["function"]["arguments"], diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 6cc34d972c..9d3ed6f320 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -39,10 +39,15 @@ def _make_stream_chunk( return chunk -def _make_tool_call_delta(index=0, tc_id=None, name=None, arguments=None): +def _make_tool_call_delta(index=0, tc_id=None, name=None, arguments=None, extra_content=None, model_extra=None): """Build a mock tool call delta.""" func = SimpleNamespace(name=name, arguments=arguments) - return SimpleNamespace(index=index, id=tc_id, function=func) + delta = SimpleNamespace(index=index, id=tc_id, function=func) + if extra_content is not None: + delta.extra_content = extra_content + if model_extra is not None: + delta.model_extra = model_extra + return delta def _make_empty_chunk(model=None, usage=None): @@ -132,6 +137,52 @@ class TestStreamingAccumulator: assert tc[0].function.name == "terminal" assert tc[0].function.arguments == '{"command": "ls"}' + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_tool_call_extra_content_preserved(self, mock_close, mock_create): + """Streamed tool calls preserve provider-specific extra_content metadata.""" + from run_agent import AIAgent + + chunks = [ + _make_stream_chunk(tool_calls=[ + _make_tool_call_delta( + index=0, + tc_id="call_gemini", + name="cronjob", + model_extra={ + "extra_content": { + "google": {"thought_signature": "sig-123"} + } + }, + ) + ]), + _make_stream_chunk(tool_calls=[ + _make_tool_call_delta(index=0, arguments='{"task": "deep index on ."}') + ]), + _make_stream_chunk(finish_reason="tool_calls"), + ] + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = iter(chunks) + mock_create.return_value = mock_client + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + response = agent._interruptible_streaming_api_call({}) + + tc = response.choices[0].message.tool_calls + assert tc is not None + assert tc[0].extra_content == { + "google": {"thought_signature": "sig-123"} + } + @patch("run_agent.AIAgent._create_request_openai_client") @patch("run_agent.AIAgent._close_request_openai_client") def test_mixed_content_and_tool_calls(self, mock_close, mock_create):