Skip to content

Commit b26a687

Browse files
DouweMclaude[bot]
andauthored
When starting run with message history ending in ModelRequest, make its content available in RunContext.prompt (#2891)
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Douwe Maan <[email protected]>
1 parent 8d4889c commit b26a687

File tree

2 files changed

+68
-21
lines changed

2 files changed

+68
-21
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
167167
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
168168
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
169169

170-
async def run(
170+
async def run( # noqa: C901
171171
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
172172
) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]:
173173
try:
@@ -186,15 +186,6 @@ async def run(
186186
# Use the `capture_run_messages` list as the message history so that new messages are added to it
187187
ctx.state.message_history = messages
188188

189-
run_context = build_run_context(ctx)
190-
191-
parts: list[_messages.ModelRequestPart] = []
192-
if messages:
193-
# Reevaluate any dynamic system prompt parts
194-
await self._reevaluate_dynamic_prompts(messages, run_context)
195-
else:
196-
parts.extend(await self._sys_parts(run_context))
197-
198189
if (tool_call_results := ctx.deps.tool_call_results) is not None:
199190
if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest):
200191
# If tool call results were provided, that means the previous run ended on deferred tool calls.
@@ -209,21 +200,52 @@ async def run(
209200
if not messages:
210201
raise exceptions.UserError('Tool call results were provided, but the message history is empty.')
211202

203+
next_message: _messages.ModelRequest | None = None
204+
212205
if messages and (last_message := messages[-1]):
213206
if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
214207
# Drop last message from history and reuse its parts
215208
messages.pop()
216-
parts.extend(last_message.parts)
209+
next_message = _messages.ModelRequest(parts=last_message.parts)
210+
211+
# Extract `UserPromptPart` content from the popped message and add to `ctx.deps.prompt`
212+
user_prompt_parts = [part for part in last_message.parts if isinstance(part, _messages.UserPromptPart)]
213+
if user_prompt_parts:
214+
if len(user_prompt_parts) == 1:
215+
ctx.deps.prompt = user_prompt_parts[0].content
216+
else:
217+
combined_content: list[_messages.UserContent] = []
218+
for part in user_prompt_parts:
219+
if isinstance(part.content, str):
220+
combined_content.append(part.content)
221+
else:
222+
combined_content.extend(part.content)
223+
ctx.deps.prompt = combined_content
217224
elif isinstance(last_message, _messages.ModelResponse):
218225
call_tools_node = await self._handle_message_history_model_response(ctx, last_message)
219226
if call_tools_node is not None:
220227
return call_tools_node
221228

222-
if self.user_prompt is not None:
223-
parts.append(_messages.UserPromptPart(self.user_prompt))
229+
# Build the run context after `ctx.deps.prompt` has been updated
230+
run_context = build_run_context(ctx)
231+
232+
parts: list[_messages.ModelRequestPart] = []
233+
if messages:
234+
await self._reevaluate_dynamic_prompts(messages, run_context)
235+
236+
if next_message:
237+
await self._reevaluate_dynamic_prompts([next_message], run_context)
238+
else:
239+
parts: list[_messages.ModelRequestPart] = []
240+
if not messages:
241+
parts.extend(await self._sys_parts(run_context))
242+
243+
if self.user_prompt is not None:
244+
parts.append(_messages.UserPromptPart(self.user_prompt))
245+
246+
next_message = _messages.ModelRequest(parts=parts)
224247

225-
instructions = await ctx.deps.get_instructions(run_context)
226-
next_message = _messages.ModelRequest(parts, instructions=instructions)
248+
next_message.instructions = await ctx.deps.get_instructions(run_context)
227249

228250
return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
229251

tests/test_agent.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,28 +2003,53 @@ async def ret_a(x: str) -> str:
20032003

20042004

20052005
def test_run_with_history_ending_on_model_request_and_no_user_prompt():
2006+
m = TestModel()
2007+
agent = Agent(m)
2008+
2009+
@agent.system_prompt(dynamic=True)
2010+
async def system_prompt(ctx: RunContext) -> str:
2011+
return f'System prompt: user prompt length = {len(ctx.prompt or [])}'
2012+
20062013
messages: list[ModelMessage] = [
2007-
ModelRequest(parts=[UserPromptPart(content='Hello')], instructions='Original instructions'),
2014+
ModelRequest(
2015+
parts=[
2016+
SystemPromptPart(content='System prompt', dynamic_ref=system_prompt.__qualname__),
2017+
UserPromptPart(content=['Hello', ImageUrl('https://example.com/image.jpg')]),
2018+
UserPromptPart(content='How goes it?'),
2019+
],
2020+
instructions='Original instructions',
2021+
),
20082022
]
20092023

2010-
m = TestModel()
2011-
agent = Agent(m, instructions='New instructions')
2024+
@agent.instructions
2025+
async def instructions(ctx: RunContext) -> str:
2026+
assert ctx.prompt == ['Hello', ImageUrl('https://example.com/image.jpg'), 'How goes it?']
2027+
return 'New instructions'
20122028

20132029
result = agent.run_sync(message_history=messages)
20142030
assert result.all_messages() == snapshot(
20152031
[
20162032
ModelRequest(
20172033
parts=[
2034+
SystemPromptPart(
2035+
content='System prompt: user prompt length = 3',
2036+
timestamp=IsDatetime(),
2037+
dynamic_ref=IsStr(),
2038+
),
20182039
UserPromptPart(
2019-
content='Hello',
2040+
content=['Hello', ImageUrl(url='https://example.com/image.jpg', identifier='39cfc4')],
20202041
timestamp=IsDatetime(),
2021-
)
2042+
),
2043+
UserPromptPart(
2044+
content='How goes it?',
2045+
timestamp=IsDatetime(),
2046+
),
20222047
],
20232048
instructions='New instructions',
20242049
),
20252050
ModelResponse(
20262051
parts=[TextPart(content='success (no tool calls)')],
2027-
usage=RequestUsage(input_tokens=51, output_tokens=4),
2052+
usage=RequestUsage(input_tokens=61, output_tokens=4),
20282053
model_name='test',
20292054
timestamp=IsDatetime(),
20302055
),

0 commit comments

Comments
 (0)