diff --git a/docs/agents.md b/docs/agents.md index a4e05ae40a..aa02228863 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -494,6 +494,4 @@ with capture_run_messages() as messages: # (2)! _(This example is complete, it can be run "as is")_ !!! note - You may not call [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] more than once within a single `capture_run_messages` context. - - If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised. + If you call [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only. diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b8fd0ac5e6..1e5b60b8a6 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -6,7 +6,6 @@ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar -from dataclasses import dataclass, field from types import FrameType from typing import Any, Callable, Generic, Literal, cast, final, overload @@ -60,7 +59,7 @@ @final -@dataclass(init=False) +@dataclasses.dataclass(init=False) class Agent(Generic[AgentDeps, ResultData]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. @@ -100,17 +99,17 @@ class Agent(Generic[AgentDeps, ResultData]): be merged with this value, with the runtime argument taking priority. """ - _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False) - _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False) - _allow_text_result: bool = field(repr=False) - _system_prompts: tuple[str, ...] = field(repr=False) - _function_tools: dict[str, Tool[AgentDeps]] = field(repr=False) - _default_retries: int = field(repr=False) - _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False) - _deps_type: type[AgentDeps] = field(repr=False) - _max_result_retries: int = field(repr=False) - _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False) - _override_model: _utils.Option[models.Model] = field(default=None, repr=False) + _result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False) + _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False) + _allow_text_result: bool = dataclasses.field(repr=False) + _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) + _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False) + _default_retries: int = dataclasses.field(repr=False) + _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False) + _deps_type: type[AgentDeps] = dataclasses.field(repr=False) + _max_result_retries: int = dataclasses.field(repr=False) + _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False) + _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False) def __init__( self, @@ -836,15 +835,15 @@ async def _prepare_messages( self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps] ) -> list[_messages.ModelMessage]: try: - messages = _messages_ctx_var.get() + ctx_messages = _messages_ctx_var.get() except LookupError: - messages = [] + messages: list[_messages.ModelMessage] = [] else: - if messages: - raise exceptions.UserError( - 'The capture_run_messages() context manager may only be used to wrap ' - 'one call to run(), run_sync(), or run_stream().' - ) + if ctx_messages.used: + messages = [] + else: + messages = ctx_messages.messages + ctx_messages.used = True if message_history: # shallow copy messages @@ -1138,7 +1137,13 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') -_messages_ctx_var: ContextVar[list[_messages.ModelMessage]] = ContextVar('var') +@dataclasses.dataclass +class _RunMessages: + messages: list[_messages.ModelMessage] + used: bool = False + + +_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var') @contextmanager @@ -1162,21 +1167,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]: ``` !!! note - You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context. - If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised. + If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context, + `messages` will represent the messages exchanged during the first call only. """ try: - yield _messages_ctx_var.get() + yield _messages_ctx_var.get().messages except LookupError: messages: list[_messages.ModelMessage] = [] - token = _messages_ctx_var.set(messages) + token = _messages_ctx_var.set(_RunMessages(messages)) try: yield messages finally: _messages_ctx_var.reset(token) -@dataclass +@dataclasses.dataclass class _MarkFinalResult(Generic[ResultData]): """Marker class to indicate that the result is the final result. diff --git a/tests/test_agent.py b/tests/test_agent.py index 0885d20dd3..6b23208771 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1218,15 +1218,44 @@ def test_double_capture_run_messages(set_event_loop: None) -> None: assert messages == [] result = agent.run_sync('Hello') assert result.data == 'success (no tool calls)' - with pytest.raises(UserError) as exc_info: - agent.run_sync('Hello') - assert ( - str(exc_info.value) - == 'The capture_run_messages() context manager may only be used to wrap one call to run(), run_sync(), or run_stream().' - ) + result2 = agent.run_sync('Hello 2') + assert result2.data == 'success (no tool calls)' + assert messages == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse(parts=[TextPart(content='success (no tool calls)')], timestamp=IsNow(tz=timezone.utc)), ] ) + + +def test_capture_run_messages_tool_agent(set_event_loop: None) -> None: + agent_outer = Agent('test') + agent_inner = Agent(TestModel(custom_result_text='inner agent result')) + + @agent_outer.tool_plain + async def foobar(x: str) -> str: + result_ = await agent_inner.run(x) + return result_.data + + with capture_run_messages() as messages: + result = agent_outer.run_sync('foobar') + + assert result.data == snapshot('{"foobar":"inner agent result"}') + assert messages == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ToolCallPart(tool_name='foobar', args=ArgsDict(args_dict={'x': 'a'}))], + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart(tool_name='foobar', content='inner agent result', timestamp=IsNow(tz=timezone.utc)) + ] + ), + ModelResponse( + parts=[TextPart(content='{"foobar":"inner agent result"}')], timestamp=IsNow(tz=timezone.utc) + ), + ] + )