diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py index 2878702c94..e445c74841 100644 --- a/pydantic_ai_slim/pydantic_ai/_result.py +++ b/pydantic_ai_slim/pydantic_ai/_result.py @@ -3,7 +3,7 @@ import inspect import sys import types -from collections.abc import Awaitable +from collections.abc import Awaitable, Iterable from dataclasses import dataclass, field from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin @@ -113,14 +113,24 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat return cls(tools=tools, allow_text_result=allow_text_result) + def find_named_tool( + self, parts: Iterable[_messages.ModelResponsePart], tool_name: str + ) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None: + """Find a tool that matches one of the calls, with a specific name.""" + for part in parts: + if isinstance(part, _messages.ToolCallPart): + if part.tool_name == tool_name: + return part, self.tools[tool_name] + def find_tool( - self, message: _messages.ModelResponse + self, + parts: Iterable[_messages.ModelResponsePart], ) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None: """Find a tool that matches one of the calls.""" - for item in message.parts: - if isinstance(item, _messages.ToolCallPart): - if result := self.tools.get(item.tool_name): - return item, result + for part in parts: + if isinstance(part, _messages.ToolCallPart): + if result := self.tools.get(part.tool_name): + return part, result def tool_names(self) -> list[str]: """Return the names of the tools.""" diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9074aed87b..6e437c79d6 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Generic, Literal, cast, final, overload import logfire_api +from typing_extensions import assert_never from . import ( _result, @@ -272,8 +273,8 @@ async def run( else: # continue the conversation handle_span.set_attribute('tool_responses', tool_responses) - tool_responses = ' '.join(r.part_kind for r in tool_responses) - handle_span.message = f'handle model response -> {tool_responses}' + tool_responses_str = ' '.join(r.part_kind for r in tool_responses) + handle_span.message = f'handle model response -> {tool_responses_str}' def run_sync( self, @@ -287,7 +288,8 @@ def run_sync( ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. - This is a convenience method that wraps `self.run` with `loop.run_until_complete()`. + This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. Example: ```python @@ -314,8 +316,7 @@ async def main(): """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - loop = asyncio.get_event_loop() - return loop.run_until_complete( + return asyncio.get_event_loop().run_until_complete( self.run( user_prompt, message_history=message_history, @@ -402,19 +403,34 @@ async def main(): model_req_span.__exit__(None, None, None) with _logfire.span('handle model response') as handle_span: - final_result, new_messages = await self._handle_streamed_model_response( + maybe_final_result = await self._handle_streamed_model_response( model_response, deps, messages ) - # Add all messages to the conversation - messages.extend(new_messages) - # Check if we got a final result - if final_result is not None: - result_stream = final_result.data - run_span.set_attribute('all_messages', messages) - handle_span.set_attribute('result_type', result_stream.__class__.__name__) + if isinstance(maybe_final_result, _MarkFinalResult): + result_stream = maybe_final_result.data + result_tool_name = maybe_final_result.tool_name handle_span.message = 'handle model response -> final result' + + async def on_complete(): + """Called when the stream has completed. + + The model response will have been added to messages by now + by `StreamedRunResult._marked_completed`. + """ + last_message = messages[-1] + assert isinstance(last_message, _messages.ModelResponse) + tool_calls = [ + part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) + ] + parts = await self._process_function_tools( + tool_calls, result_tool_name, deps, messages + ) + if parts: + messages.append(_messages.ModelRequest(parts)) + run_span.set_attribute('all_messages', messages) + yield result.StreamedRunResult( messages, new_message_index, @@ -423,17 +439,22 @@ async def main(): self._result_schema, deps, self._result_validators, - lambda m: run_span.set_attribute('all_messages', messages), + result_tool_name, + on_complete, ) return else: # continue the conversation - # the last new message should always be a model request - tool_response_msg = new_messages[-1] - assert isinstance(tool_response_msg, _messages.ModelRequest) - handle_span.set_attribute('tool_responses', tool_response_msg.parts) - tool_responses = ' '.join(r.part_kind for r in tool_response_msg.parts) - handle_span.message = f'handle model response -> {tool_responses}' + model_response_msg, tool_responses = maybe_final_result + # if we got a model response add that to messages + messages.append(model_response_msg) + if tool_responses: + # if we got one or more tool response parts, add a model request message + messages.append(_messages.ModelRequest(tool_responses)) + + handle_span.set_attribute('tool_responses', tool_responses) + tool_responses_str = ' '.join(r.part_kind for r in tool_responses) + handle_span.message = f'handle model response -> {tool_responses_str}' # the model_response should have been fully streamed by now, we can add it's cost cost += model_response.cost() @@ -846,7 +867,7 @@ async def _handle_text_response( self._incr_result_retry() return None, [e.tool_retry] else: - return _MarkFinalResult(result_data), [] + return _MarkFinalResult(result_data, None), [] else: self._incr_result_retry() response = _messages.RetryPromptPart( @@ -860,34 +881,13 @@ async def _handle_structured_response( """Handle a structured response containing tool calls from the model for non-streaming responses.""" assert tool_calls, 'Expected at least one tool call' - # First process any final result tool calls - final_result, final_messages = await self._process_final_tool_calls(tool_calls, deps, conv_messages) - - # Then process regular tools based on end strategy - if self.end_strategy == 'early' and final_result: - tool_parts = self._mark_skipped_function_tools(tool_calls) - else: - tool_parts = await self._process_function_tools(tool_calls, deps, conv_messages) - - return final_result, [*final_messages, *tool_parts] - - async def _process_final_tool_calls( - self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage] - ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]: - """Process any final result tool calls and return the first valid result.""" - if not self._result_schema: - return None, [] + # first look for the result tool call + final_result: _MarkFinalResult[ResultData] | None = None parts: list[_messages.ModelRequestPart] = [] - final_result = None - - for call in tool_calls: - result_tool = self._result_schema.tools.get(call.tool_name) - if not result_tool: - continue - - if final_result is None: - # This is the first result tool - try to use it + if result_schema := self._result_schema: + if match := result_schema.find_tool(tool_calls): + call, result_tool = match try: result_data = result_tool.validate(call) result_data = await self._validate_result(result_data, deps, call, conv_messages) @@ -895,66 +895,73 @@ async def _process_final_tool_calls( self._incr_result_retry() parts.append(e.tool_retry) else: - final_result = _MarkFinalResult(result_data) - parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Final result processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - # We already have a final result - mark this one as unused - parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Result tool not used - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) + final_result = _MarkFinalResult(result_data, call.tool_name) + + # Then build the other request parts based on end strategy + parts += await self._process_function_tools( + tool_calls, final_result and final_result.tool_name, deps, conv_messages + ) return final_result, parts async def _process_function_tools( - self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage] - ) -> list[_messages.ModelRequestPart]: - """Process function (non-final) tool calls in parallel.""" - parts: list[_messages.ModelRequestPart] = [] - tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] - - for call in tool_calls: - if tool := self._function_tools.get(call.tool_name): - tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name)) - elif self._result_schema is None or call.tool_name not in self._result_schema.tools: - parts.append(self._unknown_tool(call.tool_name)) - - # Run all tool tasks in parallel - if tasks: - with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): - task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) - parts.extend(task_results) - - return parts - - def _mark_skipped_function_tools( self, tool_calls: list[_messages.ToolCallPart], + result_tool_name: str | None, + deps: AgentDeps, + conv_messages: list[_messages.ModelMessage], ) -> list[_messages.ModelRequestPart]: - """Mark function tools as skipped when a final result was found with 'early' end strategy.""" + """Process function (non-result) tool calls in parallel. + + Also add stub return parts for any other tools that need it. + """ parts: list[_messages.ModelRequestPart] = [] + tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] + stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early' + + # we rely on the fact that if we found a result, it's the first result tool in the last + found_used_result_tool = False for call in tool_calls: - if call.tool_name in self._function_tools: + if call.tool_name == result_tool_name and not found_used_result_tool: + found_used_result_tool = True parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', + content='Final result processed.', tool_call_id=call.tool_call_id, ) ) - elif self._result_schema is None or call.tool_name not in self._result_schema.tools: + elif tool := self._function_tools.get(call.tool_name): + if stub_function_tools: + parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name)) + elif self._result_schema is not None and call.tool_name in self._result_schema.tools: + # if tool_name is in _result_schema, it means we found a result tool but an error occurred in + # validation, we don't add another part here + if result_tool_name is not None: + parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Result tool not used - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: parts.append(self._unknown_tool(call.tool_name)) + # Run all tool tasks in parallel + if tasks: + with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): + task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) + parts.extend(task_results) return parts async def _handle_streamed_model_response( @@ -962,16 +969,20 @@ async def _handle_streamed_model_response( model_response: models.EitherStreamedResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage], - ) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.ModelMessage]]: + ) -> ( + _MarkFinalResult[models.EitherStreamedResponse] + | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]] + ): """Process a streamed response from the model. Returns: - A tuple of (final_result, messages). If final_result is not None, the conversation should end. + Either a final result or a tuple of the model response and the tool responses for the next request. + If a final result is returned, the conversation should end. """ if isinstance(model_response, models.StreamTextResponse): # plain string response if self._allow_text_result: - return _MarkFinalResult(model_response), [] + return _MarkFinalResult(model_response, None) else: self._incr_result_retry() response = _messages.RetryPromptPart( @@ -981,9 +992,9 @@ async def _handle_streamed_model_response( async for _ in model_response: pass - return None, [_messages.ModelRequest([response])] - else: - assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}' + text = ''.join(model_response.get(final=True)) + return _messages.ModelResponse([_messages.TextPart(text)]), [response] + elif isinstance(model_response, models.StreamStructuredResponse): if self._result_schema is not None: # if there's a result schema, iterate over the stream until we find at least one tool # NOTE: this means we ignore any other tools called here @@ -995,14 +1006,9 @@ async def _handle_streamed_model_response( break structured_msg = model_response.get() - if match := self._result_schema.find_tool(structured_msg): + if match := self._result_schema.find_tool(structured_msg.parts): call, _ = match - tool_return = _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Final result processed.', - tool_call_id=call.tool_call_id, - ) - return _MarkFinalResult(model_response), [_messages.ModelRequest([tool_return])] + return _MarkFinalResult(model_response, call.tool_name) # the model is calling a tool function, consume the response to get the next message async for _ in model_response: @@ -1010,7 +1016,6 @@ async def _handle_streamed_model_response( model_response_msg = model_response.get() if not model_response_msg.parts: raise exceptions.UnexpectedModelBehavior('Received empty tool call message') - messages: list[_messages.ModelMessage] = [model_response_msg] # we now run all tool functions in parallel tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] @@ -1026,8 +1031,9 @@ async def _handle_streamed_model_response( with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) parts.extend(task_results) - messages.append(_messages.ModelRequest(parts)) - return None, messages + return model_response_msg, parts + else: + assert_never(model_response) async def _validate_result( self, @@ -1110,3 +1116,6 @@ class _MarkFinalResult(Generic[ResultData]): """ data: ResultData + """The final result data.""" + tool_name: str | None + """Name of the final result tool, None if the result is a string.""" diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 2cb6be4276..281b8f5fcb 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Callable +from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime from typing import Generic, TypeVar, cast @@ -122,7 +122,8 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat _result_schema: _result.ResultSchema[ResultData] | None _deps: AgentDeps _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] - _on_complete: Callable[[list[_messages.ModelMessage]], None] + _result_tool_name: str | None + _on_complete: Callable[[], Awaitable[None]] is_complete: bool = field(default=False, init=False) """Whether the stream has all been received. @@ -205,7 +206,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = combined = await self._validate_text_result(''.join(chunks)) yield combined lf_span.set_attribute('combined_text', combined) - self._marked_completed(_messages.ModelResponse.from_text(combined)) + await self._marked_completed(_messages.ModelResponse.from_text(combined)) async def stream_structured( self, *, debounce_by: float | None = 0.1 @@ -244,7 +245,7 @@ async def stream_structured( msg = self._stream_response.get(final=True) yield msg, True lf_span.set_attribute('structured_response', msg) - self._marked_completed(msg) + await self._marked_completed(msg) async def get_data(self) -> ResultData: """Stream the whole response, validate and return it.""" @@ -253,11 +254,11 @@ async def get_data(self) -> ResultData: if isinstance(self._stream_response, models.StreamTextResponse): text = ''.join(self._stream_response.get(final=True)) text = await self._validate_text_result(text) - self._marked_completed(_messages.ModelResponse.from_text(text)) + await self._marked_completed(_messages.ModelResponse.from_text(text)) return cast(ResultData, text) else: message = self._stream_response.get(final=True) - self._marked_completed(message) + await self._marked_completed(message) return await self.validate_structured_result(message) @property @@ -282,7 +283,8 @@ async def validate_structured_result( ) -> ResultData: """Validate a structured result message.""" assert self._result_schema is not None, 'Expected _result_schema to not be None' - match = self._result_schema.find_tool(message) + assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None' + match = self._result_schema.find_named_tool(message.parts, self._result_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( f'Invalid message, unable to find tool: {self._result_schema.tool_names()}' @@ -306,7 +308,7 @@ async def _validate_text_result(self, text: str) -> str: ) return text - def _marked_completed(self, message: _messages.ModelResponse) -> None: + async def _marked_completed(self, message: _messages.ModelResponse) -> None: self.is_complete = True self._all_messages.append(message) - self._on_complete(self._all_messages) + await self._on_complete() diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 6820620af1..d08d28ca10 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -669,13 +669,6 @@ async def bar(y: str) -> str: ToolReturnPart(tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc)), ] ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) - ) - ] - ), ModelResponse( parts=[ ToolCallPart( @@ -685,6 +678,13 @@ async def bar(y: str) -> str: ], timestamp=IsNow(tz=timezone.utc), ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) + ) + ] + ), ] ) assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"]) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 52004134a6..e374c9f549 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -415,19 +415,21 @@ async def test_stream_structured(allow_model_requests: None): assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) - ) - ] - ), ModelResponse( parts=[ ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"first": "One", "second": "Two"}')) ], timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), ] ) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index c226ac1c1b..86532dee80 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1327,6 +1327,12 @@ async def get_location(loc_name: str) -> str: ) ] ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"won": true}'), tool_call_id='1') + ], + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), ModelRequest( parts=[ ToolReturnPart( @@ -1337,12 +1343,6 @@ async def get_location(loc_name: str) -> str: ) ] ), - ModelResponse( - parts=[ - ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"won": true}'), tool_call_id='1') - ], - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), ] ) diff --git a/tests/test_agent.py b/tests/test_agent.py index a079e245e9..297513ec48 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1060,13 +1060,13 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 6dc06a007e..667f8fe858 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -6,6 +6,7 @@ import pytest from inline_snapshot import snapshot +from pydantic import BaseModel from pydantic_ai import Agent, UnexpectedModelBehavior, UserError from pydantic_ai.messages import ( @@ -214,15 +215,6 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ToolReturnPart(tool_name='ret_a', content='hello world', timestamp=IsNow(tz=timezone.utc))] ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', - content='Final result processed.', - timestamp=IsNow(tz=timezone.utc), - ) - ] - ), ] ) assert await result.get_data() == snapshot(('hello world', 2)) @@ -236,15 +228,6 @@ async def ret_a(x: str) -> str: ModelRequest( parts=[ToolReturnPart(tool_name='ret_a', content='hello world', timestamp=IsNow(tz=timezone.utc))] ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', - content='Final result processed.', - timestamp=IsNow(tz=timezone.utc), - ) - ] - ), ModelResponse( parts=[ ToolCallPart( @@ -254,6 +237,15 @@ async def ret_a(x: str) -> str: ], timestamp=IsNow(tz=timezone.utc), ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), ] ) @@ -299,3 +291,274 @@ async def ret_a(x: str) -> str: # pragma: no cover ), ] ) + + +class ResultType(BaseModel): + """Result type used by all tests.""" + + value: str + + +async def test_early_strategy_stops_after_first_final_result(): + """Test that 'early' strategy stops processing regular tools after first final result.""" + tool_called: list[str] = [] + + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + assert info.result_tools is not None + yield {1: DeltaToolCall('final_result', '{"value": "final"}')} + yield {2: DeltaToolCall('regular_tool', '{"x": 1}')} + yield {3: DeltaToolCall('another_tool', '{"y": 2}')} + + agent = Agent(FunctionModel(stream_function=sf), result_type=ResultType, end_strategy='early') + + @agent.tool_plain + def regular_tool(x: int) -> int: # pragma: no cover + """A regular tool that should not be called.""" + tool_called.append('regular_tool') + return x + + @agent.tool_plain + def another_tool(y: int) -> int: # pragma: no cover + """Another tool that should not be called.""" + tool_called.append('another_tool') + return y + + async with agent.run_stream('test early strategy') as result: + response = await result.get_data() + assert response.value == snapshot('final') + messages = result.all_messages() + + # Verify no tools were called after final result + assert tool_called == [] + + # Verify we got tool returns for all calls + assert messages == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='test early strategy', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"value": "final"}')), + ToolCallPart(tool_name='regular_tool', args=ArgsJson(args_json='{"x": 1}')), + ToolCallPart(tool_name='another_tool', args=ArgsJson(args_json='{"y": 2}')), + ], + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturnPart( + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturnPart( + tool_name='another_tool', + content='Tool not executed - a final result was already processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ] + ), + ] + ) + + +async def test_early_strategy_uses_first_final_result(): + """Test that 'early' strategy uses the first final result and ignores subsequent ones.""" + + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + assert info.result_tools is not None + yield {1: DeltaToolCall('final_result', '{"value": "first"}')} + yield {2: DeltaToolCall('final_result', '{"value": "second"}')} + + agent = Agent(FunctionModel(stream_function=sf), result_type=ResultType, end_strategy='early') + + async with agent.run_stream('test multiple final results') as result: + response = await result.get_data() + assert response.value == snapshot('first') + messages = result.all_messages() + + # Verify we got appropriate tool returns + assert messages == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='test multiple final results', timestamp=IsNow(tz=timezone.utc))] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"value": "first"}')), + ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"value": "second"}')), + ], + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) + ), + ToolReturnPart( + tool_name='final_result', + content='Result tool not used - a final result was already processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ] + ), + ] + ) + + +async def test_exhaustive_strategy_executes_all_tools(): + """Test that 'exhaustive' strategy executes all tools while using first final result.""" + tool_called: list[str] = [] + + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + assert info.result_tools is not None + yield {1: DeltaToolCall('final_result', '{"value": "first"}')} + yield {2: DeltaToolCall('regular_tool', '{"x": 42}')} + yield {3: DeltaToolCall('another_tool', '{"y": 2}')} + yield {4: DeltaToolCall('final_result', '{"value": "second"}')} + yield {5: DeltaToolCall('unknown_tool', '{"value": "???"}')} + + agent = Agent(FunctionModel(stream_function=sf), result_type=ResultType, end_strategy='exhaustive') + + @agent.tool_plain + def regular_tool(x: int) -> int: + """A regular tool that should be called.""" + tool_called.append('regular_tool') + return x + + @agent.tool_plain + def another_tool(y: int) -> int: + """Another tool that should be called.""" + tool_called.append('another_tool') + return y + + async with agent.run_stream('test exhaustive strategy') as result: + response = await result.get_data() + assert response.value == snapshot('first') + messages = result.all_messages() + + # Verify we got tool returns in the correct order + assert messages == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='test exhaustive strategy', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ + ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"value": "first"}')), + ToolCallPart(tool_name='regular_tool', args=ArgsJson(args_json='{"x": 42}')), + ToolCallPart(tool_name='another_tool', args=ArgsJson(args_json='{"y": 2}')), + ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"value": "second"}')), + ToolCallPart(tool_name='unknown_tool', args=ArgsJson(args_json='{"value": "???"}')), + ], + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturnPart( + tool_name='final_result', + content='Result tool not used - a final result was already processed.', + timestamp=IsNow(tz=timezone.utc), + ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", + timestamp=IsNow(tz=timezone.utc), + ), + ToolReturnPart(tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc)), + ToolReturnPart(tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc)), + ] + ), + ] + ) + + +@pytest.mark.xfail(reason='final result tool not first is not yet supported') +async def test_early_strategy_with_final_result_in_middle(): + """Test that 'early' strategy stops at first final result, regardless of position.""" + tool_called: list[str] = [] + + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + assert info.result_tools is not None + yield {1: DeltaToolCall('regular_tool', '{"x": 1}')} + yield {2: DeltaToolCall('final_result', '{"value": "final"}')} + yield {3: DeltaToolCall('another_tool', '{"y": 2}')} + yield {4: DeltaToolCall('unknown_tool', '{"value": "???"}')} + + agent = Agent(FunctionModel(stream_function=sf), result_type=ResultType, end_strategy='early') + + @agent.tool_plain + def regular_tool(x: int) -> int: # pragma: no cover + """A regular tool that should not be called.""" + tool_called.append('regular_tool') + return x + + @agent.tool_plain + def another_tool(y: int) -> int: # pragma: no cover + """A tool that should not be called.""" + tool_called.append('another_tool') + return y + + async with agent.run_stream('test early strategy with final result in middle') as result: + response = await result.get_data() + assert response.value == snapshot('first') + messages = result.all_messages() + + # Verify no tools were called + assert tool_called == [] + + # Verify we got appropriate tool returns + assert messages == snapshot() + + +async def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(): + """Test that 'early' strategy does not apply to tool calls without final tool.""" + tool_called: list[str] = [] + agent = Agent(TestModel(), result_type=ResultType, end_strategy='early') + + @agent.tool_plain + def regular_tool(x: int) -> int: + """A regular tool that should be called.""" + tool_called.append('regular_tool') + return x + + async with agent.run_stream('test early strategy with regular tool calls') as result: + response = await result.get_data() + assert response.value == snapshot('a') + messages = result.all_messages() + + assert tool_called == ['regular_tool'] + + assert messages == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='test early strategy with regular tool calls', timestamp=IsNow(tz=timezone.utc) + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='regular_tool', args=ArgsDict(args_dict={'x': 0}))], + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest(parts=[ToolReturnPart(tool_name='regular_tool', content=0, timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'value': 'a'}))], + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc) + ) + ] + ), + ] + )