Skip to content

Commit c028211

Browse files
authored
Fix new_messages() and capture_run_messages() when history processors are used (#2921)
1 parent 77b5660 commit c028211

File tree

4 files changed

+377
-63
lines changed

4 files changed

+377
-63
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,8 @@ async def run( # noqa: C901
187187
messages = ctx_messages.messages
188188
ctx_messages.used = True
189189

190-
message_history = _clean_message_history(ctx.state.message_history)
191-
# Add message history to the `capture_run_messages` list, which will be empty at this point
192-
messages.extend(message_history)
190+
# Replace the `capture_run_messages` list with the message history
191+
messages[:] = _clean_message_history(ctx.state.message_history)
193192
# Use the `capture_run_messages` list as the message history so that new messages are added to it
194193
ctx.state.message_history = messages
195194
ctx.deps.new_message_index = len(messages)
@@ -456,7 +455,18 @@ async def _prepare_request(
456455
# This will raise errors for any tool name conflicts
457456
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
458457

459-
message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, run_context)
458+
original_history = ctx.state.message_history[:]
459+
message_history = await _process_message_history(original_history, ctx.deps.history_processors, run_context)
460+
# Never merge the new `ModelRequest` with the one preceding it, to keep `new_messages()` from accidentally including part of the existing message history
461+
message_history = [*_clean_message_history(message_history[:-1]), message_history[-1]]
462+
# `ctx.state.message_history` is the same list used by `capture_run_messages`, so we should replace its contents, not the reference
463+
ctx.state.message_history[:] = message_history
464+
# Update the new message index to ensure `result.new_messages()` returns the correct messages
465+
ctx.deps.new_message_index -= len(original_history) - len(message_history)
466+
467+
# Do one more cleaning pass to merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts,
468+
# but don't store it in the message history on state.
469+
# See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary
460470
message_history = _clean_message_history(message_history)
461471

462472
model_request_parameters = await _prepare_request_parameters(ctx)
@@ -1087,12 +1097,11 @@ def build_agent_graph(
10871097

10881098

10891099
async def _process_message_history(
1090-
state: GraphAgentState,
1100+
messages: list[_messages.ModelMessage],
10911101
processors: Sequence[HistoryProcessor[DepsT]],
10921102
run_context: RunContext[DepsT],
10931103
) -> list[_messages.ModelMessage]:
10941104
"""Process message history through a sequence of processors."""
1095-
messages = state.message_history
10961105
for processor in processors:
10971106
takes_ctx = is_takes_ctx(processor)
10981107

@@ -1110,8 +1119,12 @@ async def _process_message_history(
11101119
sync_processor = cast(_HistoryProcessorSync, processor)
11111120
messages = await run_in_executor(sync_processor, messages)
11121121

1113-
# Replaces the message history in the state with the processed messages
1114-
state.message_history = messages
1122+
if len(messages) == 0:
1123+
raise exceptions.UserError('Processed history cannot be empty.')
1124+
1125+
if not isinstance(messages[-1], _messages.ModelRequest):
1126+
raise exceptions.UserError('Processed history must end with a `ModelRequest`.')
1127+
11151128
return messages
11161129

11171130

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -614,12 +614,10 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
614614
instrumentation_settings = None
615615
tracer = NoOpTracer()
616616

617-
graph_deps = _agent_graph.GraphAgentDeps[
618-
AgentDepsT, RunOutputDataT
619-
](
617+
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
620618
user_deps=deps,
621619
prompt=user_prompt,
622-
new_message_index=0, # This will be set in `UserPromptNode` based on the length of the cleaned message history
620+
new_message_index=len(message_history) if message_history else 0,
623621
model=model_used,
624622
model_settings=model_settings,
625623
usage_limits=usage_limits,

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,7 @@ def all_messages_json(self, *, output_tool_return_content: str | None = None) ->
322322
self.all_messages(output_tool_return_content=output_tool_return_content)
323323
)
324324

325-
def new_messages(
326-
self, *, output_tool_return_content: str | None = None
327-
) -> list[_messages.ModelMessage]: # pragma: no cover
325+
def new_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]:
328326
"""Return new messages associated with this run.
329327
330328
Messages from older runs are excluded.

0 commit comments

Comments
 (0)