Skip to content

Commit ccd5e6a

Browse files
authored
Fix new_messages() when deferred_tool_results is used with message_history ending in ToolReturnParts (#2913)
1 parent 9b7c429 commit ccd5e6a

File tree

13 files changed

+572
-198
lines changed

13 files changed

+572
-198
lines changed

docs/agents.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ async def main():
280280
[
281281
UserPromptNode(
282282
user_prompt='What is the capital of France?',
283-
instructions=None,
284283
instructions_functions=[],
285284
system_prompts=(),
286285
system_prompt_functions=[],
@@ -343,7 +342,6 @@ async def main():
343342
[
344343
UserPromptNode(
345344
user_prompt='What is the capital of France?',
346-
instructions=None,
347345
instructions_functions=[],
348346
system_prompts=(),
349347
system_prompt_functions=[],

docs/deferred-tools.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,21 @@ print(result.all_messages())
132132
),
133133
ModelRequest(
134134
parts=[
135-
ToolReturnPart(
136-
tool_name='delete_file',
137-
content='Deleting files is not allowed',
138-
tool_call_id='delete_file',
139-
timestamp=datetime.datetime(...),
140-
),
141135
ToolReturnPart(
142136
tool_name='update_file',
143137
content="File 'README.md' updated: 'Hello, world!'",
144138
tool_call_id='update_file_readme',
145139
timestamp=datetime.datetime(...),
140+
)
141+
]
142+
),
143+
ModelRequest(
144+
parts=[
145+
ToolReturnPart(
146+
tool_name='delete_file',
147+
content='Deleting files is not allowed',
148+
tool_call_id='delete_file',
149+
timestamp=datetime.datetime(...),
146150
),
147151
ToolReturnPart(
148152
tool_name='update_file',

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 173 additions & 114 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,11 @@
4545
from ..settings import ModelSettings, merge_model_settings
4646
from ..tools import (
4747
AgentDepsT,
48-
DeferredToolCallResult,
49-
DeferredToolResult,
5048
DeferredToolResults,
5149
DocstringFormat,
5250
GenerateToolJsonSchema,
5351
RunContext,
5452
Tool,
55-
ToolApproved,
56-
ToolDenied,
5753
ToolFuncContext,
5854
ToolFuncEither,
5955
ToolFuncPlain,
@@ -462,7 +458,7 @@ def iter(
462458
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
463459

464460
@asynccontextmanager
465-
async def iter( # noqa: C901
461+
async def iter(
466462
self,
467463
user_prompt: str | Sequence[_messages.UserContent] | None = None,
468464
*,
@@ -505,7 +501,6 @@ async def main():
505501
[
506502
UserPromptNode(
507503
user_prompt='What is the capital of France?',
508-
instructions=None,
509504
instructions_functions=[],
510505
system_prompts=(),
511506
system_prompt_functions=[],
@@ -559,7 +554,6 @@ async def main():
559554
del model
560555

561556
deps = self._get_deps(deps)
562-
new_message_index = len(message_history) if message_history else 0
563557
output_schema = self._prepare_output_schema(output_type, model_used.profile)
564558

565559
output_type_ = output_type or self.output_type
@@ -620,27 +614,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
620614
instrumentation_settings = None
621615
tracer = NoOpTracer()
622616

623-
tool_call_results: dict[str, DeferredToolResult] | None = None
624-
if deferred_tool_results is not None:
625-
tool_call_results = {}
626-
for tool_call_id, approval in deferred_tool_results.approvals.items():
627-
if approval is True:
628-
approval = ToolApproved()
629-
elif approval is False:
630-
approval = ToolDenied()
631-
tool_call_results[tool_call_id] = approval
632-
633-
if calls := deferred_tool_results.calls:
634-
call_result_types = _utils.get_union_args(DeferredToolCallResult)
635-
for tool_call_id, result in calls.items():
636-
if not isinstance(result, call_result_types):
637-
result = _messages.ToolReturn(result)
638-
tool_call_results[tool_call_id] = result
639-
640-
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
617+
graph_deps = _agent_graph.GraphAgentDeps[
618+
AgentDepsT, RunOutputDataT
619+
](
641620
user_deps=deps,
642621
prompt=user_prompt,
643-
new_message_index=new_message_index,
622+
new_message_index=0, # This will be set in `UserPromptNode` based on the length of the cleaned message history
644623
model=model_used,
645624
model_settings=model_settings,
646625
usage_limits=usage_limits,
@@ -651,13 +630,13 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
651630
history_processors=self.history_processors,
652631
builtin_tools=list(self._builtin_tools),
653632
tool_manager=tool_manager,
654-
tool_call_results=tool_call_results,
655633
tracer=tracer,
656634
get_instructions=get_instructions,
657635
instrumentation_settings=instrumentation_settings,
658636
)
659637
start_node = _agent_graph.UserPromptNode[AgentDepsT](
660638
user_prompt=user_prompt,
639+
deferred_tool_results=deferred_tool_results,
661640
instructions=self._instructions,
662641
instructions_functions=self._instructions_functions,
663642
system_prompts=self._system_prompts,

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -499,12 +499,13 @@ async def on_complete() -> None:
499499
]
500500

501501
parts: list[_messages.ModelRequestPart] = []
502-
async for _event in _agent_graph.process_function_tools(
503-
graph_ctx.deps.tool_manager,
504-
tool_calls,
505-
final_result,
506-
graph_ctx,
507-
parts,
502+
async for _event in _agent_graph.process_tool_calls(
503+
tool_manager=graph_ctx.deps.tool_manager,
504+
tool_calls=tool_calls,
505+
tool_call_results=None,
506+
final_result=final_result,
507+
ctx=graph_ctx,
508+
output_parts=parts,
508509
):
509510
pass
510511
if parts:
@@ -621,7 +622,6 @@ async def main():
621622
[
622623
UserPromptNode(
623624
user_prompt='What is the capital of France?',
624-
instructions=None,
625625
instructions_functions=[],
626626
system_prompts=(),
627627
system_prompt_functions=[],

pydantic_ai_slim/pydantic_ai/agent/wrapper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ async def main():
144144
[
145145
UserPromptNode(
146146
user_prompt='What is the capital of France?',
147-
instructions=None,
148147
instructions_functions=[],
149148
system_prompts=(),
150149
system_prompt_functions=[],

pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,6 @@ async def main():
627627
[
628628
UserPromptNode(
629629
user_prompt='What is the capital of France?',
630-
instructions=None,
631630
instructions_functions=[],
632631
system_prompts=(),
633632
system_prompt_functions=[],

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,6 @@ async def main():
660660
[
661661
UserPromptNode(
662662
user_prompt='What is the capital of France?',
663-
instructions=None,
664663
instructions_functions=[],
665664
system_prompts=(),
666665
system_prompt_functions=[],

pydantic_ai_slim/pydantic_ai/run.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ async def main():
4848
[
4949
UserPromptNode(
5050
user_prompt='What is the capital of France?',
51-
instructions=None,
5251
instructions_functions=[],
5352
system_prompts=(),
5453
system_prompt_functions=[],
@@ -183,7 +182,6 @@ async def main():
183182
[
184183
UserPromptNode(
185184
user_prompt='What is the capital of France?',
186-
instructions=None,
187185
instructions_functions=[],
188186
system_prompts=(),
189187
system_prompt_functions=[],

tests/test_a2a.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,10 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon
623623
content='Final result processed.',
624624
tool_call_id=IsStr(),
625625
timestamp=IsDatetime(),
626-
)
627-
]
626+
),
627+
UserPromptPart(content='Second message', timestamp=IsDatetime()),
628+
],
628629
),
629-
ModelRequest(parts=[UserPromptPart(content='Second message', timestamp=IsDatetime())]),
630630
]
631631
)
632632

0 commit comments

Comments
 (0)