diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e121ec475a..e81711e8e6 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -167,7 +167,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]): system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] - async def run( + async def run( # noqa: C901 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]: try: @@ -186,15 +186,6 @@ async def run( # Use the `capture_run_messages` list as the message history so that new messages are added to it ctx.state.message_history = messages - run_context = build_run_context(ctx) - - parts: list[_messages.ModelRequestPart] = [] - if messages: - # Reevaluate any dynamic system prompt parts - await self._reevaluate_dynamic_prompts(messages, run_context) - else: - parts.extend(await self._sys_parts(run_context)) - if (tool_call_results := ctx.deps.tool_call_results) is not None: if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest): # If tool call results were provided, that means the previous run ended on deferred tool calls. @@ -209,21 +200,52 @@ async def run( if not messages: raise exceptions.UserError('Tool call results were provided, but the message history is empty.') + next_message: _messages.ModelRequest | None = None + if messages and (last_message := messages[-1]): if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None: # Drop last message from history and reuse its parts messages.pop() - parts.extend(last_message.parts) + next_message = _messages.ModelRequest(parts=last_message.parts) + + # Extract `UserPromptPart` content from the popped message and add to `ctx.deps.prompt` + user_prompt_parts = [part for part in last_message.parts if isinstance(part, _messages.UserPromptPart)] + if user_prompt_parts: + if len(user_prompt_parts) == 1: + ctx.deps.prompt = user_prompt_parts[0].content + else: + combined_content: list[_messages.UserContent] = [] + for part in user_prompt_parts: + if isinstance(part.content, str): + combined_content.append(part.content) + else: + combined_content.extend(part.content) + ctx.deps.prompt = combined_content elif isinstance(last_message, _messages.ModelResponse): call_tools_node = await self._handle_message_history_model_response(ctx, last_message) if call_tools_node is not None: return call_tools_node - if self.user_prompt is not None: - parts.append(_messages.UserPromptPart(self.user_prompt)) + # Build the run context after `ctx.deps.prompt` has been updated + run_context = build_run_context(ctx) + + parts: list[_messages.ModelRequestPart] = [] + if messages: + await self._reevaluate_dynamic_prompts(messages, run_context) + + if next_message: + await self._reevaluate_dynamic_prompts([next_message], run_context) + else: + parts: list[_messages.ModelRequestPart] = [] + if not messages: + parts.extend(await self._sys_parts(run_context)) + + if self.user_prompt is not None: + parts.append(_messages.UserPromptPart(self.user_prompt)) + + next_message = _messages.ModelRequest(parts=parts) - instructions = await ctx.deps.get_instructions(run_context) - next_message = _messages.ModelRequest(parts, instructions=instructions) + next_message.instructions = await ctx.deps.get_instructions(run_context) return ModelRequestNode[DepsT, NodeRunEndT](request=next_message) diff --git a/tests/test_agent.py b/tests/test_agent.py index 261cdfaee1..6ee8a0225f 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2003,28 +2003,53 @@ async def ret_a(x: str) -> str: def test_run_with_history_ending_on_model_request_and_no_user_prompt(): + m = TestModel() + agent = Agent(m) + + @agent.system_prompt(dynamic=True) + async def system_prompt(ctx: RunContext) -> str: + return f'System prompt: user prompt length = {len(ctx.prompt or [])}' + messages: list[ModelMessage] = [ - ModelRequest(parts=[UserPromptPart(content='Hello')], instructions='Original instructions'), + ModelRequest( + parts=[ + SystemPromptPart(content='System prompt', dynamic_ref=system_prompt.__qualname__), + UserPromptPart(content=['Hello', ImageUrl('https://example.com/image.jpg')]), + UserPromptPart(content='How goes it?'), + ], + instructions='Original instructions', + ), ] - m = TestModel() - agent = Agent(m, instructions='New instructions') + @agent.instructions + async def instructions(ctx: RunContext) -> str: + assert ctx.prompt == ['Hello', ImageUrl('https://example.com/image.jpg'), 'How goes it?'] + return 'New instructions' result = agent.run_sync(message_history=messages) assert result.all_messages() == snapshot( [ ModelRequest( parts=[ + SystemPromptPart( + content='System prompt: user prompt length = 3', + timestamp=IsDatetime(), + dynamic_ref=IsStr(), + ), UserPromptPart( - content='Hello', + content=['Hello', ImageUrl(url='https://example.com/image.jpg', identifier='39cfc4')], timestamp=IsDatetime(), - ) + ), + UserPromptPart( + content='How goes it?', + timestamp=IsDatetime(), + ), ], instructions='New instructions', ), ModelResponse( parts=[TextPart(content='success (no tool calls)')], - usage=RequestUsage(input_tokens=51, output_tokens=4), + usage=RequestUsage(input_tokens=61, output_tokens=4), model_name='test', timestamp=IsDatetime(), ),