Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]

async def run(
async def run( # noqa: C901
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]:
try:
Expand All @@ -186,15 +186,6 @@ async def run(
# Use the `capture_run_messages` list as the message history so that new messages are added to it
ctx.state.message_history = messages

run_context = build_run_context(ctx)

parts: list[_messages.ModelRequestPart] = []
if messages:
# Reevaluate any dynamic system prompt parts
await self._reevaluate_dynamic_prompts(messages, run_context)
else:
parts.extend(await self._sys_parts(run_context))

if (tool_call_results := ctx.deps.tool_call_results) is not None:
if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest):
# If tool call results were provided, that means the previous run ended on deferred tool calls.
Expand All @@ -209,21 +200,52 @@ async def run(
if not messages:
raise exceptions.UserError('Tool call results were provided, but the message history is empty.')

next_message: _messages.ModelRequest | None = None

if messages and (last_message := messages[-1]):
if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
# Drop last message from history and reuse its parts
messages.pop()
parts.extend(last_message.parts)
next_message = _messages.ModelRequest(parts=last_message.parts)

# Extract `UserPromptPart` content from the popped message and add to `ctx.deps.prompt`
user_prompt_parts = [part for part in last_message.parts if isinstance(part, _messages.UserPromptPart)]
if user_prompt_parts:
if len(user_prompt_parts) == 1:
ctx.deps.prompt = user_prompt_parts[0].content
else:
combined_content: list[_messages.UserContent] = []
for part in user_prompt_parts:
if isinstance(part.content, str):
combined_content.append(part.content)
else:
combined_content.extend(part.content)
ctx.deps.prompt = combined_content
elif isinstance(last_message, _messages.ModelResponse):
call_tools_node = await self._handle_message_history_model_response(ctx, last_message)
if call_tools_node is not None:
return call_tools_node

if self.user_prompt is not None:
parts.append(_messages.UserPromptPart(self.user_prompt))
# Build the run context after `ctx.deps.prompt` has been updated
run_context = build_run_context(ctx)

parts: list[_messages.ModelRequestPart] = []
if messages:
await self._reevaluate_dynamic_prompts(messages, run_context)

if next_message:
await self._reevaluate_dynamic_prompts([next_message], run_context)
else:
parts: list[_messages.ModelRequestPart] = []
if not messages:
parts.extend(await self._sys_parts(run_context))

if self.user_prompt is not None:
parts.append(_messages.UserPromptPart(self.user_prompt))

next_message = _messages.ModelRequest(parts=parts)

instructions = await ctx.deps.get_instructions(run_context)
next_message = _messages.ModelRequest(parts, instructions=instructions)
next_message.instructions = await ctx.deps.get_instructions(run_context)

return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)

Expand Down
37 changes: 31 additions & 6 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,28 +2003,53 @@ async def ret_a(x: str) -> str:


def test_run_with_history_ending_on_model_request_and_no_user_prompt():
m = TestModel()
agent = Agent(m)

@agent.system_prompt(dynamic=True)
async def system_prompt(ctx: RunContext) -> str:
return f'System prompt: user prompt length = {len(ctx.prompt or [])}'

messages: list[ModelMessage] = [
ModelRequest(parts=[UserPromptPart(content='Hello')], instructions='Original instructions'),
ModelRequest(
parts=[
SystemPromptPart(content='System prompt', dynamic_ref=system_prompt.__qualname__),
UserPromptPart(content=['Hello', ImageUrl('https://example.com/image.jpg')]),
UserPromptPart(content='How goes it?'),
],
instructions='Original instructions',
),
]

m = TestModel()
agent = Agent(m, instructions='New instructions')
@agent.instructions
async def instructions(ctx: RunContext) -> str:
assert ctx.prompt == ['Hello', ImageUrl('https://example.com/image.jpg'), 'How goes it?']
return 'New instructions'

result = agent.run_sync(message_history=messages)
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
SystemPromptPart(
content='System prompt: user prompt length = 3',
timestamp=IsDatetime(),
dynamic_ref=IsStr(),
),
UserPromptPart(
content='Hello',
content=['Hello', ImageUrl(url='https://example.com/image.jpg', identifier='39cfc4')],
timestamp=IsDatetime(),
)
),
UserPromptPart(
content='How goes it?',
timestamp=IsDatetime(),
),
],
instructions='New instructions',
),
ModelResponse(
parts=[TextPart(content='success (no tool calls)')],
usage=RequestUsage(input_tokens=51, output_tokens=4),
usage=RequestUsage(input_tokens=61, output_tokens=4),
model_name='test',
timestamp=IsDatetime(),
),
Expand Down